Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Pydantic 2 & additional python versions #414

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,22 @@ Options:
-r, --generate-routers Generate modular api with multiple routers using RouterAPI (for bigger applications).
--specify-tags Use along with --generate-routers to generate specific routers from given list of tags.
-c, --custom-visitors PATH - A custom visitor that adds variables to the template.
-d, --output-model-type Specify a Pydantic base model to use (see [datamodel-code-generator](https://github.com/koxudaxi/datamodel-code-generator); default is `pydantic.BaseModel`).
-p, --python-version Specify a Python version to target (default is `3.8`).
--install-completion Install completion for the current shell.
--show-completion Show completion for the current shell, to copy it
or customize the installation.
--help Show this message and exit.
```

### Pydantic 2 support

Specify the Pydantic 2 `BaseModel` version in the command line, for example:

```sh
$ fastapi-codegen --input api.yaml --output app --output-model-type pydantic_v2.BaseModel
```

## Example
### OpenAPI
```sh
Expand Down
47 changes: 27 additions & 20 deletions fastapi_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Any, Dict, List, Optional

import typer
from datamodel_code_generator import LiteralType, PythonVersion, chdir
from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir
from datamodel_code_generator.format import CodeFormatter
from datamodel_code_generator.imports import Import, Imports
from datamodel_code_generator.model import get_data_model_types
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import DataType
from jinja2 import Environment, FileSystemLoader
Expand Down Expand Up @@ -57,6 +58,12 @@ def main(
None, "--custom-visitor", "-c"
),
disable_timestamp: bool = typer.Option(False, "--disable-timestamp"),
output_model_type: DataModelType = typer.Option(
DataModelType.PydanticBaseModel.value, "--output-model-type", "-d"
),
python_version: PythonVersion = typer.Option(
PythonVersion.PY_38.value, "--python-version", "-p"
),
) -> None:
input_name: str = input_file
input_text: str
Expand All @@ -69,31 +76,20 @@ def main(
else:
model_path = MODEL_PATH

if enum_field_as_literal:
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
enum_field_as_literal, # type: ignore[arg-type]
custom_visitors=custom_visitors,
disable_timestamp=disable_timestamp,
generate_routers=generate_routers,
specify_tags=specify_tags,
)
return generate_code(
input_name,
input_text,
encoding,
output_dir,
template_dir,
model_path,
enum_field_as_literal=enum_field_as_literal or None,
custom_visitors=custom_visitors,
disable_timestamp=disable_timestamp,
generate_routers=generate_routers,
specify_tags=specify_tags,
output_model_type=output_model_type,
python_version=python_version,
)


Expand All @@ -119,6 +115,8 @@ def generate_code(
disable_timestamp: bool = False,
generate_routers: Optional[bool] = None,
specify_tags: Optional[str] = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
python_version: PythonVersion = PythonVersion.PY_38,
) -> None:
if not model_path:
model_path = MODEL_PATH
Expand All @@ -130,10 +128,19 @@ def generate_code(
template_dir = (
BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
)
if enum_field_as_literal:
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type]
else:
parser = OpenAPIParser(input_text)

data_model_types = get_data_model_types(output_model_type, python_version)

parser = OpenAPIParser(
input_text,
enum_field_as_literal=enum_field_as_literal,
data_model_type=data_model_types.data_model,
data_model_root_type=data_model_types.root_model,
data_model_field_type=data_model_types.field_model,
data_type_manager_type=data_model_types.data_type_manager,
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
)

with chdir(output_dir):
models = parser.parse()
output = output_dir / model_path
Expand All @@ -153,7 +160,7 @@ def generate_code(
)

results: Dict[Path, str] = {}
code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
code_formatter = CodeFormatter(python_version, Path().resolve())

template_vars: Dict[str, object] = {"info": parser.parse_info()}
visitors: List[Visitor] = []
Expand Down