diff --git a/.travis.yml b/.travis.yml index b84fe5a..a151ee5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ install: - pip install coveralls script: - - coverage run --source=src setup.py test + - coverage run --source=src --omit='src/mog_commons/backported/*' setup.py test after_success: - coveralls diff --git a/src/mog_commons/__init__.py b/src/mog_commons/__init__.py index e6d0c4f..377e1f6 100644 --- a/src/mog_commons/__init__.py +++ b/src/mog_commons/__init__.py @@ -1 +1 @@ -__version__ = '0.1.12' +__version__ = '0.1.13' diff --git a/src/mog_commons/backported/__init__.py b/src/mog_commons/backported/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mog_commons/backported/inspect2.py b/src/mog_commons/backported/inspect2.py new file mode 100644 index 0000000..c07df72 --- /dev/null +++ b/src/mog_commons/backported/inspect2.py @@ -0,0 +1,102 @@ +# backported from https://github.com/python/cpython/blob/2.7/Lib/inspect.py + +import sys +from inspect import ismethod, getargspec + + +def getcallargs(func, *positional, **named): + """Get the mapping of arguments to values. + + A dict is returned, with keys the function argument names (including the + names of the * and ** arguments, if any), and values the respective bound + values from 'positional' and 'named'.""" + args, varargs, varkw, defaults = getargspec(func) + f_name = func.__name__ + arg2value = {} + + # The following closures are basically because of tuple parameter unpacking. + assigned_tuple_params = [] + + def assign(arg, value): + if isinstance(arg, str): + arg2value[arg] = value + else: + assigned_tuple_params.append(arg) + value = iter(value) + for i, subarg in enumerate(arg): + try: + subvalue = next(value) + except StopIteration: + raise ValueError('need more than %d %s to unpack' % + (i, 'values' if i > 1 else 'value')) + assign(subarg, subvalue) + try: + next(value) + except StopIteration: + pass + else: + raise ValueError('too many values to unpack') + + def is_assigned(arg): + if isinstance(arg, str): + return arg in arg2value + return arg in assigned_tuple_params + + if ismethod(func) and func.im_self is not None: + # implicit 'self' (or 'cls' for classmethods) argument + positional = (func.im_self,) + positional + num_pos = len(positional) + num_total = num_pos + len(named) + num_args = len(args) + num_defaults = len(defaults) if defaults else 0 + for arg, value in zip(args, positional): + assign(arg, value) + if varargs: + if num_pos > num_args: + assign(varargs, positional[-(num_pos - num_args):]) + else: + assign(varargs, ()) + elif 0 < num_args < num_pos: + raise TypeError('%s() takes %s %d %s (%d given)' % ( + f_name, 'at most' if defaults else 'exactly', num_args, + 'arguments' if num_args > 1 else 'argument', num_total)) + elif num_args == 0 and num_total: + if varkw: + if num_pos: + # XXX: We should use num_pos, but Python also uses num_total: + raise TypeError('%s() takes exactly 0 arguments ' + '(%d given)' % (f_name, num_total)) + else: + raise TypeError('%s() takes no arguments (%d given)' % + (f_name, num_total)) + for arg in args: + if isinstance(arg, str) and arg in named: + if is_assigned(arg): + raise TypeError("%s() got multiple values for keyword " + "argument '%s'" % (f_name, arg)) + else: + assign(arg, named.pop(arg)) + if defaults: # fill in any missing values with the defaults + for arg, value in zip(args[-num_defaults:], defaults): + if not is_assigned(arg): + assign(arg, value) + if varkw: + assign(varkw, named) + elif named: + unexpected = next(iter(named)) + try: + unicode + except NameError: + pass + else: + if isinstance(unexpected, unicode): + unexpected = unexpected.encode(sys.getdefaultencoding(), 'replace') + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (f_name, unexpected)) + unassigned = num_args - len([arg for arg in args if is_assigned(arg)]) + if unassigned: + num_required = num_args - num_defaults + raise TypeError('%s() takes %s %d %s (%d given)' % ( + f_name, 'at least' if defaults else 'exactly', num_required, + 'arguments' if num_required > 1 else 'argument', num_total)) + return arg2value diff --git a/src/mog_commons/types.py b/src/mog_commons/types.py new file mode 100644 index 0000000..50840eb --- /dev/null +++ b/src/mog_commons/types.py @@ -0,0 +1,158 @@ +from __future__ import division, print_function, absolute_import, unicode_literals + +import sys +import six +from abc import ABCMeta, abstractmethod + +if sys.version_info < (2, 7): + from mog_commons.backported.inspect2 import getcallargs +else: + from inspect import getcallargs + +__all__ = [ + 'String', + 'Unicode', + 'ListOf', + 'TupleOf', + 'SetOf', + 'DictOf', + 'VarArg', + 'KwArg', + 'types', +] + +# +# Type definitions +# +String = six.string_types + (bytes,) + +Unicode = unicode if six.PY2 else str + + +@six.add_metaclass(ABCMeta) +class ComposableType(object): + """Label for composable types""" + + @abstractmethod + def name(self): + """abstract method""" + + @abstractmethod + def check(self, obj): + """abstract method""" + + +@six.add_metaclass(ABCMeta) +class IterableOf(ComposableType): + def __init__(self, iterable_type, elem_type): + self.iterable_type = iterable_type + self.elem_type = elem_type + + def name(self): + return '%s(%s)' % (self.iterable_type.__name__, _get_name(self.elem_type)) + + def check(self, obj): + return isinstance(obj, self.iterable_type) and all(_check_type(elem, self.elem_type) for elem in obj) + + +class ListOf(IterableOf): + """Label for list element type assertion""" + + def __init__(self, elem_type): + IterableOf.__init__(self, list, elem_type) + + +class TupleOf(IterableOf): + """Label for tuple element type assertion""" + + def __init__(self, elem_type): + IterableOf.__init__(self, tuple, elem_type) + + +class SetOf(IterableOf): + """Label for set element type assertion""" + + def __init__(self, elem_type): + IterableOf.__init__(self, set, elem_type) + + +class DictOf(ComposableType): + """Label for dict element type assertion""" + + def __init__(self, key_type, value_type): + self.key_type = key_type + self.value_type = value_type + + def name(self): + return 'dict(%s->%s)' % (_get_name(self.key_type), _get_name(self.value_type)) + + def check(self, obj): + return isinstance(obj, dict) and all( + _check_type(k, self.key_type) and _check_type(v, self.value_type) for k, v in obj.items()) + + +def VarArg(cls): + """Shorthand description for var arg""" + return TupleOf(cls) + + +def KwArg(cls): + """Shorthand description for keyword arg""" + return DictOf(String, cls) + + +# +# Helper functions +# +def _get_name(cls): + if isinstance(cls, ComposableType): + return cls.name() + if isinstance(cls, tuple): + return '(%s)' % '|'.join(_get_name(t) for t in cls) + else: + return cls.__name__ + + +def _check_type(obj, cls): + return cls.check(obj) if isinstance(cls, ComposableType) else isinstance(obj, cls) + + +# +# Decorators +# +def types(*return_type, **arg_types): + """ + Assert types of the function arguments and return value. + :param return_type: expected type of the return value + :param arg_types: expected types of the arguments + + :example: + @types(float, x=int, y=float, z=ListOf(int)) + def f(x, y, z): + return x * y + sum(z) + """ + assert len(return_type) <= 1, 'You can specify at most one return type.' + + arg_msg = '%s must be %s, not %s.' + return_msg = 'must return %s, not %s.' + + def f(func): + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # convert args to call args + call_args = getcallargs(func, *args, **kwargs) + + for arg_name, expect in arg_types.items(): + assert arg_name in call_args, 'Not found argument: %s' % arg_name + actual = call_args[arg_name] + assert _check_type(actual, expect), arg_msg % (arg_name, _get_name(expect), type(actual).__name__) + + ret = func(*args, **kwargs) + if return_type: + assert _check_type(ret, return_type[0]), return_msg % (_get_name(return_type[0]), type(ret).__name__) + return ret + return wrapper + + return f diff --git a/tests/mog_commons/test_types.py b/tests/mog_commons/test_types.py new file mode 100644 index 0000000..67140f3 --- /dev/null +++ b/tests/mog_commons/test_types.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +from __future__ import division, print_function, absolute_import, unicode_literals + +import six +from mog_commons import unittest +from mog_commons.types import * +from mog_commons.types import _get_name + + +class TestTypes(unittest.TestCase): + @staticmethod + @types(int, x=int, y=int) + def bin_func(x, y): + return x + y + + @staticmethod + @types(p1=int, p2=ListOf(int), p3=int, p4=String, p5=String, k=VarArg(ListOf(DictOf(String, SetOf(int)))), + kw=KwArg(float)) + def complex_func(p1, p2, p3, p4='xxx', p5='yyy', *k, **kw): + return 1 + + @staticmethod + @types(a=str) + def err_func1(b): + return b + + @staticmethod + def err_func2(): + @types(bool, int) + def f(): + pass + return 1 + + @staticmethod + @types(bool) + def predicate(): + return 1 + + def test_types(self): + str_type = '(basestring|str)' if six.PY2 else '(str|bytes)' + + self.assertEqual(self.bin_func(10, 20), 30) + self.assertRaisesMessage(AssertionError, 'x must be int, not dict.', self.bin_func, {}, 20) + self.assertRaisesMessage(AssertionError, 'y must be int, not list.', self.bin_func, 10, []) + + self.assertEqual(self.complex_func(123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}]), 1) + self.assertRaisesMessage(AssertionError, 'kw must be dict(%s->float), not dict.' % str_type, + self.complex_func, 123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}], x='12.3') + + self.assertRaisesMessage(AssertionError, 'must return bool, not int.', self.predicate) + + def test_types_error(self): + self.assertRaisesMessage(AssertionError, 'Not found argument: a', self.err_func1, 123) + self.assertRaisesMessage(AssertionError, 'You can specify at most one return type.', self.err_func2) + + def test_get_name(self): + str_type = '(basestring|str)' if six.PY2 else '(str|bytes)' + unicode_type = 'unicode' if six.PY2 else 'str' + + self.assertEqual(_get_name(int), 'int') + self.assertEqual(_get_name(String), str_type) + self.assertEqual(_get_name(Unicode), unicode_type) + self.assertEqual(_get_name(TupleOf(int)), 'tuple(int)') + self.assertEqual(_get_name(ListOf(int)), 'list(int)') + self.assertEqual(_get_name(SetOf(int)), 'set(int)') + self.assertEqual(_get_name(DictOf(int, float)), 'dict(int->float)') + self.assertEqual(_get_name(VarArg(str)), 'tuple(str)') + self.assertEqual(_get_name(KwArg(str)), 'dict(%s->str)' % str_type) + self.assertEqual(_get_name(KwArg(ListOf(DictOf(float, SetOf((str, TupleOf(Unicode))))))), + 'dict(%s->list(dict(float->set((str|tuple(%s))))))' % (str_type, unicode_type))