Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use recursive method to correctly handle one_to_many fields #45

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 57 additions & 4 deletions djantic/main.py
Expand Up @@ -140,6 +140,10 @@ def get_field_names(cls) -> List[str]:
model_fields = [
name for name in model_fields if name not in cls.__config__.exclude
]
for field in model_fields:
if hasattr(cls.__fields__[field].type_, 'get_field_names'):
sub_fields = cls.__fields__[field].type_.get_field_names()
model_fields = model_fields + [f'{field}__{sub_field}' for sub_field in sub_fields]

return model_fields

Expand All @@ -155,6 +159,58 @@ def _get_object_model(cls, obj_data: dict) -> "ModelSchema":

return model_schema

@classmethod
def _get_mapping(cls, qs_values: dict) -> dict:
"""Breaks an query set key values pair (with sub-fields separated by __) into a dictionary of dictionaries.

Args:
qs_values (dict):
{
'a': 1,
'y': 1,
'a__x': 1,
'a__b': 1,
'a__b__z': 1,
'a__b__c': 1,
'a__b__c__k': 1,
'a__b__c__j': 1,
'b': None,
'b__b': None,
}

Returns:
dict: Ex:
{
'a': {
'x': 1,
'b': {
'z': 1,
'c': {
'k': 1,
'j': 1
}
}
},
'b': None,
'y': 1
}

"""

data = {}
mapping = [x.split('__') for x in qs_values]
mapping.sort(key=lambda x: len(x))
for key_map in mapping:
if len(key_map) == 1:
data[key_map[0]] = qs_values[key_map[0]]
else:
if qs_values[key_map[0]] is None:
data[key_map[0]] = None
else:
sub_c = {'__'.join(k.split('__')[1:]): v for k, v in qs_values.items() if len(k.split('__')) > 1 and k.split('__')[0] == key_map[0]}
data[key_map[0]] = cls._get_mapping(sub_c)
return data

@classmethod
def from_django(
cls,
Expand Down Expand Up @@ -192,10 +248,7 @@ def from_django(
related_qs = related_obj.all()

if schema_cls:
related_obj_data = [
schema_cls.construct(**obj_vals)
for obj_vals in related_qs.values(*related_field_names)
]
related_obj_data = [schema_cls(**cls._get_mapping(x)) for x in related_qs.values(*related_field_names)]

else:
related_obj_data = list(related_obj.all().values("id"))
Expand Down