diff --git a/README.md b/README.md index 43c9359..956d344 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,22 @@ Options: -r, --generate-routers Generate modular api with multiple routers using RouterAPI (for bigger applications). --specify-tags Use along with --generate-routers to generate specific routers from given list of tags. -c, --custom-visitors PATH - A custom visitor that adds variables to the template. + -d, --output-model-type Specify a Pydantic base model to use (see [datamodel-code-generator](https://github.com/koxudaxi/datamodel-code-generator); default is `pydantic.BaseModel`). + -p, --python-version Specify a Python version to target (default is `3.8`). --install-completion Install completion for the current shell. --show-completion Show completion for the current shell, to copy it or customize the installation. --help Show this message and exit. ``` +### Pydantic 2 support + +Specify the Pydantic 2 `BaseModel` version in the command line, for example: + +```sh +$ fastapi-codegen --input api.yaml --output app --output-model-type pydantic_v2.BaseModel +``` + ## Example ### OpenAPI ```sh diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index f644b09..517cb58 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -5,9 +5,10 @@ from typing import Any, Dict, List, Optional import typer -from datamodel_code_generator import LiteralType, PythonVersion, chdir +from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir from datamodel_code_generator.format import CodeFormatter from datamodel_code_generator.imports import Import, Imports +from datamodel_code_generator.model import get_data_model_types from datamodel_code_generator.reference import Reference from datamodel_code_generator.types import DataType from jinja2 import Environment, FileSystemLoader @@ -57,6 +58,12 @@ def main( None, "--custom-visitor", "-c" ), disable_timestamp: bool = typer.Option(False, "--disable-timestamp"), + output_model_type: DataModelType = typer.Option( + DataModelType.PydanticBaseModel.value, "--output-model-type", "-d" + ), + python_version: PythonVersion = typer.Option( + PythonVersion.PY_38.value, "--python-version", "-p" + ), ) -> None: input_name: str = input_file input_text: str @@ -69,20 +76,6 @@ def main( else: model_path = MODEL_PATH - if enum_field_as_literal: - return generate_code( - input_name, - input_text, - encoding, - output_dir, - template_dir, - model_path, - enum_field_as_literal, # type: ignore[arg-type] - custom_visitors=custom_visitors, - disable_timestamp=disable_timestamp, - generate_routers=generate_routers, - specify_tags=specify_tags, - ) return generate_code( input_name, input_text, @@ -90,10 +83,13 @@ def main( output_dir, template_dir, model_path, + enum_field_as_literal=enum_field_as_literal or None, custom_visitors=custom_visitors, disable_timestamp=disable_timestamp, generate_routers=generate_routers, specify_tags=specify_tags, + output_model_type=output_model_type, + python_version=python_version, ) @@ -119,6 +115,8 @@ def generate_code( disable_timestamp: bool = False, generate_routers: Optional[bool] = None, specify_tags: Optional[str] = None, + output_model_type: DataModelType = DataModelType.PydanticBaseModel, + python_version: PythonVersion = PythonVersion.PY_38, ) -> None: if not model_path: model_path = MODEL_PATH @@ -130,10 +128,19 @@ def generate_code( template_dir = ( BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR ) - if enum_field_as_literal: - parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type] - else: - parser = OpenAPIParser(input_text) + + data_model_types = get_data_model_types(output_model_type, python_version) + + parser = OpenAPIParser( + input_text, + enum_field_as_literal=enum_field_as_literal, + data_model_type=data_model_types.data_model, + data_model_root_type=data_model_types.root_model, + data_model_field_type=data_model_types.field_model, + data_type_manager_type=data_model_types.data_type_manager, + dump_resolve_reference_action=data_model_types.dump_resolve_reference_action, + ) + with chdir(output_dir): models = parser.parse() output = output_dir / model_path @@ -153,7 +160,7 @@ def generate_code( ) results: Dict[Path, str] = {} - code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve()) + code_formatter = CodeFormatter(python_version, Path().resolve()) template_vars: Dict[str, object] = {"info": parser.parse_info()} visitors: List[Visitor] = []