Skip to content

Commit

Permalink
Fix handling of default values on nested dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
jacebrowning committed Mar 22, 2020
1 parent 14afc98 commit 658dd91
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 30 deletions.
7 changes: 5 additions & 2 deletions datafiles/converters/containers.py
@@ -1,10 +1,10 @@
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import _MISSING_TYPE as Missing
from typing import Callable, Dict

import log

from ..utils import Missing, get_default_field_value
from ._bases import Converter


Expand Down Expand Up @@ -148,7 +148,10 @@ def to_python_value(cls, deserialized_data, *, target_object):
if name in data:
data[name] = converter.to_python_value(data[name], target_object=None)
else:
data[name] = converter.to_python_value(None, target_object=None)
if target_object is None:
data[name] = converter.to_python_value(None, target_object=None)
else:
data[name] = get_default_field_value(target_object, name)

new_value = cls.DATACLASS(**data) # pylint: disable=not-callable

Expand Down
28 changes: 8 additions & 20 deletions datafiles/mapper.py
Expand Up @@ -13,26 +13,14 @@

from . import config, formats, hooks
from .converters import Converter, List, map_type
from .utils import display, recursive_update, write


Trilean = Optional[bool]
Missing = dataclasses._MISSING_TYPE


def get_default_field_value(instance, name):
for field in dataclasses.fields(instance):
if field.name == name:
if not isinstance(field.default, Missing):
return field.default

if not isinstance(field.default_factory, Missing): # type: ignore
return field.default_factory() # type: ignore

if not field.init and hasattr(instance, '__post_init__'):
return getattr(instance, name)

return Missing
from .utils import (
Missing,
Trilean,
display,
get_default_field_value,
recursive_update,
write,
)


class Mapper:
Expand Down
24 changes: 22 additions & 2 deletions datafiles/utils.py
@@ -1,18 +1,38 @@
"""Internal helper functions."""

import dataclasses
from contextlib import suppress
from functools import lru_cache
from pathlib import Path
from pprint import pformat
from shutil import get_terminal_size
from typing import Dict, Union
from typing import Dict, Optional, Union

import log


Trilean = Optional[bool]
Missing = dataclasses._MISSING_TYPE


cached = lru_cache()


def get_default_field_value(instance, name):
for field in dataclasses.fields(instance):
if field.name == name:
if not isinstance(field.default, Missing):
return field.default

if not isinstance(field.default_factory, Missing): # type: ignore
return field.default_factory() # type: ignore

if not field.init and hasattr(instance, '__post_init__'):
return getattr(instance, name)

return Missing


def prettify(value) -> str:
"""Ensure value is a dictionary pretty-format it."""
return pformat(dictify(value))
Expand Down Expand Up @@ -108,4 +128,4 @@ def logbreak(message: str = "") -> None:
line = '-' * (width - len(message) - 1) + ' ' + message
else:
line = '-' * width
log.info(line)
log.critical(line)
97 changes: 96 additions & 1 deletion tests/test_loading.py
Expand Up @@ -2,9 +2,12 @@

# pylint: disable=unused-variable

from dataclasses import dataclass

import pytest

from datafiles.utils import logbreak, write
from datafiles import datafile
from datafiles.utils import dedent, logbreak, read, write

from .samples import (
Sample,
Expand Down Expand Up @@ -230,6 +233,98 @@ def with_extra_attributes(sample, expect):
expect(sample.nested.score) == 3.4
expect(hasattr(sample.nested, 'extra')) == False

def with_multiple_levels(expect):
@dataclass
class Bottom:
level: int = 4

@dataclass
class C:
level: int = 3
d: Bottom = Bottom()

@dataclass
class B:
level: int = 2
c: C = C()

@dataclass
class A:
level: int = 1
b: B = B()

@datafile('../tmp/sample.toml', defaults=True, auto_save=False)
class Top:
level: int = 0
a: A = A()

sample = Top()

expect(read('tmp/sample.toml')) == dedent(
"""
level = 0
[a]
level = 1
[a.b]
level = 2
[a.b.c]
level = 3
[a.b.c.d]
level = 4
"""
)

logbreak("Modifying attribute")
sample.a.b.c.d.level = 99

expect(read('tmp/sample.toml')) == dedent(
"""
level = 0
[a]
level = 1
[a.b]
level = 2
[a.b.c]
level = 3
[a.b.c.d]
level = 99
"""
)

write(
'tmp/sample.toml',
"""
level = 0
[a]
level = 10
[a.b]
level = 20
[a.b.c]
level = 30
[a.b.c.d]
level = 40
""",
)

logbreak("Reading attribute")
expect(sample.a.level) == 10
expect(sample.a.b.level) == 20
expect(sample.a.b.c.level) == 30

expect(sample.a.b.c.d.level) == 40


def describe_lists():
def with_matching_types(expect):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_patched_methods.py
Expand Up @@ -103,7 +103,7 @@ def with_delitem(expect):
def with_append(expect):
sample = Sample()

logbreak("Appending to list")
logbreak("Appending to list: 2")
sample.items.append(2)

expect(read('tmp/sample.yml')) == dedent(
Expand All @@ -116,7 +116,7 @@ def with_append(expect):

sample.datafile.load()

logbreak("Appending to list")
logbreak("Appending to list: 3")
sample.items.append(3)

expect(read('tmp/sample.yml')) == dedent(
Expand All @@ -128,11 +128,10 @@ def with_append(expect):
"""
)

@pytest.mark.xfail(reason="TODO: fix this")
def with_append_on_nested_dataclass(expect):
sample = SampleWithNesting(1)

logbreak("Appending to nested list")
logbreak("Appending to nested list: 2")
sample.nested.items.append(2)

expect(read('tmp/sample.yml')) == dedent(
Expand All @@ -144,7 +143,7 @@ def with_append_on_nested_dataclass(expect):
"""
)

logbreak("Appending to nested list")
logbreak("Appending to nested list: 3")
sample.nested.items.append(3)

expect(read('tmp/sample.yml')) == dedent(
Expand Down

0 comments on commit 658dd91

Please sign in to comment.