Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
285 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ class ClassType(Enum): | |
Dataclass = "dataclass" | ||
Attrs = "attrs" | ||
Pydantic = "pydantic" | ||
SqlModel = "sqlmodel" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import List, Tuple | ||
|
||
from json_to_models.dynamic_typing import ImportPathList, MetaData | ||
from json_to_models.models.base import GenericModelCodeGenerator | ||
from json_to_models.models.pydantic import PydanticModelCodeGenerator | ||
|
||
|
||
class SqlModelCodeGenerator(PydanticModelCodeGenerator): | ||
def generate(self, nested_classes: List[str] = None, extra: str = "", **kwargs) \ | ||
-> Tuple[ImportPathList, str]: | ||
imports, body = GenericModelCodeGenerator.generate( | ||
self, | ||
bases='SQLModel, table=True', | ||
nested_classes=nested_classes, | ||
extra=extra | ||
) | ||
imports.append(('sqlmodel', ['SQLModel', 'Field'])) | ||
body = """ | ||
# Warn! This generated code does not respect SQLModel Relationship and foreign_key, please add them manually. | ||
""".strip() + '\n' + body | ||
return imports, body | ||
|
||
def convert_field_name(self, name): | ||
if name in ('id', 'pk'): | ||
return name | ||
return super().convert_field_name(name) | ||
|
||
def _get_field_kwargs(self, name: str, meta: MetaData, optional: bool, data: dict): | ||
kwargs = super()._get_field_kwargs(name, meta, optional, data) | ||
# Detect primary key | ||
if data['name'] in ('id', 'pk') and meta is int: | ||
kwargs['primary_key'] = True | ||
return kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
from typing import Dict, List | ||
|
||
import pytest | ||
|
||
from json_to_models.dynamic_typing import ( | ||
DDict, | ||
DList, | ||
DOptional, | ||
DUnion, | ||
FloatString, | ||
IntString, | ||
ModelMeta, | ||
compile_imports, | ||
) | ||
from json_to_models.models.base import generate_code | ||
from json_to_models.models.sqlmodel import SqlModelCodeGenerator | ||
from json_to_models.models.structure import sort_fields | ||
from test.test_code_generation.test_models_code_generator import model_factory, trim | ||
|
||
# Data structure: | ||
# pytest.param id -> { | ||
# "model" -> (model_name, model_metadata), | ||
# test_name -> expected, ... | ||
# } | ||
test_data = { | ||
"base": { | ||
"model": ("Test", { | ||
"foo": int, | ||
"Bar": int, | ||
"baz": float | ||
}), | ||
"fields_data": { | ||
"foo": { | ||
"name": "foo", | ||
"type": "int" | ||
}, | ||
"Bar": { | ||
"name": "bar", | ||
"type": "int", | ||
"body": 'Field(..., alias="Bar")' | ||
}, | ||
"baz": { | ||
"name": "baz", | ||
"type": "float" | ||
} | ||
}, | ||
"fields": { | ||
"imports": "", | ||
"fields": [ | ||
f"foo: int", | ||
f'bar: int = Field(..., alias="Bar")', | ||
f"baz: float", | ||
] | ||
}, | ||
"generated": trim(f""" | ||
from sqlmodel import Field, SQLModel | ||
# Warn! This generated code does not respect SQLModel Relationship and foreign_key, please add them manually. | ||
class Test(SQLModel, table=True): | ||
foo: int | ||
bar: int = Field(..., alias="Bar") | ||
baz: float | ||
""") | ||
}, | ||
"complex": { | ||
"model": ("Test", { | ||
"foo": int, | ||
"baz": DOptional(DList(DList(str))), | ||
"bar": DOptional(IntString), | ||
"qwerty": FloatString, | ||
"asdfg": DOptional(int), | ||
"dict": DDict(int), | ||
"not": bool, | ||
"1day": int, | ||
"день_недели": str, | ||
}), | ||
"fields_data": { | ||
"foo": { | ||
"name": "foo", | ||
"type": "int" | ||
}, | ||
"baz": { | ||
"name": "baz", | ||
"type": "Optional[List[List[str]]]", | ||
"body": "[]" | ||
}, | ||
"bar": { | ||
"name": "bar", | ||
"type": "Optional[int]", | ||
"body": "None" | ||
}, | ||
"qwerty": { | ||
"name": "qwerty", | ||
"type": "float" | ||
}, | ||
"asdfg": { | ||
"name": "asdfg", | ||
"type": "Optional[int]", | ||
"body": "None" | ||
}, | ||
"dict": { | ||
"name": "dict_", | ||
"type": "Dict[str, int]", | ||
"body": 'Field(..., alias="dict")' | ||
}, | ||
"not": { | ||
"name": "not_", | ||
"type": "bool", | ||
"body": 'Field(..., alias="not")' | ||
}, | ||
"1day": { | ||
"name": "one_day", | ||
"type": "int", | ||
"body": 'Field(..., alias="1day")' | ||
}, | ||
"день_недели": { | ||
"name": "den_nedeli", | ||
"type": "str", | ||
"body": 'Field(..., alias="день_недели")' | ||
} | ||
}, | ||
"generated": trim(f""" | ||
from sqlmodel import Field, SQLModel | ||
from typing import Dict, List, Optional | ||
# Warn! This generated code does not respect SQLModel Relationship and foreign_key, please add them manually. | ||
class Test(SQLModel, table=True): | ||
foo: int | ||
qwerty: float | ||
dict_: Dict[str, int] = Field(..., alias="dict") | ||
not_: bool = Field(..., alias="not") | ||
one_day: int = Field(..., alias="1day") | ||
den_nedeli: str = Field(..., alias="день_недели") | ||
baz: Optional[List[List[str]]] = [] | ||
bar: Optional[int] = None | ||
asdfg: Optional[int] = None | ||
""") | ||
}, | ||
"converters": { | ||
"model": ("Test", { | ||
"a": int, | ||
"b": IntString, | ||
"c": DOptional(FloatString), | ||
"d": DList(DList(DList(IntString))), | ||
"e": DDict(IntString), | ||
"u": DUnion(DDict(IntString), DList(DList(IntString))), | ||
}), | ||
"generated": trim(""" | ||
from sqlmodel import Field, SQLModel | ||
from typing import Dict, List, Optional, Union | ||
# Warn! This generated code does not respect SQLModel Relationship and foreign_key, please add them manually. | ||
class Test(SQLModel, table=True): | ||
a: int | ||
b: int | ||
d: List[List[List[int]]] | ||
e: Dict[str, int] | ||
u: Union[Dict[str, int], List[List[int]]] | ||
c: Optional[float] = None | ||
""") | ||
}, | ||
"sql_models": { | ||
"model": ("Test", { | ||
"id": int, | ||
"name": str, | ||
"x": DList(int) | ||
}), | ||
"generated": trim(""" | ||
from sqlmodel import Field, SQLModel | ||
from typing import List | ||
# Warn! This generated code does not respect SQLModel Relationship and foreign_key, please add them manually. | ||
class Test(SQLModel, table=True): | ||
id: int = Field(..., primary_key=True) | ||
name: str | ||
x: List[int] | ||
""") | ||
} | ||
} | ||
|
||
test_data_unzip = { | ||
test: [ | ||
pytest.param( | ||
model_factory(*data["model"]), | ||
data[test], | ||
id=id | ||
) | ||
for id, data in test_data.items() | ||
if test in data | ||
] | ||
for test in ("fields_data", "fields", "generated") | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("value,expected", test_data_unzip["fields_data"]) | ||
def test_fields_data_attr(value: ModelMeta, expected: Dict[str, dict]): | ||
gen = SqlModelCodeGenerator(value) | ||
required, optional = sort_fields(value) | ||
for is_optional, fields in enumerate((required, optional)): | ||
for field in fields: | ||
field_imports, data = gen.field_data(field, value.type[field], bool(is_optional)) | ||
assert data == expected[field] | ||
|
||
|
||
@pytest.mark.parametrize("value,expected", test_data_unzip["fields"]) | ||
def test_fields_attr(value: ModelMeta, expected: dict): | ||
expected_imports: str = expected["imports"] | ||
expected_fields: List[str] = expected["fields"] | ||
gen = SqlModelCodeGenerator(value) | ||
imports, fields = gen.fields | ||
imports = compile_imports(imports) | ||
assert imports == expected_imports | ||
assert fields == expected_fields | ||
|
||
|
||
@pytest.mark.parametrize("value,expected", test_data_unzip["generated"]) | ||
def test_generated_attr(value: ModelMeta, expected: str): | ||
generated = generate_code( | ||
( | ||
[{"model": value, "nested": []}], | ||
{} | ||
), | ||
SqlModelCodeGenerator, | ||
class_generator_kwargs={} | ||
) | ||
assert generated.rstrip() == expected, generated |