Skip to content

Commit

Permalink
v0.5.0 - protecting defaults() -> _defaults()
Browse files Browse the repository at this point in the history
  • Loading branch information
kpe committed Mar 27, 2019
1 parent 0a8fdf3 commit 51a3768
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion params/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

from .params import Params

__version__ = '0.4.1'
__version__ = '0.5.0'
27 changes: 14 additions & 13 deletions params/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,43 @@ class MyParams(Params):

def __init__(self, *args, **kwargs):
super(Params, self).__init__()
self.update(self.__class__.defaults())
self.update(self._defaults())
self.update(dict(*args))
self.update(kwargs)

def __getattribute__(self, attr):
if attr != 'defaults' and attr in self.defaults():
if not attr.startswith("_") and attr in self._defaults():
return self.__getitem__(attr)
return object.__getattribute__(self, attr)

def __setattr__(self, key, value):
self.__setitem__(key, value)

def __setitem__(self, key, value):
if key not in self.defaults():
if key not in self._defaults():
raise AttributeError("Setting unexpected parameter '{}' "
"in Params instance '{}'".format(key, self.__class__.__name__))
super(Params, self).__setitem__(key, value)

@classmethod
def defaults(cls):
def _defaults(cls):
""" Aggregate all class fields in the class hierarchy to a dict. """

if '_defaults' in cls.__dict__:
return cls.__dict__['_defaults']
if '__defaults' in cls.__dict__:
return cls.__defaults

result = {}
for base in cls.__bases__:
if issubclass(base, Params):
result.update(base.defaults())
result.update(base._defaults())

result.update(dict(filter(lambda t: (not t[0].startswith('_') and
not callable(getattr(cls, t[0]))),
cls.__dict__.items())))
for attr, value in cls.__dict__.items():
if attr.startswith("_") or callable(getattr(cls, attr)):
continue
result[attr] = value

cls._defaults = result
return cls._defaults
cls.__defaults = result
return cls.__defaults

@classmethod
def from_dict(cls, args, return_instance=True, return_unused=True):
Expand All @@ -86,7 +87,7 @@ def is_not_none(x):
cls_args, unused_args = {}, {}
if args:
# extract unused args
keys = cls.defaults().keys()
keys = cls._defaults().keys()
cls_args, unused_args = zip(*list(map(lambda p: (p, None) if p[0] in keys else (None, p),
args.items())))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MyParams(Params):
class ParamsConstructionTest(unittest.TestCase):
def test_defaults(self):
expected = {'param_a': True, 'param_b': 1}
self.assertEqual(expected, MyParams.defaults())
self.assertEqual(expected, dict(MyParams()))

params = MyParams()
self.assertEqual(expected, dict(params))
Expand Down

0 comments on commit 51a3768

Please sign in to comment.