Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infer missing bugs with nested objects and field defaults #8

Merged
merged 1 commit into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions dataclasses_json/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,21 @@ def from_json(cls,
parse_float=parse_float,
parse_int=parse_int,
parse_constant=parse_constant)

if infer_missing:
init_kwargs = ChainMap(init_kwargs,
{field.name: None for field in fields(cls)
if field.name not in init_kwargs})
return _decode_dataclass(cls, init_kwargs)
return _decode_dataclass(cls, init_kwargs, infer_missing)

@classmethod
def from_json_array(cls,
kvss,
*,
encoding=None,
parse_float=None,
parse_int=None,
parse_constant=None):
parse_constant=None,
infer_missing=False):
init_kwargs_array = json.loads(kvss,
encoding=encoding,
parse_float=parse_float,
parse_int=parse_int,
parse_constant=parse_constant)
return [_decode_dataclass(cls, init_kwargs)
return [_decode_dataclass(cls, init_kwargs, infer_missing)
for init_kwargs in init_kwargs_array]
34 changes: 24 additions & 10 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from dataclasses import fields, is_dataclass
from typing import Collection, Optional
import sys
from collections import ChainMap
from dataclasses import fields, is_dataclass, MISSING
from typing import Collection, Optional
from functools import partial


def _get_type_origin(type_):
Expand Down Expand Up @@ -51,14 +53,26 @@ def default(self, o):
return json.JSONEncoder.default(self, o)


def _decode_dataclass(cls, kvs):
def _decode_dataclass(cls, kvs, infer_missing):
kvs = {} if kvs is None and infer_missing else kvs
missing_fields = {field for field in fields(cls) if field.name not in kvs}
for field in missing_fields:
if field.default is not MISSING:
kvs[field.name] = field.default
elif infer_missing:
kvs[field.name] = None

init_kwargs = {}
for field in fields(cls):
field_value = kvs[field.name]
if is_dataclass(field.type):
init_kwargs[field.name] = _decode_dataclass(field.type, field_value)
init_kwargs[field.name] = _decode_dataclass(field.type,
field_value,
infer_missing)
elif _is_supported_generic(field.type) and field.type != str:
init_kwargs[field.name] = _decode_generic(field.type, field_value)
init_kwargs[field.name] = _decode_generic(field.type,
field_value,
infer_missing)
else:
init_kwargs[field.name] = field_value
return cls(**init_kwargs)
Expand All @@ -75,7 +89,7 @@ def _is_supported_generic(type_):
return is_collection or is_optional


def _decode_generic(type_, value):
def _decode_generic(type_, value, infer_missing):
if value is None:
res = value
elif _issubclass_safe(_get_type_origin(type_), Collection):
Expand All @@ -88,9 +102,9 @@ def _decode_generic(type_, value):
# hence the check of `is_dataclass(value)`
type_arg = type_.__args__[0]
if is_dataclass(type_arg) or is_dataclass(value):
xs = (_decode_dataclass(type_arg, v) for v in value)
xs = (_decode_dataclass(type_arg, v, infer_missing) for v in value)
elif _is_supported_generic(type_arg):
xs = (_decode_generic(type_arg, v) for v in value)
xs = (_decode_generic(type_arg, v, infer_missing) for v in value)
else:
xs = value
# get the constructor if using corresponding generic type in `typing`
Expand All @@ -102,9 +116,9 @@ def _decode_generic(type_, value):
else: # Optional
type_arg = type_.__args__[0]
if is_dataclass(type_arg) or is_dataclass(value):
res = _decode_dataclass(type_arg, value)
res = _decode_dataclass(type_arg, value, infer_missing)
elif _is_supported_generic(type_arg):
res = _decode_generic(type_arg, value)
res = _decode_generic(type_arg, value, infer_missing)
else:
res = value
return res
Expand Down
17 changes: 16 additions & 1 deletion tests/test_entities.py → tests/entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import (Collection,
Deque,
FrozenSet,
Expand Down Expand Up @@ -44,6 +44,11 @@ class DataClassWithOptional(DataClassJsonMixin):
x: Optional[int]


@dataclass(frozen=True)
class DataClassWithOptionalRecursive(DataClassJsonMixin):
x: DataClassWithOptional


@dataclass(frozen=True)
class DataClassWithUnionIntNone(DataClassJsonMixin):
x: Union[int, None]
Expand All @@ -64,6 +69,16 @@ class DataClassXs(DataClassJsonMixin):
xs: List[DataClassX]


@dataclass(frozen=True)
class DataClassImmutableDefault(DataClassJsonMixin):
x: int = 0


@dataclass(frozen=True)
class DataClassMutableDefault(DataClassJsonMixin):
xs: List[int] = field(default_factory=list)


class MyCollection(Collection[A]):
def __init__(self, xs: Collection[A]) -> None:
self.xs = xs
Expand Down
41 changes: 32 additions & 9 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections import deque

from tests.test_entities import (DataClassWithDeque,
DataClassWithFrozenSet,
DataClassWithList,
DataClassWithMyCollection,
DataClassWithOptional,
DataClassWithSet,
DataClassWithTuple,
DataClassWithUnionIntNone,
MyCollection)
from tests.entities import (DataClassWithDeque,
DataClassWithFrozenSet,
DataClassWithList,
DataClassWithMyCollection,
DataClassWithOptional,
DataClassWithOptionalRecursive,
DataClassWithSet,
DataClassWithTuple,
DataClassWithUnionIntNone,
MyCollection,
DataClassImmutableDefault,
DataClassMutableDefault)


class TestEncoder:
Expand Down Expand Up @@ -39,6 +42,12 @@ def test_my_collection(self):
assert DataClassWithMyCollection(
MyCollection([1])).to_json() == '{"xs": [1]}'

def test_immutable_default(self):
assert DataClassImmutableDefault().to_json() == '{"x": 0}'

def test_mutable_default(self):
assert DataClassMutableDefault().to_json() == '{"xs": []}'


class TestDecoder:
def test_list(self):
Expand Down Expand Up @@ -71,10 +80,24 @@ def test_infer_missing(self):
actual = DataClassWithOptional.from_json('{}', infer_missing=True)
assert (actual == DataClassWithOptional(None))

def test_infer_missing_is_recursive(self):
actual = DataClassWithOptionalRecursive.from_json('{"x": null}',
infer_missing=True)
assert (actual == DataClassWithOptionalRecursive(
DataClassWithOptional(None)))

def test_my_collection(self):
assert (DataClassWithMyCollection.from_json('{"xs": [1]}') ==
DataClassWithMyCollection(MyCollection([1])))

def test_my_list_collection(self):
assert (DataClassWithMyCollection.from_json_array('[{"xs": [1]}]')
== [DataClassWithMyCollection(MyCollection([1]))])

def test_immutable_default(self):
assert (DataClassImmutableDefault.from_json('{"x": 0}')
== DataClassImmutableDefault())

def test_mutable_default(self):
assert (DataClassMutableDefault.from_json('{"xs": []}')
== DataClassMutableDefault())
2 changes: 1 addition & 1 deletion tests/test_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests.hypothesis2 import examples
from tests.hypothesis2.strategies import deques, optionals
from tests.test_entities import (DataClassWithDeque, DataClassWithFrozenSet,
from tests.entities import (DataClassWithDeque, DataClassWithFrozenSet,
DataClassWithList, DataClassWithOptional,
DataClassWithSet, DataClassWithTuple)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_nested.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.test_entities import (DataClassWithDataClass,
from tests.entities import (DataClassWithDataClass,
DataClassWithList,
DataClassX,
DataClassXs)
Expand Down