In [1]:
import ast
import sys
from pathlib import Path


class ClassExtractor(ast.NodeVisitor):
    def __init__(self, target_class_name):
        self.target_class = target_class_name
        self.found = False
        self.class_info = {}

    def visit_ClassDef(self, node):
        if node.name == self.target_class:
            self.found = True
            bases = [self._get_base_name(base) for base in node.bases]
            attrs = []
            methods = []

            for item in node.body:
                if isinstance(item, ast.FunctionDef):
                    methods.append(item.name)
                elif isinstance(item, ast.Assign):
                    for target in item.targets:
                        if isinstance(target, ast.Name):
                            attrs.append(target.id)

            self.class_info = {
                "name": node.name,
                "bases": bases,
                "attributes": attrs,
                "methods": methods
            }

    def _get_base_name(self, base):
        if isinstance(base, ast.Name):
            return base.id
        elif isinstance(base, ast.Attribute):
            return base.attr
        else:
            return "object"

def generate_puml(class_info):
    lines = ["@startuml", f"class {class_info['name']} {{"]
    for attr in class_info["attributes"]:
        lines.append(f"  - {attr}")
    for method in class_info["methods"]:
        lines.append(f"  + {method}()")
    lines.append("}")
    for base in class_info["bases"]:
        lines.append(f"{base} <|-- {class_info['name']}")
    lines.append("@enduml")
    return "\n".join(lines)

def extract_class_to_puml(py_path, class_name, output_path):
    with open(py_path, 'r', encoding='utf-8') as f:
        tree = ast.parse(f.read())

    extractor = ClassExtractor(class_name)
    extractor.visit(tree)

    if not extractor.found:
        print(f"❌ 类 '{class_name}' 未在文件中找到")
        return

    puml_code = generate_puml(extractor.class_info)

    Path(output_path).write_text(puml_code, encoding='utf-8')
    print(f"✅ 已生成 PlantUML 文件：{output_path}")


    

In [3]:

# parser = argparse.ArgumentParser(description="提取指定类为 PlantUML 类图")
# parser.add_argument("pyfile", help="Python 源文件路径")
# parser.add_argument("classname", help="要提取的类名")
# parser.add_argument("output", help="输出的 .puml 文件路径")

# args = parser.parse_args()
pyfile = "./agent/ExchangeAgent.py"
classname = "ExchangeAgent"
output = "./insight/uml/agent/ExchangeAgent.puml"
extract_class_to_puml(pyfile, classname, output)

✅ 已生成 PlantUML 文件：./insight/uml/agent/ExchangeAgent.puml


In [4]:
import ast
import os
from pathlib import Path

class ClassExtractor(ast.NodeVisitor):
    def __init__(self):
        self.classes = []

    def visit_ClassDef(self, node):
        bases = [self._get_base_name(base) for base in node.bases]
        attrs = []
        methods = []

        for item in node.body:
            if isinstance(item, ast.FunctionDef):
                methods.append(item.name)
            elif isinstance(item, ast.Assign):
                for target in item.targets:
                    if isinstance(target, ast.Name):
                        attrs.append(target.id)

        self.classes.append({
            "name": node.name,
            "bases": bases,
            "attributes": attrs,
            "methods": methods
        })

    def _get_base_name(self, base):
        if isinstance(base, ast.Name):
            return base.id
        elif isinstance(base, ast.Attribute):
            return base.attr
        else:
            return "object"

def generate_puml_code(class_info):
    lines = ["@startuml", f"class {class_info['name']} {{"]
    for attr in class_info['attributes']:
        lines.append(f"  - {attr}")
    for method in class_info['methods']:
        lines.append(f"  + {method}()")
    lines.append("}")
    for base in class_info["bases"]:
        lines.append(f"{base} <|-- {class_info['name']}")
    lines.append("@enduml")
    return "\n".join(lines)

def process_py_file(py_path: Path, input_root: Path, output_root: Path):
    rel_path = py_path.relative_to(input_root).with_suffix('')  # 去掉 .py 后缀
    try:
        with open(py_path, 'r', encoding='utf-8') as f:
            tree = ast.parse(f.read())
    except Exception as e:
        print(f"⚠️ 无法解析 {py_path}: {e}")
        return

    extractor = ClassExtractor()
    extractor.visit(tree)

    for class_info in extractor.classes:
        output_dir = output_root / rel_path.parent
        output_dir.mkdir(parents=True, exist_ok=True)
        output_file = output_dir / f"{class_info['name']}.puml"
        puml_code = generate_puml_code(class_info)
        output_file.write_text(puml_code, encoding='utf-8')
        print(f"✅ 生成类图：{output_file}")

def generate_all_puml(input_root: str, output_root: str):
    input_root = Path(input_root).resolve()
    output_root = Path(output_root).resolve()

    for py_file in input_root.rglob("*.py"):
        process_py_file(py_file, input_root, output_root)

In [14]:
class ClassExtractor(ast.NodeVisitor):
    def __init__(self):
        self.classes = {}
        self.current_class = None

    def visit_ClassDef(self, node):
        class_name = node.name
        bases = [self._get_base_name(base) for base in node.bases]
        attributes = {}
        methods = {}

        self.current_class = class_name

        for item in node.body:
            if isinstance(item, ast.FunctionDef):
                doc = ast.get_docstring(item) or ""
                methods[item.name] = doc.strip()
            elif isinstance(item, (ast.AnnAssign, ast.Assign)):
                # 支持类型注解的变量（AnnAssign）和普通变量（Assign）
                targets = [item.target] if isinstance(item, ast.AnnAssign) else item.targets
                for target in targets:
                    if isinstance(target, ast.Name):
                        attributes[target.id] = self._get_annotation(item)

        self.classes[class_name] = {
            "name": class_name,
            "bases": bases,
            "attributes": attributes,
            "methods": methods
        }

    def _get_annotation(self, node):
        if hasattr(node, 'annotation') and isinstance(node.annotation, ast.Name):
            return node.annotation.id
        elif hasattr(node, 'annotation') and isinstance(node.annotation, ast.Subscript):
            return self._parse_subscript(node.annotation)
        return ""

    def _parse_subscript(self, node):
        # 例如 List[str]
        value = node.value.id if isinstance(node.value, ast.Name) else "Any"
        slice_value = (
            node.slice.id if isinstance(node.slice, ast.Name) else
            self._parse_subscript(node.slice) if isinstance(node.slice, ast.Subscript) else "Any"
        )
        return f"{value}[{slice_value}]"

    def _get_base_name(self, base):
        if isinstance(base, ast.Name):
            return base.id
        elif isinstance(base, ast.Attribute):
            return base.attr
        else:
            return "object"

def generate_puml_code(class_info, all_classes):
    lines = [f"class {class_info['name']} {{"]

    # 当前类的属性（不重复列出父类的）
    own_attrs = class_info['attributes']
    parent_attrs = {}

    for base in class_info['bases']:
        if base in all_classes:
            parent_attrs.update(all_classes[base]['attributes'])

    for attr, type_str in own_attrs.items():
        if attr not in parent_attrs:
            if type_str:
                lines.append(f"  - {attr}: {type_str}")
            else:
                lines.append(f"  - {attr}")

    # 方法（带注释）
    for method, doc in class_info['methods'].items():
        doc_comment = f"  + {method}()"
        if doc:
            doc_comment += f' : "{doc}"'
        lines.append(doc_comment)

    lines.append("}")

    # 添加继承关系
    for base in class_info['bases']:
        lines.append(f"{base} <|-- {class_info['name']}")

    return "@startuml\n" + '\n'.join(lines) + "\n@enduml"

def process_py_file(py_path: Path, input_root: Path, output_root: Path):
    rel_path = py_path.relative_to(input_root).with_suffix('')  # 去掉 .py 后缀
    try:
        with open(py_path, 'r', encoding='utf-8') as f:
            tree = ast.parse(f.read())
    except Exception as e:
        print(f"⚠️ 无法解析 {py_path}: {e}")
        return

    extractor = ClassExtractor()
    extractor.visit(tree)

    all_classes = extractor.classes
    for class_name, class_info in all_classes.items():
        output_dir = output_root / rel_path.parent
        output_dir.mkdir(parents=True, exist_ok=True)
        output_file = output_dir / f"{class_name}.puml"
        puml_code = generate_puml_code(class_info, all_classes)
        output_file.write_text(puml_code, encoding='utf-8')
        print(f"✅ 生成类图：{output_file}")

def generate_all_puml(input_root: str, output_root: str):
    input_root = Path(input_root).resolve()
    output_root = Path(output_root).resolve()

    for py_file in input_root.rglob("*.py"):
        process_py_file(py_file, input_root, output_root)

In [15]:
ipt = "./"
output = "./insight/project_uml2"

generate_all_puml(ipt, output)

✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/Kernel.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/util/OrderBook.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/contributed_traders/SimpleAgent.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/message/MessageType.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/message/Message.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/agent/HeuristicBeliefLearningAgent.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/agent/ValueAgent.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/agent/FinancialAgent.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/agent/TradingAgent.puml
✅ 生成类图：/Users/mythezone/project/gitlab/abides-acc/insight/project_uml2/agent/OrderBookImbalanceAgent.puml
✅ 生成类图：/Users/mythezone/proj

In [7]:
from pathlib import Path

def extract_puml_body(puml_path: Path) -> str:
    """
    提取单个 .puml 文件中 @startuml 和 @enduml 之间的内容。
    """
    content = puml_path.read_text(encoding='utf-8')
    lines = content.splitlines()
    body = []
    inside = False
    for line in lines:
        if line.strip().lower().startswith('@startuml'):
            inside = True
            continue
        elif line.strip().lower().startswith('@enduml'):
            inside = False
            continue
        if inside:
            body.append(line)
    return '\n'.join(body)

def merge_puml_files(input_dir: str, output_file: str):
    input_dir = Path(input_dir)
    all_puml_files = list(input_dir.rglob("*.puml"))

    merged_lines = ["@startuml", "skinparam classAttributeIconSize 0\n"]

    for puml_file in sorted(all_puml_files):
        relative_path = puml_file.relative_to(input_dir).with_suffix('')
        namespace = str(relative_path.parent).replace(os.sep, '.')
        class_name = relative_path.name

        merged_lines.append(f"'' ===== From: {relative_path} =====")
        body = extract_puml_body(puml_file)

        # 可选：添加 package 标签（可关闭）
        if namespace and namespace != ".":
            merged_lines.append(f"package {namespace} {{")
            merged_lines.append(body)
            merged_lines.append("}")
        else:
            merged_lines.append(body)

        merged_lines.append("")  # 空行分隔

    merged_lines.append("@enduml")

    Path(output_file).write_text('\n'.join(merged_lines), encoding='utf-8')
    print(f"✅ 合成完成：{output_file}")

In [16]:
input_dir = "./insight/project_uml2"
output_file = "./insight/project_uml2.puml"
merge_puml_files(input_dir, output_file)

✅ 合成完成：./insight/project_uml2.puml


In [9]:
def extract_puml_body(puml_path: Path) -> str:
    content = puml_path.read_text(encoding='utf-8')
    lines = content.splitlines()
    body = []
    inside = False
    for line in lines:
        if line.strip().lower().startswith('@startuml'):
            inside = True
            continue
        elif line.strip().lower().startswith('@enduml'):
            inside = False
            continue
        if inside:
            body.append(line)
    return '\n'.join(body)

def group_by_namespace(puml_files: list, root: Path):
    grouped = {}
    for f in puml_files:
        rel_path = f.relative_to(root)
        namespace = str(rel_path.parent).replace(os.sep, ".")
        grouped.setdefault(namespace, []).append(f)
    return grouped

def merge_puml_grouped(input_dir: str, output_file: str):
    input_dir = Path(input_dir)
    all_puml_files = list(input_dir.rglob("*.puml"))

    grouped_files = group_by_namespace(all_puml_files, input_dir)

    merged_lines = [
        "@startuml",
        "skinparam classAttributeIconSize 0",
        "skinparam packageStyle rectangle",
        "skinparam shadowing false",
        "skinparam backgroundColor #FDFDFD",
        ""
    ]

    for namespace, files in sorted(grouped_files.items()):
        label = f"package {namespace or 'root'} {{"
        merged_lines.append("")
        merged_lines.append(label)
        for f in sorted(files):
            merged_lines.append(f"  '' From: {f.relative_to(input_dir)}")
            body = extract_puml_body(f)
            indented_body = '\n'.join("  " + line for line in body.splitlines())
            merged_lines.append(indented_body)
        merged_lines.append("}")

    merged_lines.append("@enduml")

    Path(output_file).write_text('\n'.join(merged_lines), encoding='utf-8')
    print(f"✅ 合并完成并分组布局优化：{output_file}")

In [17]:
output_file = "./insight/project_uml_grouped2.puml"
merge_puml_grouped(input_dir, output_file)

✅ 合并完成并分组布局优化：./insight/project_uml_grouped2.puml


In [12]:
import subprocess
from pathlib import Path

def render_plantuml(input_puml: str, output_format: str = 'svg', plantuml_jar: str = None):
    input_path = Path(input_puml).resolve()
    assert input_path.exists(), f"❌ 输入文件不存在: {input_path}"
    assert output_format in ['svg', 'png'], "❌ 仅支持 svg 或 png 格式"

    if plantuml_jar:
        # 方式二：使用 plantuml.jar
        cmd = ['java', '-jar', plantuml_jar, f'-t{output_format}', str(input_path)]
    else:
        # 方式一：使用 plantuml 命令行工具
        cmd = ['plantuml', f'-t{output_format}', str(input_path)]

    try:
        subprocess.run(cmd, check=True)
        output_file = input_path.with_suffix(f".{output_format}")
        print(f"✅ 已生成图像：{output_file}")
    except subprocess.CalledProcessError as e:
        print(f"❌ 渲染失败：{e}")

In [13]:
render_plantuml("./insight/project_uml_grouped.puml", 'svg')

FileNotFoundError: [Errno 2] No such file or directory: 'plantuml'