In [None]:
#default_exp dispatch

In [None]:
#export
from __future__ import annotations
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *
from fastcore.meta import delegates

from collections import defaultdict
from plum import Function, Dispatcher

In [None]:
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

# Type dispatch

> Multiple dispatch, extending [plum](https://github.com/wesselb/plum)

Type dispatch, or [multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based on the input types it receives. This is a prominent feature in some  programming languages like [Julia](https://docs.julialang.org/en/v1/manual/methods/).

Type dispatch allows you to have a common API for functions that do similar tasks. This is especially useful in data science, where the same operation (e.g. normalize, categorize) requires an implementation that depends on its input type (e.g. numpy array, pandas dataframe, pytorch tensor).

Fastcore uses and extends the wonderful [plum](https://github.com/wesselb/plum) library's implementation of multiple dispatch for Python. Be sure to view their [informative documentation](https://github.com/wesselb/plum#basic-usage) as well.

In [None]:
#export
def _eval_annotations(f):
    "Evaluate future annotations before passing to plum to support backported union operator `|`"
    f = copy_func(f)
    for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v
    return f

In [None]:
#hide
def f(x:int|str) -> float: pass
test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})
def f(x:(int,str)) -> float: pass
test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})
def f(x): pass
test_eq(_eval_annotations(f).__annotations__, {})

In [None]:
#export
def _pt_repr(o):
    "Concise repr of plum types"
    n = type(o).__name__
    if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
    if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
    if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
    if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
    if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
    if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
    assert len(o.get_types()) == 1
    return o.get_types()[0].__name__

In [None]:
#hide
from typing import Dict, List, Iterable, Sequence, Tuple
from plum.type import VarArgs, ptype

test_eq(_pt_repr(ptype(int)), 'int')
test_eq(_pt_repr(ptype(Union[int, str])), 'int|str')
test_eq(_pt_repr(ptype(Tuple[int, str])), 'tuple[int,str]')
test_eq(_pt_repr(ptype(List[int])), 'list[int]')
test_eq(_pt_repr(ptype(Sequence[int])), 'Sequence[int]')
test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')
test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')
test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')
test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),
        'dict[tuple[int|str,float],list[tuple[object]]]')

## FastFunction -

In [None]:
#export
class FastFunction(Function):
    def __repr__(self):
        return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
                         for s, (f, r) in self.methods.items())

    @delegates(Function.dispatch)
    def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)

    def __getitem__(self, ts):
        "Return the most-specific matching method with fewest parameters"
        ts = L(ts)
        nargs = min(len(o) for o in self.methods.keys())
        while len(ts) < nargs: ts.append(object)
        return self.invoke(*ts)

`FastFunction` extends `plum.Function` with the following functionality.

`FastFunction` has a concise `repr`:

In [None]:
def f(x: int) -> float: pass
f = FastFunction(f).dispatch(f)
f

f: (int) -> float

`FastFunction` supports fastcore's backport of the `|` operator on types:

In [None]:
def f1(x):          return 'obj'
def f2(x: int|str): return 'int|str'
f = FastFunction(f1).dispatch(f1).dispatch(f2)

test_eq(f(0),   'int|str')
test_eq(f(''),  'int|str')
test_eq(f(0.0), 'obj')

Indexing a `FastFunction` works like [`plum.Function.invoke`](https://github.com/wesselb/plum#directly-invoke-a-method) but returns the most-specific matching method with the fewest parameters:

In [None]:
def f1(a: int,   b,      c):    return 'int, 3 args'
def f2(a: int,   b,      c, d): return 'int, 4 args'
def f3(a: float, b,      c):    return 'float, 3 args'
def f4(a: float, b: str, c):    return 'float, str, 3 args'
f = FastFunction(f1).dispatch(f1).dispatch(f2).dispatch(f3).dispatch(f4)

test_eq(f[int](0,0,0),        'int, 3 args')
test_eq(f[float](0,0,0),      'float, 3 args')
test_eq(f[float](0,0,0),      'float, 3 args')
test_eq(f[float, str](0,0,0), 'float, str, 3 args')

## FastDispatcher -

In [None]:
#export
class FastDispatcher(Dispatcher):
    def _get_function(self, method, owner):
        "Adapted from `Dispatcher._get_function` to use `FastFunction`"
        name = method.__name__
        if owner:
            if owner not in self._classes: self._classes[owner] = {}
            namespace = self._classes[owner]
        else: namespace = self._functions
        if name not in namespace: namespace[name] = FastFunction(method, owner=owner)
        return namespace[name]

    @delegates(Dispatcher.__call__, but='method')
    def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)

    def _to(self, cls, nm, f, **kwargs):
        nf = copy_func(f)
        nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner
        pf = self(nf, **kwargs)
        # plum uses __set_name__ to resolve a plum.Function's owner
        # since we assign after class creation, __set_name__ must be called directly
        # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
        pf.__set_name__(cls, nm)
        pf = pf.resolve()
        setattr(cls, nm, pf)
        return pf

    def to(self, cls):
        "Decorator: dispatch `f` to `cls.f`"
        def _inner(f, **kwargs):
            nm = f.__name__
            # check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on
            if nm in cls.__dict__:
                pf = getattr(cls, nm)
                if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)
                pf.dispatch(f)
            else: pf = self._to(cls, nm, f, **kwargs)
            return pf
        return _inner

typedispatch = FastDispatcher()

`FastDispatcher` extends `plum.Dispatcher` with the following functionality.

Dispatching with `FastDispatcher` returns a `FastFunction`:

In [None]:
@typedispatch
def f(x): return 'obj'

assert isinstance(f, FastFunction)

`FastDispatcher` supports fastcore's backport of the `|` operator on types:

In [None]:
@typedispatch
def f(x:int|str): return 'int|str'

test_eq(f(0),   'int|str')
test_eq(f(''),  'int|str')
test_eq(f(0.0), 'obj')

... `FastDispatcher.multi` works too:

In [None]:
@typedispatch.multi([bool],[list])
def f(x: bool|list): return 'bool|list'
@typedispatch
def f(x: int): return 'int'

test_eq(f(True), 'bool|list')
test_eq(f([]),   'bool|list')
test_eq(f(0),    'int')

`FastDispatcher.to` lets you dynamically dispatch to class instance methods:

In [None]:
class A:
    @typedispatch
    def f(self, x): return 'obj'

@typedispatch.to(A)
def f(self, x:int): return 'int'

a = A()
test_eq(a.f(0), 'int')
test_eq(a.f(''), 'obj')

### Tests -

In [None]:
#hide
#Call `to` twice consecutively
class A: pass

@typedispatch.to(A)
def f(self, x:int): return 'int'

a = A()
test_eq(a.f(0), 'int')

@typedispatch.to(A)
def f(self, x:str): return 'str'

test_eq(a.f(''), 'str')

In [None]:
#hide
#Call `to` on an ordinary method (not a `FastFunction`)
class A:
    def f(self, x): return 'obj'

@typedispatch.to(A)
def f(self, x:int): return 'int'

a = A()
test_eq(a.f(0), 'int')
test_eq(a.f(''), 'obj')

In [None]:
#hide
#Calling `to` when there is a matching inherited method doesn't alter the base class
#but still dispatches to it
class A:
    def f(self, x): return 'A'
Af = A.f
class B(A):
    @typedispatch
    def f(self, x:int): return 'B'
test_is(Af, A.f)
b = B()
test_eq(b.f(0), 'B')
test_eq(b.f(''), 'A')

## Casting

Now that we can dispatch on types, let's make it easier to cast objects to a different type.

In [None]:
#export
_all_=['cast']

In [None]:
#export
def retain_meta(x, res, as_copy=False):
    "Call `res.set_meta(x)`, if it exists"
    if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)
    return res

In [None]:
#export
def default_set_meta(self, x, as_copy=False):
    "Copy over `_meta` from `x` to `res`, if it's missing"
    if hasattr(x, '_meta') and not hasattr(self, '_meta'):
        meta = x._meta
        if as_copy: meta = copy(meta)
        self._meta = meta
    return self

In [None]:
#export
@typedispatch
def cast(x, typ):
    "cast `x` to type `typ` (may also change `x` inplace)"
    res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
    if risinstance('ndarray', res): res = res.view(typ)
    elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
    else:
        try: res.__class__ = typ
        except: res = typ(res)
    return retain_meta(x, res)

This works both for plain python classes:...

In [None]:
mk_class('_T1', 'a')   # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)        
assert t2 is t            # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)

test_eq_type(_T2(a=1), t2) 


...as well as for arrays and tensors.

In [None]:
class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

To customize casting for other types, define a separate `cast` function with `typedispatch` for your type.

In [None]:
#export
def retain_type(new, old=None, typ=None, as_copy=False):
    "Cast `new` to type of `old` or `typ` if it's a superclass"
    # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
    if new is None: return
    assert old is not None or typ is not None
    if typ is None:
        if not isinstance(old, type(new)): return new
        typ = old if isinstance(old,type) else type(old)
    # Do nothing the new type is already an instance of requested type (i.e. same type)
    if typ==NoneType or isinstance(new, typ): return new
    return retain_meta(old, cast(new, typ), as_copy=as_copy)

In [None]:
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
c = retain_type(b, typ=_T)
test_eq_type(c, a)

If `old` has a `_meta` attribute, its content is passed when casting `new` to the type of `old`.  In the below example, only the attribute `a`, but not `other_attr` is kept, because `other_attr` is not in `_meta`:

In [None]:
class _A():
    set_meta = default_set_meta
    def __init__(self, t): self.t=t

class _B1(_A):
    def __init__(self, t, a=1):
        super().__init__(t)
        self._meta = {'a':a}
        self.other_attr = 'Hello' # will not be kept after casting.
        
x = _B1(1, a=2)
b = _A(1)
c = retain_type(b, old=x)
test_eq(c._meta, {'a': 2})
assert not getattr(c, 'other_attr', None)

In [None]:
#export
def retain_types(new, old=None, typs=None):
    "Cast each item of `new` to type of matching item in `old` if it's a superclass"
    if not is_listy(new): return retain_type(new, old, typs)
    if typs is not None:
        if isinstance(typs, dict):
            t = first(typs.keys())
            typs = typs[t]
        else: t,typs = typs,None
    else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
    return t(L(new, old, typs).map_zip(retain_types, cycled=True))

In [None]:
class T(tuple): pass

t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

In [None]:
#export
def explode_types(o):
    "Return the type of `o`, potentially in nested dictionaries for thing that are listy"
    if not is_listy(o): return type(o)
    return {type(o): [explode_types(o_) for o_ in o]}

In [None]:
test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_test.ipynb.
Converted 01_basics.ipynb.
Converted 02_foundation.ipynb.
Converted 03_xtras.ipynb.
Converted 03a_parallel.ipynb.
Converted 03b_net.ipynb.
Converted 04_dispatch.ipynb.
Converted 05_transform.ipynb.
Converted 06_docments.ipynb.
Converted 07_meta.ipynb.
Converted 08_script.ipynb.
Converted index.ipynb.
Converted parallel_win.ipynb.
