Pydantic モデルの定義ファイルを自動生成できるようにする

In [1]:
from typing import NotRequired, TypedDict

from pydantic import ConfigDict


class TypeInfo(TypedDict):
    name: str
    import_from: NotRequired[str]
    alias: NotRequired[str]
    list: NotRequired[bool]
    optional: NotRequired[bool]


class ModelDefinition(TypedDict):
    name: str
    description: NotRequired[str]
    fields: list[tuple[str, TypeInfo]]
    base: NotRequired[TypeInfo]
    config: NotRequired[ConfigDict]

In [2]:
from collections import defaultdict


def build_model_definitions(models: list[ModelDefinition]) -> str:
    def type_ref(t: TypeInfo) -> str:
        ref = t.get("alias", t["name"])
        if t.get("list"):
            ref = f"list[{ref}]"
        if t.get("optional"):
            ref = f"{ref} | None"
        return ref

    # Collect imports: module -> {(name, alias)}
    imports: dict[str, set[tuple[str, str | None]]] = defaultdict(set)

    for m in models:
        # base type import (e.g., BaseModel from pydantic)
        base = m.get("base")
        if base and base.get("import_from"):
            imports[base["import_from"]].add((base["name"], base.get("alias")))

        # field type imports
        for field_name, field_type in m["fields"]:
            if field_type.get("import_from"):
                imports[field_type["import_from"]].add(
                    (field_type["name"], field_type.get("alias"))
                )

    # Render imports in a deterministic order
    import_lines: list[str] = []
    for module in sorted(imports.keys()):
        items = sorted(imports[module], key=lambda x: (x[0], x[1] or ""))
        parts = []
        for name, alias in items:
            parts.append(f"{name} as {alias}" if alias else name)
        import_lines.append(f"from {module} import {', '.join(parts)}")

    # Render classes
    class_lines: list[str] = []
    for m in models:
        base = m.get("base")
        base_name = type_ref(base) if base else "BaseModel"

        # Add BaseModel import if base is not specified
        if not base:
            imports.setdefault("pydantic", set()).add(("BaseModel", None))

        class_lines.append(f"class {m['name']}({base_name}):")

        # Add description as docstring if present
        description = m.get("description")
        if description:
            class_lines.append(f'    """{description}"""')
            class_lines.append("")

        # Add config if present
        config = m.get("config")
        if config:
            # Format config as ConfigDict(key1=value1, key2=value2, ...)
            config_items = []
            for key, value in config.items():
                if isinstance(value, str):
                    config_items.append(f"{key}='{value}'")
                else:
                    config_items.append(f"{key}={value!r}")
            config_str = ", ".join(config_items)
            class_lines.append(f"    model_config = ConfigDict({config_str})")
            class_lines.append("")
            # Add ConfigDict import
            imports.setdefault("pydantic", set()).add(("ConfigDict", None))

        # Add fields
        if m["fields"]:
            for field_name, field_type in m["fields"]:
                class_lines.append(f"    {field_name}: {type_ref(field_type)}")
        else:
            if not config and not description:
                class_lines.append("    pass")

        class_lines.append("")  # blank line after each class

    # Re-render imports if ConfigDict or BaseModel was added
    import_lines = []
    for module in sorted(imports.keys()):
        items = sorted(imports[module], key=lambda x: (x[0], x[1] or ""))
        parts = []
        for name, alias in items:
            parts.append(f"{name} as {alias}" if alias else name)
        import_lines.append(f"from {module} import {', '.join(parts)}")

    # Join (imports, blank line, classes). Match the sample formatting.
    out: list[str] = []
    out.extend(import_lines)
    if import_lines:
        out.append("")  # blank line between imports and first class
    out.extend(class_lines)

    # Avoid extra blank lines at the very end (sample ends without an extra blank line)
    while out and out[-1] == "":
        out.pop()

    return "\n".join(out)

In [3]:
User: ModelDefinition = {
    "name": "User",
    "description": "Represents a user in the system",
    "fields": [
        ("id", {"name": "UUID", "import_from": "uuid"}),
        ("name", {"name": "str"}),
        ("meta", {"name": "UserMeta", "import_from": "models"}),
        # relations
        ("posts", {"name": "Post", "list": True, "optional": True}),
    ],
}

Post: ModelDefinition = {
    "name": "Post",
    "description": "Represents a blog post written by a user",
    "fields": [
        ("id", {"name": "UUID", "import_from": "uuid"}),
        ("title", {"name": "str"}),
        ("content", {"name": "str"}),
        ("created_at", {"name": "datetime", "import_from": "datetime"}),
        ("updated_at", {"name": "datetime", "import_from": "datetime"}),
        # relations
        ("author", {"name": "User", "optional": True}),
    ],
    "config": {"frozen": True, "extra": "forbid"},
}

print(build_model_definitions([User, Post]))

from datetime import datetime
from models import UserMeta
from pydantic import BaseModel, ConfigDict
from uuid import UUID

class User(BaseModel):
    """Represents a user in the system"""

    id: UUID
    name: str
    meta: UserMeta
    posts: list[Post] | None

class Post(BaseModel):
    """Represents a blog post written by a user"""

    model_config = ConfigDict(frozen=True, extra='forbid')

    id: UUID
    title: str
    content: str
    created_at: datetime
    updated_at: datetime
    author: User | None


SQLAlchemy のテーブルを Pydantic モデルに変換する

In [None]:
import inspect
from typing import get_args, get_origin, Any, get_type_hints, Union
from types import UnionType
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import Mapped, DeclarativeBase


def detect_typeinfo(t: Any) -> TypeInfo:
    """型オブジェクトから型情報 (TypeInfo) を抽出する"""
    origin = get_origin(t)
    args = get_args(t)

    if origin is Mapped:
        inner_type = args[0]
        return detect_typeinfo(inner_type)

    # Union
    if origin is Union or origin is UnionType:
        non_none_args = [arg for arg in args if arg is not type(None)]
        optional = type(None) in args

        # None 以外の型が複数ある Union は未対応
        if len(non_none_args) > 1:
            raise NotImplementedError

        result = detect_typeinfo(non_none_args[0])
        if optional:
            result["optional"] = True
        return result

    if t is type(None):
        return {"name": "None"}

    if origin is list:
        t = detect_typeinfo(args[0])
        t["list"] = True
        return t

    # TypedDict
    if hasattr(t, "__annotations__"):
        module = t.__module__
        name = t.__qualname__ if hasattr(t, "__qualname__") else t.__name__

        if module in ("builtins", "__builtin__"):
            return {"name": name}
        return {"name": name, "import_from": module}

    if origin is dict:
        return {"name": "dict"}

    if inspect.isclass(t):
        module = t.__module__
        name = t.__qualname__

        if module in ("builtins", "__builtin__"):
            return {"name": name}

        if name == "UUID" and module == "uuid":
            return {"name": "UUID", "import_from": "uuid"}

        if module == "datetime":
            return {"name": name, "import_from": "datetime"}

        return {"name": name, "import_from": module}

    return {"name": str(t)}


def sqlalchemy_model_to_pydantic_model_definition(
    cls: type[DeclarativeBase],
    name: str | None = None,
) -> ModelDefinition:
    mapper = sa_inspect(cls)
    fields: list[tuple[str, TypeInfo]] = []

    if name is None:
        name = getattr(cls, "__pydantic_model__")
    if type(name) is not str:
        raise ValueError("name is not specified")

    type_hints = get_type_hints(cls)
    for col in mapper.columns:
        t = detect_typeinfo(type_hints.get(col.name, col.type.python_type))
        if col.nullable:
            t["optional"] = True
        fields.append((col.name, t))

    for rel in mapper.relationships:
        rel_cls = rel.mapper.class_
        t: TypeInfo = {
            "name": getattr(rel_cls, "__pydantic_model__", rel_cls.__name__),
            "list": rel.uselist or False,
            "optional": True,
        }
        fields.append((rel.key, t))

    return {
        "name": name,
        "fields": fields,
        "config": {"extra": "forbid"},
    }


In [43]:
from tables import UserTable, ArticleTable, CommentTable
from utils import ruff_format

code = build_model_definitions(
    [
        sqlalchemy_model_to_pydantic_model_definition(cls)
        for cls in [UserTable, ArticleTable, CommentTable]
    ]
)

print(ruff_format(code, line_length=120))

from datetime import datetime
from uuid import UUID

from pydantic import BaseModel, ConfigDict

from tables import UserMeta


class User(BaseModel):
    model_config = ConfigDict(extra="forbid")

    id: UUID
    created_at: datetime
    updated_at: datetime
    name: str
    meta: UserMeta
    articles: list[Article] | None
    comments: list[Comment] | None


class Article(BaseModel):
    model_config = ConfigDict(extra="forbid")

    id: UUID
    created_at: datetime
    updated_at: datetime
    author_id: UUID
    title: str
    body: str
    published_at: datetime | None
    author: User | None
    comments: list[Comment] | None


class Comment(BaseModel):
    model_config = ConfigDict(extra="forbid")

    id: UUID
    created_at: datetime
    updated_at: datetime
    article_id: UUID
    author_id: UUID
    body: str
    article: Article | None
    author: User | None



TODO: partial モデルを作れるようにする
- フィールドの除外/追加
- Optional の除外/追加