Skip to content

Commit

Permalink
Merge 2ce1c9e into f1146be
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusvniekerk committed Feb 27, 2018
2 parents f1146be + 2ce1c9e commit d9397eb
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 31 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python:
install:
- pip install coverage
- pip install --upgrade pytest pytest-benchmark
- pip install pytypes

script:
- |
Expand Down
15 changes: 13 additions & 2 deletions multipledispatch/conflict.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from .utils import _toposort, groupby
from pytypes import is_subtype, is_Union, get_Union_params
from itertools import zip_longest


class AmbiguityWarning(Warning):
pass


def safe_subtype(a, b):
"""Union safe subclass"""
if is_Union(a):
return any(is_subtype(tp, b) for tp in get_Union_params(a))
else:
return is_subtype(a, b)


def supercedes(a, b):
""" A is consistent and strictly more specific than B """
return len(a) == len(b) and all(map(issubclass, a, b))
return len(a) == len(b) and all(map(safe_subtype, a, b))


def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
return (len(a) == len(b) and
all(issubclass(aa, bb) or issubclass(bb, aa)
all(safe_subtype(aa, bb) or safe_subtype(bb, aa)
for aa, bb in zip(a, b)))


Expand Down
80 changes: 59 additions & 21 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from warnings import warn
import inspect

import copy

from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl

import itertools as itl
import pytypes
import typing


class MDNotImplementedError(NotImplementedError):
Expand Down Expand Up @@ -46,6 +51,7 @@ def restart_ordering(on_ambiguity=ambiguity_warn):
DeprecationWarning,
)


class Dispatcher(object):
""" Dispatch methods based on type signature
Expand Down Expand Up @@ -140,13 +146,17 @@ def add(self, signature, func):
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D.add((typing.Optional[str], ), lambda x: x)
>>> D(1, 2)
3
>>> D(1, 2.0)
>>> D('1', 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
NotImplementedError: Could not find signature for add: <str, float>
>>> D('s')
's'
>>> D(None)
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
Expand All @@ -157,24 +167,35 @@ def add(self, signature, func):
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Make function annotation dict

def process_union(tp):
if isinstance(tp, tuple):
t = typing.Union[tuple(process_union(e) for e in tp)]
return t
else:
return tp

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
signatures = expand_tuples(signature)
for signature in signatures:
signature = tuple(process_union(tp) for tp in signature)

for typ in signature:
if not isinstance(typ, type):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))
# make a copy of the function (if needed) and apply the function annotations

self.funcs[signature] = func
self._cache.clear()
# TODO: MAKE THIS type or typevar
for typ in signature:
try:
typing.Union[typ]
except TypeError:
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))

self.funcs[signature] = func
self._cache.clear()

try:
del self._ordering
Expand All @@ -196,7 +217,11 @@ def reorder(self, on_ambiguity=ambiguity_warn):
return od

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
types = tuple([pytypes.deep_type(arg, 1, max_sample=10) for arg in args])
except:
# some things dont deeptype welkl
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
Expand Down Expand Up @@ -259,12 +284,25 @@ def dispatch(self, *types):
except StopIteration:
return None

@staticmethod
def get_type_vars(x):
if isinstance(x, typing.TypeVar):
yield x
if isinstance(x, typing.GenericMeta):
yield from x.__parameters__

def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
if len(signature) == n:
result = self.funcs[signature]
yield result
try:
typsig = typing.Tuple[signature]
typvars = list(self.get_type_vars(typsig))
if pytypes.is_subtype(typing.Tuple[types], typsig, bound_typevars={t.__name__: t for t in typvars}):
yield result
except pytypes.InputTypeError:
continue

def resolve(self, types):
""" Deterimine appropriate implementation for this type signature
Expand Down
6 changes: 3 additions & 3 deletions multipledispatch/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def f(x):

def test_union_types():
@dispatch((A, C))
def f(x):
def hh(x):
return 1

assert f(A()) == 1
assert f(C()) == 1
assert hh(A()) == 1
assert hh(C()) == 1


def test_namespaces():
Expand Down
29 changes: 29 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
from multipledispatch.utils import raises
import typing


def test_function_annotation_register():
Expand All @@ -30,8 +32,23 @@ def inc(x: int):
def inc(x: float):
return x - 1

@dispatch()
def inc(x: typing.Optional[str]):
return x

@dispatch()
def inc(x: typing.List[int]):
return x[0] * 4

@dispatch()
def inc(x: typing.List[str]):
return x[0] + 'b'

assert inc(1) == 2
assert inc(1.0) == 0.0
assert inc('a') == 'a'
assert inc([8]) == 32
assert inc(['a']) == 'ab'


def test_function_annotation_dispatch_custom_namespace():
Expand Down Expand Up @@ -68,6 +85,18 @@ def f(self, x: float):
assert foo.f(1.0) == 0.0


def test_diagonal_dispatch():
T = typing.TypeVar('T')
U = typing.TypeVar('U')

@dispatch()
def diag(x: T, y: T):
return 'same'

assert diag(1, 6) == 'same'
assert raises(NotImplementedError, lambda: diag(1, '1'))


def test_overlaps():
@dispatch(int)
def inc(x: int):
Expand Down
25 changes: 20 additions & 5 deletions multipledispatch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@

import pytypes
import typing


def raises(err, lamda):
try:
lamda()
Expand All @@ -14,15 +19,25 @@ def expand_tuples(L):
>>> expand_tuples([1, 2])
[(1, 2)]
>>> expand_tuples([1, typing.Optional[str]]) #doctest: +ELLIPSIS
[(1, <... 'str'>), (1, <... 'NoneType'>)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
if pytypes.is_Union(L[0]):
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in pytypes.get_Union_params(L[0])]
elif not pytypes.is_of_type(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]


# Taken from theano/theano/gof/sched.py
Expand Down

0 comments on commit d9397eb

Please sign in to comment.