Skip to content

Commit

Permalink
Lazy dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Mystic-Mirage committed Jul 6, 2020
1 parent 3acb624 commit 2dd329e
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 6 deletions.
13 changes: 13 additions & 0 deletions docs/source/resolution.rst
Expand Up @@ -161,6 +161,19 @@ For example, here's a function that takes a ``float`` followed by any number
>>> f(2.0, '4', 6, 8)
20.0
Lazy Dispatch
-------------

You may need to refer to your own class while defining it. Just use its name as
a string and ``multipledispatch`` will resolve a name to a class during runtime

.. code::
class MyInteger(int):
@dispatch('MyInteger')
def add(self, x):
return self + x
Ambiguities
-----------

Expand Down
5 changes: 5 additions & 0 deletions multipledispatch/conflict.py
@@ -1,3 +1,5 @@
import itertools

from .utils import _toposort, groupby
from .variadic import isvariadic

Expand All @@ -8,6 +10,9 @@ class AmbiguityWarning(Warning):

def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if any(isinstance(x, str) for x in itertools.chain(a, b)):
# skip due to lazy types
return False
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
Expand Down
48 changes: 42 additions & 6 deletions multipledispatch/dispatcher.py
Expand Up @@ -117,14 +117,16 @@ class Dispatcher(object):
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc', \
'_lazy'

def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc

self._cache = {}
self._lazy = False

def register(self, *types, **kwargs):
""" register dispatcher with new implementation
Expand Down Expand Up @@ -214,9 +216,8 @@ def add(self, signature, func):
return

new_signature = []

for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
if not isinstance(typ, (type, list, str)):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
Expand All @@ -237,9 +238,12 @@ def add(self, signature, func):
'To use a variadic union type place the desired types '
'inside of a tuple, e.g., [(int, str)]'
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
typ = Variadic[typ[0]]

if isinstance(typ, str):
self._lazy = True

new_signature.append(typ)

self.funcs[tuple(new_signature)] = func
self._cache.clear()
Expand All @@ -264,6 +268,9 @@ def reorder(self, on_ambiguity=ambiguity_warn):
return od

def __call__(self, *args, **kwargs):
if self._lazy:
self._unlazy()

types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
Expand Down Expand Up @@ -359,6 +366,7 @@ def __setstate__(self, d):
self.funcs = d['funcs']
self._ordering = ordering(self.funcs)
self._cache = dict()
self._lazy = any(isinstance(t, str) for t in itl.chain(*d['funcs']))

@property
def __doc__(self):
Expand Down Expand Up @@ -400,6 +408,31 @@ def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))

def _unlazy(self):
funcs = {}
for signature, func in self.funcs.items():
new_signature = []
for typ in signature:
if isinstance(typ, str):
for frame_info in inspect.stack():
frame = frame_info[0]
scope = dict(frame.f_globals)
scope.update(frame.f_locals)
if typ in scope:
typ = scope[typ]
break
else:
raise NameError("name '%s' is not defined" % typ)
new_signature.append(typ)

new_signature = tuple(new_signature)
funcs[new_signature] = func

self.funcs = funcs
self.reorder()

self._lazy = False


def source(func):
s = 'File: %s\n\n' % inspect.getsourcefile(func)
Expand Down Expand Up @@ -427,6 +460,9 @@ def __get__(self, instance, owner):
return self

def __call__(self, *args, **kwargs):
if self._lazy:
self._unlazy()

types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
Expand Down
72 changes: 72 additions & 0 deletions multipledispatch/tests/test_dispatcher.py
@@ -1,6 +1,7 @@

import warnings

from multipledispatch import dispatch
from multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
MethodDispatcher)
from multipledispatch.conflict import ambiguities
Expand Down Expand Up @@ -421,3 +422,74 @@ def _3(*objects):
assert f('a', ['a']) == 2
assert f(1) == 3
assert f() == 3


def test_lazy_methods():
class A(object):
@dispatch(int)
def get(self, _):
return 'int'

@dispatch('A')
def get(self, _):
"""Self reference"""
return 'A'

@dispatch('B')
def get(self, _):
"""Yet undeclared type"""
return 'B'

class B(object):
pass

class C(A):
@dispatch('D')
def get(self, _):
"""Non-existent type"""
return 'D'

a = A()
b = B()
c = C()

assert a.get(1) == 'int'
assert a.get(a) == 'A'
assert a.get(b) == 'B'
assert raises(NameError, lambda: c.get(1))


def test_lazy_functions():
f = Dispatcher('f')
f.add((int,), inc)
f.add(('Int',), dec)

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert f(1) == 2
assert f(Int(1)) == 0


def test_lazy_serializable():
f = Dispatcher('f')
f.add((int,), inc)
f.add(('Int',), dec)

import pickle
assert isinstance(pickle.dumps(f), (str, bytes))

g = pickle.loads(pickle.dumps(f))

assert f.funcs == g.funcs
assert f._lazy == g._lazy

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert g(1) == 2
assert g(Int(1)) == 0
21 changes: 21 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Expand Up @@ -4,6 +4,7 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
from multipledispatch.utils import raises


def test_function_annotation_register():
Expand Down Expand Up @@ -92,3 +93,23 @@ def inc(x: int):

assert inc(1) == 2
assert inc(1.0) == 0.0


def test_lazy_annotations():
f = Dispatcher('f')

@f.register()
def inc(x: int):
return x + 1

@f.register()
def dec(x: 'Int'):
return x - 1

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert f(1) == 2
assert f(Int(1)) == 0

0 comments on commit 2dd329e

Please sign in to comment.