Skip to content

Commit

Permalink
Merge 25968b7 into 9e3c87d
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobHayes committed Apr 18, 2021
2 parents 9e3c87d + 25968b7 commit 6b893dc
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 39 deletions.
14 changes: 3 additions & 11 deletions .travis.yml
@@ -1,30 +1,22 @@

language: python
python:
- "pypy"
- "pypy3"
matrix:
include:
- { python: '2.7', env: }
- { arch: arm64, python: '2.7' }
- { python: '3.4', env: }
- { python: '3.5', env: }
- { python: '3.6', env: }
- { python: '3.7', env: }
- { python: '3.8', env: }
- { arch: arm64, python: '3.8' }
- { python: '3.9', env: }
- { arch: arm64, python: '3.9' }
install:
- pip install --upgrade pip
- pip install coverage
- pip install --upgrade pytest pytest-benchmark

script:
- |
if [[ $(bc <<< "$TRAVIS_PYTHON_VERSION >= 3.3") -eq 1 ]]; then
py.test --doctest-modules multipledispatch
else
py.test --doctest-modules --ignore=multipledispatch/tests/test_dispatcher_3only.py multipledispatch
fi
py.test --doctest-modules multipledispatch
after_success:
- |
Expand Down
30 changes: 18 additions & 12 deletions multipledispatch/dispatcher.py
@@ -1,3 +1,4 @@
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, get_type_hints
from warnings import warn
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
Expand Down Expand Up @@ -95,7 +96,10 @@ def variadic_signature_matches(types, full_signature):
return all(variadic_signature_matches_iter(types, full_signature))


class Dispatcher(object):
DISPATCHED_RETURN = TypeVar("DISPATCHED_RETURN")


class Dispatcher(Generic[DISPATCHED_RETURN]):
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Expand All @@ -119,14 +123,16 @@ class Dispatcher(object):
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'

def __init__(self, name, doc=None):
def __init__(self, name: str, doc: Optional[None] = None) -> None:
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc

self._cache = {}

def register(self, *types, **kwargs):
def register(
self, *types: type, **kwargs: Any
) -> Callable[[Callable[..., DISPATCHED_RETURN]], Callable[..., DISPATCHED_RETURN]]:
""" register dispatcher with new implementation
>>> f = Dispatcher('f')
Expand Down Expand Up @@ -171,19 +177,19 @@ def get_func_annotations(cls, func):
if params:
Parameter = inspect.Parameter

params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))

hints = get_type_hints(func)
annotations = tuple(
param.annotation
for param in params)
hints.get(param.name, Parameter.empty)
for param in params
if param.kind in (
Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD
)
)

if all(ann is not Parameter.empty for ann in annotations):
return annotations

def add(self, signature, func):
def add(self, signature: Tuple[type, ...], func: Callable[..., DISPATCHED_RETURN]) -> None:
""" Add new types/method pair to dispatcher
>>> D = Dispatcher('add')
Expand Down Expand Up @@ -263,7 +269,7 @@ def reorder(self, on_ambiguity=ambiguity_warn):
on_ambiguity(self, amb)
return od

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> DISPATCHED_RETURN:
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
Expand Down
33 changes: 17 additions & 16 deletions multipledispatch/tests/test_core.py
@@ -1,14 +1,15 @@
from multipledispatch import dispatch
from multipledispatch import dispatch as orig_dispatch
from multipledispatch.utils import raises
from functools import partial
import pytest

test_namespace = dict()

orig_dispatch = dispatch
dispatch = partial(dispatch, namespace=test_namespace)
@pytest.fixture(name="dispatch")
def fixture_dispatch():
return partial(orig_dispatch, namespace=dict())


def test_singledispatch():
def test_singledispatch(dispatch):
@dispatch(int)
def f(x):
return x + 1
Expand All @@ -28,7 +29,7 @@ def f(x):
assert raises(NotImplementedError, lambda: f('hello'))


def test_multipledispatch(benchmark):
def test_multipledispatch(dispatch):
@dispatch(int, int)
def f(x, y):
return x + y
Expand All @@ -48,7 +49,7 @@ class D(C): pass
class E(C): pass


def test_inheritance():
def test_inheritance(dispatch):
@dispatch(A)
def f(x):
return 'a'
Expand All @@ -62,7 +63,7 @@ def f(x):
assert f(C()) == 'a'


def test_inheritance_and_multiple_dispatch():
def test_inheritance_and_multiple_dispatch(dispatch):
@dispatch(A, A)
def f(x, y):
return type(x), type(y)
Expand All @@ -78,7 +79,7 @@ def f(x, y):
assert raises(NotImplementedError, lambda: f(B(), B()))


def test_competing_solutions():
def test_competing_solutions(dispatch):
@dispatch(A)
def h(x):
return 1
Expand All @@ -90,7 +91,7 @@ def h(x):
assert h(D()) == 2


def test_competing_multiple():
def test_competing_multiple(dispatch):
@dispatch(A, B)
def h(x, y):
return 1
Expand All @@ -102,7 +103,7 @@ def h(x, y):
assert h(D(), B()) == 2


def test_competing_ambiguous():
def test_competing_ambiguous(dispatch):
@dispatch(A, C)
def f(x, y):
return 2
Expand All @@ -115,7 +116,7 @@ def f(x, y):
# assert raises(Warning, lambda : f(C(), C()))


def test_caching_correct_behavior():
def test_caching_correct_behavior(dispatch):
@dispatch(A)
def f(x):
return 1
Expand All @@ -129,7 +130,7 @@ def f(x):
assert f(C()) == 2


def test_union_types():
def test_union_types(dispatch):
@dispatch((A, C))
def f(x):
return 1
Expand All @@ -156,7 +157,7 @@ def foo(x):

"""
Fails
def test_dispatch_on_dispatch():
def test_dispatch_on_dispatch(dispatch):
@dispatch(A)
@dispatch(C)
def q(x):
Expand All @@ -167,7 +168,7 @@ def q(x):
"""


def test_methods():
def test_methods(dispatch):
class Foo(object):
@dispatch(float)
def f(self, x):
Expand All @@ -188,7 +189,7 @@ def g(self, x):
assert foo.g(1) == 4


def test_methods_multiple_dispatch():
def test_methods_multiple_dispatch(dispatch):
class Foo(object):
@dispatch(A, A)
def f(x, y):
Expand Down
4 changes: 4 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Expand Up @@ -17,6 +17,10 @@ def inc(x: int):
def inc(x: float):
return x - 1

@f.register()
def inc(x: 'float'):
return x - 1

assert f(1) == 2
assert f(1.0) == 0.0

Expand Down

0 comments on commit 6b893dc

Please sign in to comment.