diff --git a/src/mog_commons/__init__.py b/src/mog_commons/__init__.py index 377e1f6..112abf1 100644 --- a/src/mog_commons/__init__.py +++ b/src/mog_commons/__init__.py @@ -1 +1 @@ -__version__ = '0.1.13' +__version__ = '0.1.14' diff --git a/src/mog_commons/types.py b/src/mog_commons/types.py index 50840eb..8190473 100644 --- a/src/mog_commons/types.py +++ b/src/mog_commons/types.py @@ -12,6 +12,7 @@ __all__ = [ 'String', 'Unicode', + 'Option', 'ListOf', 'TupleOf', 'SetOf', @@ -101,6 +102,11 @@ def KwArg(cls): return DictOf(String, cls) +def Option(cls): + """Shorthand description for a type allowing NoneType""" + return cls + (type(None),) if isinstance(cls, tuple) else (cls, type(None)) + + # # Helper functions # @@ -114,7 +120,12 @@ def _get_name(cls): def _check_type(obj, cls): - return cls.check(obj) if isinstance(cls, ComposableType) else isinstance(obj, cls) + if isinstance(cls, ComposableType): + return cls.check(obj) + elif isinstance(cls, tuple): + return any(_check_type(obj, t) for t in cls) + else: + return isinstance(obj, cls) # @@ -147,12 +158,15 @@ def wrapper(*args, **kwargs): for arg_name, expect in arg_types.items(): assert arg_name in call_args, 'Not found argument: %s' % arg_name actual = call_args[arg_name] - assert _check_type(actual, expect), arg_msg % (arg_name, _get_name(expect), type(actual).__name__) + if not _check_type(actual, expect): + raise TypeError(arg_msg % (arg_name, _get_name(expect), type(actual).__name__)) ret = func(*args, **kwargs) if return_type: - assert _check_type(ret, return_type[0]), return_msg % (_get_name(return_type[0]), type(ret).__name__) + if not _check_type(ret, return_type[0]): + raise TypeError(return_msg % (_get_name(return_type[0]), type(ret).__name__)) return ret + return wrapper return f diff --git a/tests/mog_commons/test_types.py b/tests/mog_commons/test_types.py index 67140f3..f2edfc3 100644 --- a/tests/mog_commons/test_types.py +++ b/tests/mog_commons/test_types.py @@ -29,6 +29,7 @@ def err_func2(): @types(bool, int) def f(): pass + return 1 @staticmethod @@ -36,18 +37,41 @@ def f(): def predicate(): return 1 + @types(x=Option((int, float))) + def optional_func1(self, x): + return x + + @types(int, xs=Option(ListOf(int))) + def optional_func2(self, xs=None): + return len(xs or []) + + class Foo(object): + pass + def test_types(self): str_type = '(basestring|str)' if six.PY2 else '(str|bytes)' self.assertEqual(self.bin_func(10, 20), 30) - self.assertRaisesMessage(AssertionError, 'x must be int, not dict.', self.bin_func, {}, 20) - self.assertRaisesMessage(AssertionError, 'y must be int, not list.', self.bin_func, 10, []) + self.assertRaisesMessage(TypeError, 'x must be int, not dict.', self.bin_func, {}, 20) + self.assertRaisesMessage(TypeError, 'y must be int, not list.', self.bin_func, 10, []) self.assertEqual(self.complex_func(123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}]), 1) - self.assertRaisesMessage(AssertionError, 'kw must be dict(%s->float), not dict.' % str_type, + self.assertRaisesMessage(TypeError, 'kw must be dict(%s->float), not dict.' % str_type, self.complex_func, 123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}], x='12.3') - self.assertRaisesMessage(AssertionError, 'must return bool, not int.', self.predicate) + self.assertRaisesMessage(TypeError, 'must return bool, not int.', self.predicate) + + self.assertEqual(self.optional_func1(None), None) + self.assertEqual(self.optional_func1(123), 123) + self.assertEqual(self.optional_func1(1.23), 1.23) + self.assertRaisesMessage(TypeError, 'x must be (int|float|NoneType), not Foo.', + self.optional_func1, self.Foo()) + + self.assertEqual(self.optional_func2(), 0) + self.assertEqual(self.optional_func2(None), 0) + self.assertEqual(self.optional_func2([1, 2, 3]), 3) + self.assertRaisesMessage(TypeError, 'xs must be (list(int)|NoneType), not list.', + self.optional_func2, [1, 2, 3.4]) def test_types_error(self): self.assertRaisesMessage(AssertionError, 'Not found argument: a', self.err_func1, 123)