Skip to content

Commit

Permalink
support PEP484 function argument annotation (#65)
Browse files Browse the repository at this point in the history
* support PEP484 function argument annotation

cc:
Asana:
by: 顏孜羲 <joseph.yen@gmail.com>

* fix python2 compatibility

* make Dispatcher.register() supports function annotation

* add python 3 only test

* Add tests for methods and overlapping signatures

* fix compatibility of function annotation to methods

cc:
Asana:
by: 顏孜羲 <joseph.yen@gmail.com>

* add test_overlaps_conflict_annotation
  • Loading branch information
d2207197 authored and mrocklin committed Nov 13, 2017
1 parent aabb2da commit 4dd36b1
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 26 deletions.
7 changes: 6 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ install:
- pip install coverage

script:
- nosetests --with-doctest
- |
if [[ $(bc <<< "$TRAVIS_PYTHON_VERSION >= 3.3") -eq 1 ]]; then
nosetests --with-doctest
else
nosetests --with-doctest -I '.*_3only.py$'
fi
after_success:
- |
Expand Down
8 changes: 4 additions & 4 deletions multipledispatch/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from contextlib import contextmanager
from warnings import warn
import inspect
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn

from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn

global_namespace = dict()

Expand Down Expand Up @@ -54,11 +52,13 @@ def dispatch(*types, **kwargs):
on_ambiguity = kwargs.get('on_ambiguity', ambiguity_warn)

types = tuple(types)

def _(func):
name = func.__name__

if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(name,
dispatcher = inspect.currentframe().f_back.f_locals.get(
name,
MethodDispatcher(name))
else:
if name not in namespace:
Expand Down
66 changes: 52 additions & 14 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl


class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
Expand Down Expand Up @@ -39,7 +41,6 @@ def restart_ordering(on_ambiguity=ambiguity_warn):
dispatcher.reorder(on_ambiguity=on_ambiguity)



class Dispatcher(object):
""" Dispatch methods based on type signature
Expand Down Expand Up @@ -102,6 +103,32 @@ def _(func):
return func
return _

@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return sig.parameters.values()

@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional paremeters
"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter

params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))

annotations = tuple(
param.annotation
for param in params)

if all(ann is not Parameter.empty for ann in annotations):
return annotations

def add(self, signature, func, on_ambiguity=ambiguity_warn):
""" Add new types/method pair to dispatcher
Expand All @@ -120,6 +147,12 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
Expand All @@ -129,7 +162,7 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
for typ in signature:
if not isinstance(typ, type):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
Expand All @@ -148,7 +181,6 @@ def reorder(self, on_ambiguity=ambiguity_warn):
else:
_unresolved_dispatchers.add(self)


def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
Expand All @@ -157,25 +189,23 @@ def __call__(self, *args, **kwargs):
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
'Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
'Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
self._cache[types] = func
try:
return func(*args, **kwargs)

except MDNotImplementedError:
funcs = self.dispatch_iter(*types)
next(funcs) # burn first
next(funcs) # burn first
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError("Matching functions for "
"%s: <%s> found, but none completed successfully"
% (self.name, str_signature(types)))


"%s: <%s> found, but none completed successfully"
% (self.name, str_signature(types)))

def __str__(self):
return "<dispatched %s>" % self.name
Expand Down Expand Up @@ -239,7 +269,6 @@ def __setstate__(self, d):
self.ordering = ordering(self.funcs)
self._cache = dict()


@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]
Expand Down Expand Up @@ -293,6 +322,13 @@ class MethodDispatcher(Dispatcher):
See Also:
Dispatcher
"""

@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)

def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
Expand All @@ -315,13 +351,15 @@ def str_signature(sig):
"""
return ', '.join(cls.__name__ for cls in sig)


def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n"%(name)
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + ', '.join('['+str_signature(s)+']' for s in pair) + "\n"
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ ')\ndef %s(...)'%name for s in amb])
+ ')\ndef %s(...)' % name for s in amb])
return text
15 changes: 8 additions & 7 deletions multipledispatch/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings

from multipledispatch.dispatcher import (Dispatcher, MethodDispatcher,
halt_ordering, restart_ordering, MDNotImplementedError)
from multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
MethodDispatcher, halt_ordering,
restart_ordering)
from multipledispatch.utils import raises


Expand Down Expand Up @@ -75,13 +76,13 @@ def _init_obj(self, datum):
def test_on_ambiguity():
f = Dispatcher('f')

identity = lambda x: x
def identity(x): return x

ambiguities = [False]

def on_ambiguity(dispatcher, amb):
ambiguities[0] = True


f.add((object, object), identity, on_ambiguity=on_ambiguity)
assert not ambiguities[0]
f.add((object, float), identity, on_ambiguity=on_ambiguity)
Expand Down Expand Up @@ -134,7 +135,7 @@ def three(x, y):
assert one.__doc__.strip() in f.__doc__
assert two.__doc__.strip() in f.__doc__
assert f.__doc__.find(one.__doc__.strip()) < \
f.__doc__.find(two.__doc__.strip())
f.__doc__.find(two.__doc__.strip())
assert 'object, object' in f.__doc__
assert master_doc in f.__doc__

Expand Down Expand Up @@ -190,6 +191,7 @@ def test_source_raises_on_missing_function():

def test_halt_method_resolution():
g = [0]

def on_ambiguity(a, b):
g[0] += 1

Expand All @@ -209,7 +211,6 @@ def func(*args):

assert g == [1]

print(list(f.ordering))
assert set(f.ordering) == set([(int, object), (object, int)])


Expand Down Expand Up @@ -266,7 +267,7 @@ def _(x):
else:
raise MDNotImplementedError()

assert f('hello') == 'default' # default behavior
assert f('hello') == 'default' # default behavior
assert f(2) == 'even' # specialized behavior
assert f(3) == 'default' # fall bac to default behavior
assert raises(NotImplementedError, lambda: f(1, 2))
Expand Down
94 changes: 94 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# import sys

# from nose import SkipTest

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher


def test_function_annotation_register():
f = Dispatcher('f')

@f.register()
def inc(x: int):
return x + 1

@f.register()
def inc(x: float):
return x - 1

assert f(1) == 2
assert f(1.0) == 0.0


def test_function_annotation_dispatch():
@dispatch()
def inc(x: int):
return x + 1

@dispatch()
def inc(x: float):
return x - 1

assert inc(1) == 2
assert inc(1.0) == 0.0


def test_function_annotation_dispatch_custom_namespace():
namespace = {}

@dispatch(namespace=namespace)
def inc(x: int):
return x + 2

@dispatch(namespace=namespace)
def inc(x: float):
return x - 2

assert inc(1) == 3
assert inc(1.0) == -1.0

assert namespace['inc'] == inc
assert set(inc.funcs.keys()) == set([(int,), (float,)])


def test_method_annotations():
class Foo():
@dispatch()
def f(self, x: int):
return x + 1

@dispatch()
def f(self, x: float):
return x - 1

foo = Foo()

assert foo.f(1) == 2
assert foo.f(1.0) == 0.0


def test_overlaps():
@dispatch(int)
def inc(x: int):
return x + 1

@dispatch(float)
def inc(x: float):
return x - 1

assert inc(1) == 2
assert inc(1.0) == 0.0


def test_overlaps_conflict_annotation():
@dispatch(int)
def inc(x: str):
return x + 1

@dispatch(float)
def inc(x: int):
return x - 1

assert inc(1) == 2
assert inc(1.0) == 0.0

0 comments on commit 4dd36b1

Please sign in to comment.