Skip to content

Commit

Permalink
Merge pull request #160 from jacebrowning/recursive-conversion
Browse files Browse the repository at this point in the history
Support unlimited nesting
  • Loading branch information
jacebrowning committed Mar 22, 2020
2 parents a10feb6 + 01a90e1 commit a16288d
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# 0.8 (unreleased)

- Updated the `@datafile(...)` decorator to be used as a drop-in replacement for `@dataclass(...)`.
- Added support for loading unlimited levels of nested objects.

# 0.7 (2020-02-20)

Expand Down
16 changes: 13 additions & 3 deletions datafiles/converters/containers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import _MISSING_TYPE as Missing
from typing import Callable, Dict

import log

from ..utils import Missing, get_default_field_value
from ._bases import Converter


Expand Down Expand Up @@ -144,8 +144,18 @@ def to_python_value(cls, deserialized_data, *, target_object):
data.pop(name)

for name, converter in cls.CONVERTERS.items():
if name not in data:
data[name] = converter.to_python_value(None, target_object=None)
log.debug(f"Converting '{name}' data with {converter}")
if name in data:
converted = converter.to_python_value(data[name], target_object=None)
else:
if target_object is None:
converted = converter.to_python_value(None, target_object=None)
else:
converted = get_default_field_value(target_object, name)
if converted is Missing:
converted = getattr(target_object, name)

data[name] = converted

new_value = cls.DATACLASS(**data) # pylint: disable=not-callable

Expand Down
125 changes: 43 additions & 82 deletions datafiles/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

from . import config, formats, hooks
from .converters import Converter, List, map_type
from .utils import display, recursive_update, write


Trilean = Optional[bool]
Missing = dataclasses._MISSING_TYPE
from .utils import (
Missing,
Trilean,
display,
get_default_field_value,
recursive_update,
write,
)


class Mapper:
Expand Down Expand Up @@ -147,11 +150,11 @@ def _get_data(self, include_default_values: Trilean = None) -> Dict:
value,
default_to_skip=Missing
if include_default_values
else self._get_default_field_value(name),
else get_default_field_value(self._instance, name),
)

elif (
value == self._get_default_field_value(name)
value == get_default_field_value(self._instance, name)
and not include_default_values
):
log.debug(f"Skipped default value of {value!r} for {name!r} attribute")
Expand All @@ -178,9 +181,9 @@ def _get_text(self, **kwargs) -> str:
def text(self, value: str):
write(self.path, value.strip() + '\n')

def load(self, *, _log=True, _first=False) -> None:
def load(self, *, _log=True, _first_load=False) -> None:
if self._root:
self._root.load(_log=_log, _first=_first)
self._root.load(_log=_log, _first_load=_first_load)
return

if self.path:
Expand All @@ -197,76 +200,48 @@ def load(self, *, _log=True, _first=False) -> None:

for name, value in data.items():
if name not in self.attrs and self.auto_attr:
cls: Any = type(value)
if issubclass(cls, list):
cls.__origin__ = list

if value:
item_cls = type(value[0])
for item in value:
if not isinstance(item, item_cls):
log.warn(f'{name!r} list type cannot be inferred')
item_cls = Converter
break
else:
log.warn(f'{name!r} list type cannot be inferred')
item_cls = Converter

log.debug(f'Inferring {name!r} type: {cls} of {item_cls}')
self.attrs[name] = map_type(cls, name=name, item_cls=item_cls)
elif issubclass(cls, dict):
cls.__origin__ = dict

log.debug(f'Inferring {name!r} type: {cls}')
self.attrs[name] = map_type(cls, name=name, item_cls=Converter)
else:
log.debug(f'Inferring {name!r} type: {cls}')
self.attrs[name] = map_type(cls, name=name)
self.attrs[name] = self._infer_attr(name, value)

for name, converter in self.attrs.items():
log.debug(f"Converting '{name}' data with {converter}")

if getattr(converter, 'DATACLASS', None):
self._set_dataclass_value(data, name, converter)
else:
self._set_attribute_value(data, name, converter, _first)
self._set_value(self._instance, name, converter, data, _first_load)

hooks.apply(self._instance, self)

self.modified = False

def _set_dataclass_value(self, data, name, converter):
# TODO: Support nesting unlimited levels
# https://github.com/jacebrowning/datafiles/issues/22
nested_data = data.get(name)
if nested_data is None:
return
@staticmethod
def _infer_attr(name, value):
cls: Any = type(value)
if issubclass(cls, list):
cls.__origin__ = list
if value:
item_cls = type(value[0])
for item in value:
if not isinstance(item, item_cls):
log.warn(f'{name!r} list type cannot be inferred')
item_cls = Converter
break
else:
log.warn(f'{name!r} list type cannot be inferred')
item_cls = Converter
log.debug(f'Inferring {name!r} type: {cls} of {item_cls}')
return map_type(cls, name=name, item_cls=item_cls)

log.debug(f'Converting nested data to Python: {nested_data}')
if issubclass(cls, dict):
cls.__origin__ = dict
log.debug(f'Inferring {name!r} type: {cls}')
return map_type(cls, name=name, item_cls=Converter)

dataclass = getattr(self._instance, name)
if dataclass is None:
for field in dataclasses.fields(converter.DATACLASS):
if field.name not in nested_data:
nested_data[field.name] = None
dataclass = converter.to_python_value(nested_data, target_object=dataclass)
log.debug(f'Inferring {name!r} type: {cls}')
return map_type(cls, name=name)

mapper = create_mapper(dataclass)
for name2, converter2 in mapper.attrs.items():
_value = nested_data.get(name2, mapper._get_default_field_value(name2))
value = converter2.to_python_value(
_value, target_object=getattr(dataclass, name2)
)
log.debug(f"'{name2}' as Python: {value!r}")
setattr(dataclass, name2, value)
@staticmethod
def _set_value(instance, name, converter, data, first_load):
log.debug(f"Converting '{name}' data with {converter}")

log.debug(f"Setting '{name}' value: {dataclass!r}")
setattr(self._instance, name, dataclass)

def _set_attribute_value(self, data, name, converter, first_load):
file_value = data.get(name, Missing)
init_value = getattr(self._instance, name, Missing)
default_value = self._get_default_field_value(name)
init_value = getattr(instance, name, Missing)
default_value = get_default_field_value(instance, name)

if first_load:
log.debug(
Expand All @@ -291,21 +266,7 @@ def _set_attribute_value(self, data, name, converter, first_load):
value = converter.to_python_value(file_value, target_object=init_value)

log.debug(f"Setting '{name}' value: {value!r}")
setattr(self._instance, name, value)

def _get_default_field_value(self, name):
for field in dataclasses.fields(self._instance):
if field.name == name:
if not isinstance(field.default, Missing):
return field.default

if not isinstance(field.default_factory, Missing): # type: ignore
return field.default_factory() # type: ignore

if not field.init and hasattr(self._instance, '__post_init__'):
return getattr(self._instance, name)

return Missing
setattr(instance, name, value)

def save(self, *, include_default_values: Trilean = None, _log=True) -> None:
if self._root:
Expand Down
2 changes: 1 addition & 1 deletion datafiles/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __post_init__(self):
log.debug(f'Datafile exists: {exists}')

if exists:
self.datafile.load(_first=True)
self.datafile.load(_first_load=True)
elif path and create:
self.datafile.save()

Expand Down
18 changes: 15 additions & 3 deletions datafiles/tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ class MyDataclass:
flag: bool = False


@dataclass
class MyNestedDataclass:
name: str
dc: MyDataclass


class MyNonDataclass:
pass


class MyNonDataclass2:
class MyCustomString:
pass


Expand All @@ -28,6 +34,7 @@ class MyNonDataclass2:
MyDict = converters.Dictionary.subclass(converters.String, converters.Integer)
MyDataclassConverter = converters.map_type(MyDataclass)
MyDataclassConverterList = converters.map_type(List[MyDataclass])
MyNestedDataclassConverter = converters.map_type(MyNestedDataclass)


def describe_map_type():
Expand Down Expand Up @@ -139,6 +146,11 @@ def when_immutable(expect, converter, data, value):
(MyDataclassConverter, None, MyDataclass(foobar=0)),
(MyDataclassConverterList, None, []),
(MyDataclassConverterList, 42, [MyDataclass(foobar=0)]),
(
MyNestedDataclassConverter,
None,
MyNestedDataclass(name='', dc=MyDataclass(foobar=0, flag=False)),
),
],
)
def when_mutable(expect, converter, data, value):
Expand Down Expand Up @@ -284,6 +296,6 @@ def when_dataclass_with_default(expect):

def describe_register():
def with_new_type(expect):
converters.register(MyNonDataclass2, converters.String)
converter = converters.map_type(MyNonDataclass2)
converters.register(MyCustomString, converters.String)
converter = converters.map_type(MyCustomString)
expect(converter) == converters.String
24 changes: 22 additions & 2 deletions datafiles/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
"""Internal helper functions."""

import dataclasses
from contextlib import suppress
from functools import lru_cache
from pathlib import Path
from pprint import pformat
from shutil import get_terminal_size
from typing import Dict, Union
from typing import Dict, Optional, Union

import log


Trilean = Optional[bool]
Missing = dataclasses._MISSING_TYPE


cached = lru_cache()


def get_default_field_value(instance, name):
for field in dataclasses.fields(instance):
if field.name == name:
if not isinstance(field.default, Missing):
return field.default

if not isinstance(field.default_factory, Missing): # type: ignore
return field.default_factory() # type: ignore

if not field.init and hasattr(instance, '__post_init__'):
return getattr(instance, name)

return Missing


def prettify(value) -> str:
"""Ensure value is a dictionary pretty-format it."""
return pformat(dictify(value))
Expand Down Expand Up @@ -108,4 +128,4 @@ def logbreak(message: str = "") -> None:
line = '-' * (width - len(message) - 1) + ' ' + message
else:
line = '-' * width
log.info(line)
log.critical(line)
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]

name = "datafiles"
version = "0.8b1"
version = "0.8b2"
description = "File-based ORM for dataclasses."

license = "MIT"
Expand Down Expand Up @@ -76,11 +76,11 @@ pytest-mock = "*"
pytest-random = "*"
pytest-repeat = "*"
pytest-watch = "*"
pytest-cov = "*"
pytest-cov = "^2.8.1"
pytest-profiling = "*"

# Coverage
coveragespace = "^3.1"
coveragespace = "^3.1.1"

# Documentation
mkdocs = "~1.0"
Expand Down
Loading

0 comments on commit a16288d

Please sign in to comment.