Skip to content
This repository has been archived by the owner on Aug 1, 2021. It is now read-only.

Commit

Permalink
[types] cleanup model loading and make consume greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
b1naryth1ef committed Jun 23, 2017
1 parent 359795e commit cca53f0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
51 changes: 29 additions & 22 deletions disco/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def name(self, name):
def has_default(self):
return self.default is not None

def try_convert(self, raw, client):
def try_convert(self, raw, client, **kwargs):
try:
return self.deserializer(raw, client)
return self.deserializer(raw, client, **kwargs)
except Exception as e:
six.reraise(ConversionError, ConversionError(self, raw, e))

Expand All @@ -94,11 +94,16 @@ def type_to_deserializer(typ):
if isinstance(typ, Field) or inspect.isclass(typ) and issubclass(typ, Model):
return typ
elif isinstance(typ, BaseEnumMeta):
return lambda raw, _: typ.get(raw)
def _f(raw, client, **kwargs):
return typ.get(raw)
return _f
elif typ is None:
return lambda x, y: None
def _f(*args, **kwargs):
return None
else:
return lambda raw, _: typ(raw)
def _f(raw, client, **kwargs):
return typ(raw)
return _f

@staticmethod
def serialize(value, inst=None):
Expand All @@ -111,8 +116,8 @@ def serialize(value, inst=None):
return inst.cast(value)
return value

def __call__(self, raw, client):
return self.try_convert(raw, client)
def __call__(self, raw, client, **kwargs):
return self.try_convert(raw, client, **kwargs)


class DictField(Field):
Expand All @@ -132,7 +137,7 @@ def serialize(value, inst=None):
if k not in (inst.ignore_dump if inst else [])
}

def try_convert(self, raw, client):
def try_convert(self, raw, client, **kwargs):
return HashMap({
self.key_de(k, client): self.value_de(v, client) for k, v in six.iteritems(raw)
})
Expand All @@ -145,7 +150,7 @@ class ListField(Field):
def serialize(value, inst=None):
return list(map(Field.serialize, value))

def try_convert(self, raw, client):
def try_convert(self, raw, client, **kwargs):
return [self.deserializer(i, client) for i in raw]


Expand All @@ -157,7 +162,7 @@ def __init__(self, value_type, key, **kwargs):
self.value_de = self.type_to_deserializer(value_type)
self.key = key

def try_convert(self, raw, client):
def try_convert(self, raw, client, **kwargs):
return HashMap({
getattr(b, self.key): b for b in (self.value_de(a, client) for a in raw)
})
Expand Down Expand Up @@ -274,8 +279,9 @@ def __init__(self, *args, **kwargs):
obj, self.client = args
else:
obj = kwargs
kwargs = {}

self.load(obj)
self.load(obj, **kwargs)
self.validate()

def after(self, delay):
Expand All @@ -289,31 +295,32 @@ def validate(self):
def _fields(self):
return self.__class__._fields

def load(self, obj, consume=False, skip=None):
return self.load_into(self, obj, consume, skip)
def load(self, *args, **kwargs):
return self.load_into(self, *args, **kwargs)

@classmethod
def load_into(cls, inst, obj, consume=False, skip=None):
def load_into(cls, inst, obj, consume=False):
for name, field in six.iteritems(cls._fields):
should_skip = skip and name in skip
try:
raw = obj[field.src_name]

if consume and not should_skip:
raw = obj.pop(field.src_name, UNSET)
else:
raw = obj.get(field.src_name, UNSET)
if consume and not isinstance(raw, dict):
del obj[field.src_name]
except KeyError:
raw = UNSET

# If the field is unset/none, and we have a default we need to set it
if (raw in (None, UNSET) or should_skip) and field.has_default():
if raw in (None, UNSET) and field.has_default():
default = field.default() if callable(field.default) else field.default
setattr(inst, field.dst_name, default)
continue

# Otherwise if the field is UNSET and has no default, skip conversion
if raw is UNSET or should_skip:
if raw is UNSET:
setattr(inst, field.dst_name, raw)
continue

value = field.try_convert(raw, inst.client)
value = field.try_convert(raw, inst.client, consume=consume)
setattr(inst, field.dst_name, value)

def update(self, other, ignored=None):
Expand Down
46 changes: 36 additions & 10 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,58 @@
from disco.types.base import Model, Field


class _M(Model):
class _A(Model):
a = Field(int)
b = Field(float)
c = Field(str)


class _B(Model):
a = Field(int)
b = Field(float)
c = Field(str)


class _C(Model):
a = Field(_A)
b = Field(_B)


class TestModel(TestCase):
def test_model_simple_loading(self):
inst = _M(dict(a=1, b=1.1, c='test'))
inst = _A(dict(a=1, b=1.1, c='test'))
self.assertEquals(inst.a, 1)
self.assertEquals(inst.b, 1.1)
self.assertEquals(inst.c, 'test')

def test_model_load_into(self):
inst = _M()
_M.load_into(inst, dict(a=1, b=1.1, c='test'))
inst = _A()
_A.load_into(inst, dict(a=1, b=1.1, c='test'))
self.assertEquals(inst.a, 1)
self.assertEquals(inst.b, 1.1)
self.assertEquals(inst.c, 'test')

def test_model_loading_consume(self):
obj = dict(a=5, b=33.33, c='wtf')
inst = _M()
obj = {
'a': {
'a': 1,
'b': 2.2,
'c': '3',
'd': 'wow',
},
'b': {
'a': 3,
'b': 2.2,
'c': '1',
'z': 'wtf'
},
'g': 'lmao'
}

inst = _C()
inst.load(obj, consume=True)

self.assertEquals(obj, {})
self.assertEquals(inst.a, 5)
self.assertEquals(inst.b, 33.33)
self.assertEquals(inst.c, 'wtf')
self.assertEquals(inst.a.a, 1)
self.assertEquals(inst.b.c, '1')

self.assertEquals(obj, {'a': {'d': 'wow'}, 'b': {'z': 'wtf'}, 'g': 'lmao'})

0 comments on commit cca53f0

Please sign in to comment.