Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mog_commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.13'
__version__ = '0.1.14'
20 changes: 17 additions & 3 deletions src/mog_commons/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__all__ = [
'String',
'Unicode',
'Option',
'ListOf',
'TupleOf',
'SetOf',
Expand Down Expand Up @@ -101,6 +102,11 @@ def KwArg(cls):
return DictOf(String, cls)


def Option(cls):
"""Shorthand description for a type allowing NoneType"""
return cls + (type(None),) if isinstance(cls, tuple) else (cls, type(None))


#
# Helper functions
#
Expand All @@ -114,7 +120,12 @@ def _get_name(cls):


def _check_type(obj, cls):
return cls.check(obj) if isinstance(cls, ComposableType) else isinstance(obj, cls)
if isinstance(cls, ComposableType):
return cls.check(obj)
elif isinstance(cls, tuple):
return any(_check_type(obj, t) for t in cls)
else:
return isinstance(obj, cls)


#
Expand Down Expand Up @@ -147,12 +158,15 @@ def wrapper(*args, **kwargs):
for arg_name, expect in arg_types.items():
assert arg_name in call_args, 'Not found argument: %s' % arg_name
actual = call_args[arg_name]
assert _check_type(actual, expect), arg_msg % (arg_name, _get_name(expect), type(actual).__name__)
if not _check_type(actual, expect):
raise TypeError(arg_msg % (arg_name, _get_name(expect), type(actual).__name__))

ret = func(*args, **kwargs)
if return_type:
assert _check_type(ret, return_type[0]), return_msg % (_get_name(return_type[0]), type(ret).__name__)
if not _check_type(ret, return_type[0]):
raise TypeError(return_msg % (_get_name(return_type[0]), type(ret).__name__))
return ret

return wrapper

return f
32 changes: 28 additions & 4 deletions tests/mog_commons/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,49 @@ def err_func2():
@types(bool, int)
def f():
pass

return 1

@staticmethod
@types(bool)
def predicate():
return 1

@types(x=Option((int, float)))
def optional_func1(self, x):
return x

@types(int, xs=Option(ListOf(int)))
def optional_func2(self, xs=None):
return len(xs or [])

class Foo(object):
pass

def test_types(self):
str_type = '(basestring|str)' if six.PY2 else '(str|bytes)'

self.assertEqual(self.bin_func(10, 20), 30)
self.assertRaisesMessage(AssertionError, 'x must be int, not dict.', self.bin_func, {}, 20)
self.assertRaisesMessage(AssertionError, 'y must be int, not list.', self.bin_func, 10, [])
self.assertRaisesMessage(TypeError, 'x must be int, not dict.', self.bin_func, {}, 20)
self.assertRaisesMessage(TypeError, 'y must be int, not list.', self.bin_func, 10, [])

self.assertEqual(self.complex_func(123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}]), 1)
self.assertRaisesMessage(AssertionError, 'kw must be dict(%s->float), not dict.' % str_type,
self.assertRaisesMessage(TypeError, 'kw must be dict(%s->float), not dict.' % str_type,
self.complex_func, 123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}], x='12.3')

self.assertRaisesMessage(AssertionError, 'must return bool, not int.', self.predicate)
self.assertRaisesMessage(TypeError, 'must return bool, not int.', self.predicate)

self.assertEqual(self.optional_func1(None), None)
self.assertEqual(self.optional_func1(123), 123)
self.assertEqual(self.optional_func1(1.23), 1.23)
self.assertRaisesMessage(TypeError, 'x must be (int|float|NoneType), not Foo.',
self.optional_func1, self.Foo())

self.assertEqual(self.optional_func2(), 0)
self.assertEqual(self.optional_func2(None), 0)
self.assertEqual(self.optional_func2([1, 2, 3]), 3)
self.assertRaisesMessage(TypeError, 'xs must be (list(int)|NoneType), not list.',
self.optional_func2, [1, 2, 3.4])

def test_types_error(self):
self.assertRaisesMessage(AssertionError, 'Not found argument: a', self.err_func1, 123)
Expand Down