## Multimethod test

In [None]:
#default_exp data.pipeline

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

In [None]:
from multimethod import multidispatch,multimeta,multimethod

In [None]:
#hide
torch.cuda.set_device(int(os.environ.get('DEFAULT_GPU') or 0))

In [None]:
#export
def get_func(t, name, *args, **kwargs):
    "Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
    f = getattr(t, name, noop)
    return f if not (args or kwargs) else partial(f, *args, **kwargs)

In [None]:
test_eq(get_func(operator, 'neg', 2)(), -2)
test_eq(get_func(operator.neg, '__call__')(2), -2)

In [None]:
from multimethod import DispatchError

In [None]:
#export
class Transform(metaclass=multimeta):
    t = None
    def __init__(self,encodes=None,decodes=None):
        self.encodes = getattr(self, 'encodes', noop) if encodes is None else encodes 
        self.decodes = getattr(self, 'decodes', noop) if decodes is None else decodes
    
    def _apply(self, fs, x):
        if self.t: fs = [self._get_func(fs,t_) for t_ in self.t]
        if is_listy(fs): return tuple(f(x_) for f,x_ in zip(fs,x))
        return fs(*L(x))

    def _get_func(self,f,t):
        try: f = f.__func__[object,t]
        except DispatchError: return noop
        return partial(f,self)
    
    def accept_types(self, t):
        # We can't create encodes/decodes here since patching might change things later
        # So we call _get_func in _apply instead
        self.t = t

    def __call__(self, x, filt=None): return self._apply(self.encodes, x)
    def decode  (self, x, filt=None): return self._apply(self.decodes, x)
    def __getitem__(self, x): return self(x) # So it can be used as a `Dataset`

In [None]:
tfm = Transform(operator.neg, decodes=operator.neg)
start = 4
t = tfm(start)
test_eq(t, -4)
test_eq(t, tfm[start]) #You can use a transform as a dataset
test_eq(tfm.decode(t), start)

In [None]:
class _AddOne(Transform):
    def encodes(self, x): return x+1
    def decodes(self, x): return x-1

addt = _AddOne()
test_eq(addt(start), 5)
test_eq(addt.decodes(addt(start)), start)

In [None]:
addt.accept_types([float,float])
t = addt([1,2])
test_eq(t, (2,3))
test_eq(addt.decode(t), (1,2))

In [None]:
class _Add(Transform):
    def encodes(self, x, y): return (x+y,y)
    def decodes(self, x, y): return (x-y,y)

addt = _Add()
t = addt([1,2])
test_eq(t, (3,2))
test_eq(addt.decode(t), (1,2))

In [None]:
def transform(cls):
    def _inner(f):
        if   f.__name__=='encodes': cls.encodes.register(f)
        elif f.__name__=='decodes': cls.decodes.register(f)
        else: raise Exception('Function must be "encodes" or "decodes"')
    return _inner

In [None]:
#Apply on all integers of the tuple
#Also note that your tuples can have more than two elements
class _AddOne(Transform):
    def encodes(self, x:numbers.Integral): return x+1
    def encodes(self, x:float): return x*2
    def decodes(self, x:numbers.Integral): return x-1

addt = _AddOne()
addt.accept_types([float, int, float])
start = [1,2,3]

In [None]:
@transform(_AddOne)
def decodes(self, x:float): return x/2

In [None]:
t = addt(start)
test_eq(t, (2,3,6))
test_eq(addt.decode(t), start)