Skip to content

Commit

Permalink
added option --use-double-quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
nesb1 committed Aug 10, 2022
1 parent ca41c24 commit de3c5a8
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 8 deletions.
2 changes: 2 additions & 0 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def generate(
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand Down Expand Up @@ -360,6 +361,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
**kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
default=None,
)

arg_parser.add_argument(
"--use-double-quotes",
action='store_true',
default=False,
help="Model generated with double quotes. Single quotes or "
"your black config skip_string_normalization value will be used without this option.",
)

arg_parser.add_argument(
'--encoding',
help=f'The encoding of input and output (default: {DEFAULT_ENCODING})',
Expand Down Expand Up @@ -446,6 +454,7 @@ def _validate_use_annotated(cls, values: Dict[str, Any]) -> Dict[str, Any]:
use_annotated: bool = False
use_non_positive_negative_number_constrained_types: bool = False
original_field_name_delimiter: Optional[str] = None
use_double_quotes: bool = False

def merge_args(self, args: Namespace) -> None:
set_args = {
Expand Down Expand Up @@ -579,6 +588,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
use_annotated=config.use_annotated,
use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=config.original_field_name_delimiter,
use_double_quotes=config.use_double_quotes,
)
return Exit.OK
except InvalidClassNameError as e:
Expand Down
10 changes: 6 additions & 4 deletions datamodel_code_generator/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
python_version: PythonVersion,
settings_path: Optional[Path] = None,
wrap_string_literal: Optional[bool] = None,
skip_string_normalization: bool = True,
):
if not settings_path:
settings_path = Path().resolve()
Expand Down Expand Up @@ -95,12 +96,13 @@ def __init__(
] = experimental_string_processing

if TYPE_CHECKING:
self.back_mode: black.FileMode
self.black_mode: black.FileMode
else:
self.back_mode = black.FileMode(
self.black_mode = black.FileMode(
target_versions={BLACK_PYTHON_VERSION[python_version]},
line_length=config.get("line-length", black.DEFAULT_LINE_LENGTH),
string_normalization=not config.get("skip-string-normalization", True),
string_normalization=not skip_string_normalization
or not config.get("skip-string-normalization", True),
**black_kwargs,
)

Expand All @@ -121,7 +123,7 @@ def format_code(
def apply_black(self, code: str) -> str:
return black.format_str(
code,
mode=self.back_mode,
mode=self.black_mode,
)

if isort.__version__.startswith('4.'):
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
IMPORT_CONSTR,
IMPORT_EMAIL_STR,
IMPORT_IPV4ADDRESS,
IMPORT_IPV6ADDRESS,
IMPORT_IPV4NETWORKS,
IMPORT_IPV6ADDRESS,
IMPORT_IPV6NETWORKS,
IMPORT_NEGATIVE_FLOAT,
IMPORT_NEGATIVE_INT,
Expand Down
7 changes: 6 additions & 1 deletion datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
):
self.data_type_manager: DataTypeManager = data_type_manager_type(
target_python_version,
Expand Down Expand Up @@ -391,6 +392,7 @@ def __init__(
self.use_non_positive_negative_number_constrained_types = (
use_non_positive_negative_number_constrained_types
)
self.use_double_quotes = use_double_quotes

@property
def iter_source(self) -> Iterator[Source]:
Expand Down Expand Up @@ -454,7 +456,10 @@ def parse(

if format_:
code_formatter: Optional[CodeFormatter] = CodeFormatter(
self.target_python_version, settings_path, self.wrap_string_literal
self.target_python_version,
settings_path,
self.wrap_string_literal,
skip_string_normalization=not self.use_double_quotes,
)
else:
code_formatter = None
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def __init__(
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
):
super().__init__(
source=source,
Expand Down Expand Up @@ -377,6 +378,7 @@ def __init__(
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
)

self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
):
super().__init__(
source=source,
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
use_annotated=use_annotated,
use_non_positive_negative_number_constrained_types=use_non_positive_negative_number_constrained_types,
original_field_name_delimiter=original_field_name_delimiter,
use_double_quotes=use_double_quotes,
)
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
OpenAPIScope.Schemas
Expand Down
5 changes: 4 additions & 1 deletion docs/formatting.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
Code generated by `datamodel-codegen` will be passed through `isort` and
`black` to produce consistent, well-formatted results. Settings for these tools
can be specified in `pyproject.toml` (located in the output directory, or in
some parent of the output directory).
some parent of the output directory). Also for black you can disable
skip-string-normalization with using datamodel-codegen option `--use-double-quotes`,
it will override your black config skip-string-normalization value.
Using --use-double-quotes may be useful if you can't use black config.

Example `pyproject.toml`:
```toml
Expand Down
23 changes: 22 additions & 1 deletion tests/test_format.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import sys

from datamodel_code_generator.format import PythonVersion
import pytest

from datamodel_code_generator.format import CodeFormatter, PythonVersion


def test_python_version():
"""Ensure that the python version used for the tests is properly listed"""

_ = PythonVersion("{}.{}".format(*sys.version_info[:2]))


@pytest.mark.parametrize(
("skip_string_normalization", "expected_output"),
[
(True, "a = 'b'"),
(False, 'a = "b"'),
],
)
def test_format_code_with_skip_string_normalization(
skip_string_normalization: bool, expected_output: str
) -> None:
formatter = CodeFormatter(
PythonVersion.PY_37, skip_string_normalization=skip_string_normalization
)

formatted_code = formatter.format_code("a = 'b'")

assert formatted_code == expected_output + "\n"

0 comments on commit de3c5a8

Please sign in to comment.