Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ install:
- pip install coveralls

script:
- coverage run --source=src setup.py test
- coverage run --source=src --omit='src/mog_commons/backported/*' setup.py test

after_success:
- coveralls
Expand Down
2 changes: 1 addition & 1 deletion src/mog_commons/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.12'
__version__ = '0.1.13'
Empty file.
102 changes: 102 additions & 0 deletions src/mog_commons/backported/inspect2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# backported from https://github.com/python/cpython/blob/2.7/Lib/inspect.py

import sys
from inspect import ismethod, getargspec


def getcallargs(func, *positional, **named):
"""Get the mapping of arguments to values.

A dict is returned, with keys the function argument names (including the
names of the * and ** arguments, if any), and values the respective bound
values from 'positional' and 'named'."""
args, varargs, varkw, defaults = getargspec(func)
f_name = func.__name__
arg2value = {}

# The following closures are basically because of tuple parameter unpacking.
assigned_tuple_params = []

def assign(arg, value):
if isinstance(arg, str):
arg2value[arg] = value
else:
assigned_tuple_params.append(arg)
value = iter(value)
for i, subarg in enumerate(arg):
try:
subvalue = next(value)
except StopIteration:
raise ValueError('need more than %d %s to unpack' %
(i, 'values' if i > 1 else 'value'))
assign(subarg, subvalue)
try:
next(value)
except StopIteration:
pass
else:
raise ValueError('too many values to unpack')

def is_assigned(arg):
if isinstance(arg, str):
return arg in arg2value
return arg in assigned_tuple_params

if ismethod(func) and func.im_self is not None:
# implicit 'self' (or 'cls' for classmethods) argument
positional = (func.im_self,) + positional
num_pos = len(positional)
num_total = num_pos + len(named)
num_args = len(args)
num_defaults = len(defaults) if defaults else 0
for arg, value in zip(args, positional):
assign(arg, value)
if varargs:
if num_pos > num_args:
assign(varargs, positional[-(num_pos - num_args):])
else:
assign(varargs, ())
elif 0 < num_args < num_pos:
raise TypeError('%s() takes %s %d %s (%d given)' % (
f_name, 'at most' if defaults else 'exactly', num_args,
'arguments' if num_args > 1 else 'argument', num_total))
elif num_args == 0 and num_total:
if varkw:
if num_pos:
# XXX: We should use num_pos, but Python also uses num_total:
raise TypeError('%s() takes exactly 0 arguments '
'(%d given)' % (f_name, num_total))
else:
raise TypeError('%s() takes no arguments (%d given)' %
(f_name, num_total))
for arg in args:
if isinstance(arg, str) and arg in named:
if is_assigned(arg):
raise TypeError("%s() got multiple values for keyword "
"argument '%s'" % (f_name, arg))
else:
assign(arg, named.pop(arg))
if defaults: # fill in any missing values with the defaults
for arg, value in zip(args[-num_defaults:], defaults):
if not is_assigned(arg):
assign(arg, value)
if varkw:
assign(varkw, named)
elif named:
unexpected = next(iter(named))
try:
unicode
except NameError:
pass
else:
if isinstance(unexpected, unicode):
unexpected = unexpected.encode(sys.getdefaultencoding(), 'replace')
raise TypeError("%s() got an unexpected keyword argument '%s'" %
(f_name, unexpected))
unassigned = num_args - len([arg for arg in args if is_assigned(arg)])
if unassigned:
num_required = num_args - num_defaults
raise TypeError('%s() takes %s %d %s (%d given)' % (
f_name, 'at least' if defaults else 'exactly', num_required,
'arguments' if num_required > 1 else 'argument', num_total))
return arg2value
158 changes: 158 additions & 0 deletions src/mog_commons/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import division, print_function, absolute_import, unicode_literals

import sys
import six
from abc import ABCMeta, abstractmethod

if sys.version_info < (2, 7):
from mog_commons.backported.inspect2 import getcallargs
else:
from inspect import getcallargs

__all__ = [
'String',
'Unicode',
'ListOf',
'TupleOf',
'SetOf',
'DictOf',
'VarArg',
'KwArg',
'types',
]

#
# Type definitions
#
String = six.string_types + (bytes,)

Unicode = unicode if six.PY2 else str


@six.add_metaclass(ABCMeta)
class ComposableType(object):
"""Label for composable types"""

@abstractmethod
def name(self):
"""abstract method"""

@abstractmethod
def check(self, obj):
"""abstract method"""


@six.add_metaclass(ABCMeta)
class IterableOf(ComposableType):
def __init__(self, iterable_type, elem_type):
self.iterable_type = iterable_type
self.elem_type = elem_type

def name(self):
return '%s(%s)' % (self.iterable_type.__name__, _get_name(self.elem_type))

def check(self, obj):
return isinstance(obj, self.iterable_type) and all(_check_type(elem, self.elem_type) for elem in obj)


class ListOf(IterableOf):
"""Label for list element type assertion"""

def __init__(self, elem_type):
IterableOf.__init__(self, list, elem_type)


class TupleOf(IterableOf):
"""Label for tuple element type assertion"""

def __init__(self, elem_type):
IterableOf.__init__(self, tuple, elem_type)


class SetOf(IterableOf):
"""Label for set element type assertion"""

def __init__(self, elem_type):
IterableOf.__init__(self, set, elem_type)


class DictOf(ComposableType):
"""Label for dict element type assertion"""

def __init__(self, key_type, value_type):
self.key_type = key_type
self.value_type = value_type

def name(self):
return 'dict(%s->%s)' % (_get_name(self.key_type), _get_name(self.value_type))

def check(self, obj):
return isinstance(obj, dict) and all(
_check_type(k, self.key_type) and _check_type(v, self.value_type) for k, v in obj.items())


def VarArg(cls):
"""Shorthand description for var arg"""
return TupleOf(cls)


def KwArg(cls):
"""Shorthand description for keyword arg"""
return DictOf(String, cls)


#
# Helper functions
#
def _get_name(cls):
if isinstance(cls, ComposableType):
return cls.name()
if isinstance(cls, tuple):
return '(%s)' % '|'.join(_get_name(t) for t in cls)
else:
return cls.__name__


def _check_type(obj, cls):
return cls.check(obj) if isinstance(cls, ComposableType) else isinstance(obj, cls)


#
# Decorators
#
def types(*return_type, **arg_types):
"""
Assert types of the function arguments and return value.
:param return_type: expected type of the return value
:param arg_types: expected types of the arguments

:example:
@types(float, x=int, y=float, z=ListOf(int))
def f(x, y, z):
return x * y + sum(z)
"""
assert len(return_type) <= 1, 'You can specify at most one return type.'

arg_msg = '%s must be %s, not %s.'
return_msg = 'must return %s, not %s.'

def f(func):
import functools

@functools.wraps(func)
def wrapper(*args, **kwargs):
# convert args to call args
call_args = getcallargs(func, *args, **kwargs)

for arg_name, expect in arg_types.items():
assert arg_name in call_args, 'Not found argument: %s' % arg_name
actual = call_args[arg_name]
assert _check_type(actual, expect), arg_msg % (arg_name, _get_name(expect), type(actual).__name__)

ret = func(*args, **kwargs)
if return_type:
assert _check_type(ret, return_type[0]), return_msg % (_get_name(return_type[0]), type(ret).__name__)
return ret
return wrapper

return f
70 changes: 70 additions & 0 deletions tests/mog_commons/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import, unicode_literals

import six
from mog_commons import unittest
from mog_commons.types import *
from mog_commons.types import _get_name


class TestTypes(unittest.TestCase):
@staticmethod
@types(int, x=int, y=int)
def bin_func(x, y):
return x + y

@staticmethod
@types(p1=int, p2=ListOf(int), p3=int, p4=String, p5=String, k=VarArg(ListOf(DictOf(String, SetOf(int)))),
kw=KwArg(float))
def complex_func(p1, p2, p3, p4='xxx', p5='yyy', *k, **kw):
return 1

@staticmethod
@types(a=str)
def err_func1(b):
return b

@staticmethod
def err_func2():
@types(bool, int)
def f():
pass
return 1

@staticmethod
@types(bool)
def predicate():
return 1

def test_types(self):
str_type = '(basestring|str)' if six.PY2 else '(str|bytes)'

self.assertEqual(self.bin_func(10, 20), 30)
self.assertRaisesMessage(AssertionError, 'x must be int, not dict.', self.bin_func, {}, 20)
self.assertRaisesMessage(AssertionError, 'y must be int, not list.', self.bin_func, 10, [])

self.assertEqual(self.complex_func(123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}]), 1)
self.assertRaisesMessage(AssertionError, 'kw must be dict(%s->float), not dict.' % str_type,
self.complex_func, 123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}], x='12.3')

self.assertRaisesMessage(AssertionError, 'must return bool, not int.', self.predicate)

def test_types_error(self):
self.assertRaisesMessage(AssertionError, 'Not found argument: a', self.err_func1, 123)
self.assertRaisesMessage(AssertionError, 'You can specify at most one return type.', self.err_func2)

def test_get_name(self):
str_type = '(basestring|str)' if six.PY2 else '(str|bytes)'
unicode_type = 'unicode' if six.PY2 else 'str'

self.assertEqual(_get_name(int), 'int')
self.assertEqual(_get_name(String), str_type)
self.assertEqual(_get_name(Unicode), unicode_type)
self.assertEqual(_get_name(TupleOf(int)), 'tuple(int)')
self.assertEqual(_get_name(ListOf(int)), 'list(int)')
self.assertEqual(_get_name(SetOf(int)), 'set(int)')
self.assertEqual(_get_name(DictOf(int, float)), 'dict(int->float)')
self.assertEqual(_get_name(VarArg(str)), 'tuple(str)')
self.assertEqual(_get_name(KwArg(str)), 'dict(%s->str)' % str_type)
self.assertEqual(_get_name(KwArg(ListOf(DictOf(float, SetOf((str, TupleOf(Unicode))))))),
'dict(%s->list(dict(float->set((str|tuple(%s))))))' % (str_type, unicode_type))