Skip to content

Commit

Permalink
Merge pull request #8 from lidatong/fix-infer-nested-and-default
Browse files Browse the repository at this point in the history
Fix infer missing bugs with nested objects and field defaults
  • Loading branch information
lidatong authored Aug 26, 2018
2 parents 72f53ef + 818878f commit 2b03814
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 30 deletions.
13 changes: 5 additions & 8 deletions dataclasses_json/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,21 @@ def from_json(cls,
parse_float=parse_float,
parse_int=parse_int,
parse_constant=parse_constant)

if infer_missing:
init_kwargs = ChainMap(init_kwargs,
{field.name: None for field in fields(cls)
if field.name not in init_kwargs})
return _decode_dataclass(cls, init_kwargs)
return _decode_dataclass(cls, init_kwargs, infer_missing)

@classmethod
def from_json_array(cls,
kvss,
*,
encoding=None,
parse_float=None,
parse_int=None,
parse_constant=None):
parse_constant=None,
infer_missing=False):
init_kwargs_array = json.loads(kvss,
encoding=encoding,
parse_float=parse_float,
parse_int=parse_int,
parse_constant=parse_constant)
return [_decode_dataclass(cls, init_kwargs)
return [_decode_dataclass(cls, init_kwargs, infer_missing)
for init_kwargs in init_kwargs_array]
34 changes: 24 additions & 10 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from dataclasses import fields, is_dataclass
from typing import Collection, Optional
import sys
from collections import ChainMap
from dataclasses import fields, is_dataclass, MISSING
from typing import Collection, Optional
from functools import partial


def _get_type_origin(type_):
Expand Down Expand Up @@ -51,14 +53,26 @@ def default(self, o):
return json.JSONEncoder.default(self, o)


def _decode_dataclass(cls, kvs):
def _decode_dataclass(cls, kvs, infer_missing):
kvs = {} if kvs is None and infer_missing else kvs
missing_fields = {field for field in fields(cls) if field.name not in kvs}
for field in missing_fields:
if field.default is not MISSING:
kvs[field.name] = field.default
elif infer_missing:
kvs[field.name] = None

init_kwargs = {}
for field in fields(cls):
field_value = kvs[field.name]
if is_dataclass(field.type):
init_kwargs[field.name] = _decode_dataclass(field.type, field_value)
init_kwargs[field.name] = _decode_dataclass(field.type,
field_value,
infer_missing)
elif _is_supported_generic(field.type) and field.type != str:
init_kwargs[field.name] = _decode_generic(field.type, field_value)
init_kwargs[field.name] = _decode_generic(field.type,
field_value,
infer_missing)
else:
init_kwargs[field.name] = field_value
return cls(**init_kwargs)
Expand All @@ -75,7 +89,7 @@ def _is_supported_generic(type_):
return is_collection or is_optional


def _decode_generic(type_, value):
def _decode_generic(type_, value, infer_missing):
if value is None:
res = value
elif _issubclass_safe(_get_type_origin(type_), Collection):
Expand All @@ -88,9 +102,9 @@ def _decode_generic(type_, value):
# hence the check of `is_dataclass(value)`
type_arg = type_.__args__[0]
if is_dataclass(type_arg) or is_dataclass(value):
xs = (_decode_dataclass(type_arg, v) for v in value)
xs = (_decode_dataclass(type_arg, v, infer_missing) for v in value)
elif _is_supported_generic(type_arg):
xs = (_decode_generic(type_arg, v) for v in value)
xs = (_decode_generic(type_arg, v, infer_missing) for v in value)
else:
xs = value
# get the constructor if using corresponding generic type in `typing`
Expand All @@ -102,9 +116,9 @@ def _decode_generic(type_, value):
else: # Optional
type_arg = type_.__args__[0]
if is_dataclass(type_arg) or is_dataclass(value):
res = _decode_dataclass(type_arg, value)
res = _decode_dataclass(type_arg, value, infer_missing)
elif _is_supported_generic(type_arg):
res = _decode_generic(type_arg, value)
res = _decode_generic(type_arg, value, infer_missing)
else:
res = value
return res
Expand Down
17 changes: 16 additions & 1 deletion tests/test_entities.py → tests/entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (Collection,
Deque,
FrozenSet,
Expand Down Expand Up @@ -44,6 +44,11 @@ class DataClassWithOptional(DataClassJsonMixin):
x: Optional[int]


@dataclass(frozen=True)
class DataClassWithOptionalRecursive(DataClassJsonMixin):
x: DataClassWithOptional


@dataclass(frozen=True)
class DataClassWithUnionIntNone(DataClassJsonMixin):
x: Union[int, None]
Expand All @@ -64,6 +69,16 @@ class DataClassXs(DataClassJsonMixin):
xs: List[DataClassX]


@dataclass(frozen=True)
class DataClassImmutableDefault(DataClassJsonMixin):
x: int = 0


@dataclass(frozen=True)
class DataClassMutableDefault(DataClassJsonMixin):
xs: List[int] = field(default_factory=list)


class MyCollection(Collection[A]):
def __init__(self, xs: Collection[A]) -> None:
self.xs = xs
Expand Down
41 changes: 32 additions & 9 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections import deque

from tests.test_entities import (DataClassWithDeque,
DataClassWithFrozenSet,
DataClassWithList,
DataClassWithMyCollection,
DataClassWithOptional,
DataClassWithSet,
DataClassWithTuple,
DataClassWithUnionIntNone,
MyCollection)
from tests.entities import (DataClassWithDeque,
DataClassWithFrozenSet,
DataClassWithList,
DataClassWithMyCollection,
DataClassWithOptional,
DataClassWithOptionalRecursive,
DataClassWithSet,
DataClassWithTuple,
DataClassWithUnionIntNone,
MyCollection,
DataClassImmutableDefault,
DataClassMutableDefault)


class TestEncoder:
Expand Down Expand Up @@ -39,6 +42,12 @@ def test_my_collection(self):
assert DataClassWithMyCollection(
MyCollection([1])).to_json() == '{"xs": [1]}'

def test_immutable_default(self):
assert DataClassImmutableDefault().to_json() == '{"x": 0}'

def test_mutable_default(self):
assert DataClassMutableDefault().to_json() == '{"xs": []}'


class TestDecoder:
def test_list(self):
Expand Down Expand Up @@ -71,10 +80,24 @@ def test_infer_missing(self):
actual = DataClassWithOptional.from_json('{}', infer_missing=True)
assert (actual == DataClassWithOptional(None))

def test_infer_missing_is_recursive(self):
actual = DataClassWithOptionalRecursive.from_json('{"x": null}',
infer_missing=True)
assert (actual == DataClassWithOptionalRecursive(
DataClassWithOptional(None)))

def test_my_collection(self):
assert (DataClassWithMyCollection.from_json('{"xs": [1]}') ==
DataClassWithMyCollection(MyCollection([1])))

def test_my_list_collection(self):
assert (DataClassWithMyCollection.from_json_array('[{"xs": [1]}]')
== [DataClassWithMyCollection(MyCollection([1]))])

def test_immutable_default(self):
assert (DataClassImmutableDefault.from_json('{"x": 0}')
== DataClassImmutableDefault())

def test_mutable_default(self):
assert (DataClassMutableDefault.from_json('{"xs": []}')
== DataClassMutableDefault())
2 changes: 1 addition & 1 deletion tests/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests.hypothesis2 import examples
from tests.hypothesis2.strategies import deques, optionals
from tests.test_entities import (DataClassWithDeque, DataClassWithFrozenSet,
from tests.entities import (DataClassWithDeque, DataClassWithFrozenSet,
DataClassWithList, DataClassWithOptional,
DataClassWithSet, DataClassWithTuple)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_nested.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.test_entities import (DataClassWithDataClass,
from tests.entities import (DataClassWithDataClass,
DataClassWithList,
DataClassX,
DataClassXs)
Expand Down

0 comments on commit 2b03814

Please sign in to comment.