In [None]:
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

from types import MethodType

In [None]:
class BasicTransform():
    def __call__(self, *args, **kwargs):
        return self.encode(*args, **kwargs)
    def decode  (self, *args, **kwargs):
        return self.decode(*args, **kwargs)

In [None]:
#Apply on the tuple as a whole
class A(BasicTransform): 
    def encode(self, x, y): return (x+y,y)
    def decode(self, x, y): return (x-y,y)

addt = A()
t = addt(1,2)
test_eq(t, (3,2))
test_eq(addt.decode(*t), (1,2))

In [None]:
class TupleTransform(BasicTransform):
    def encode(self, x, *args, **kwargs):
        if args: return [self.encodes(x_, **kwargs) for x_ in (x,)+args]
        return self.encodes(x, **kwargs)

    def decode(self, x, *args, **kwargs):
        if args: return [self.decodes(x_, **kwargs) for x_ in (x,)+args]
        return self.decodes(x, **kwargs)

In [None]:
#export
cmp_instance = functools.cmp_to_key(lambda a,b: 0 if a==b else 1 if issubclass(a,b) else -1)

In [None]:
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])

In [None]:
#export
def param_anno(f, idx=0):
    "Get the annotation of parameter `idx` param of `f`"
    t = list(inspect.signature(f).parameters.values())[idx].annotation
    return None if t==inspect.Parameter.empty else t

In [None]:
def _f(a, b:str)->float: pass
test_eq(param_anno(_f, 1), str)
test_eq(param_anno(_f, 0), None)

In [None]:
class FuncTransform(TupleTransform):
    def __init__(self, enc=None, dec=None):
        if enc: self.enc = enc
        if dec: self.dec = dec
        self.t_enc = param_anno(self.enc) or object
        self.t_dec = param_anno(self.dec) or object

    def encodes(self, x, *args, **kwargs):
        return self.enc(x, *args, **kwargs) if isinstance(x,self.t_enc) else x
    def decodes(self, x, *args, **kwargs):
        return self.dec(x, *args, **kwargs) if isinstance(x,self.t_dec) else x

    def enc(self, x, *args, **kwargs): return x
    def dec(self, x, *args, **kwargs): return x

In [None]:
d = FuncTransform(operator.neg, operator.neg)
test_eq(d(1), -1)
test_eq(d.decode(1), -1)

In [None]:
def neg_int(x:numbers.Integral): return -x

d = FuncTransform(neg_int)
test_eq(d(1), -1)
test_eq(d(1.), 1.)
test_eq(d(1.,2,3.), [1.,-2,3.])
test_eq(d.decode(1), 1)

In [None]:
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, funcs=None):
        self.funcs = funcs or {}
        self._reset()
        
    def _reset(self):
        self.funcs = {k:self.funcs[k] for k in sorted(self.funcs, key=cmp_instance, reverse=True)}
        self.cache = {**self.funcs}
        
    def add(self, t, f):
        "Add type `t` and function `f`"
        self.funcs[t] = f
        self._reset()
        
    def __repr__(self): return str(self.funcs)
    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        if k in self.cache: return self.cache[k]
        types = [f for f in self.funcs if issubclass(k,f)]
        res = self.funcs[types[0]] if types else None
        self.cache[k] = res
        return res

In [None]:
td = {int:1, numbers.Number:2, numbers.Integral:3}
t = TypeDispatch(td)

test_eq(t[int], 1)
test_eq(t[str], None)
test_eq(t[float], 2)
t.add(typing.Collection, 4)
test_eq(t[str], 4)
test_eq(t[int], 1)

In [None]:
class TfmMeta(type):
    def __new__(cls, name, bases, dct):
        res = super().__new__(cls, name, bases, dct)
        res.enc,res.dec = TypeDispatch(),TypeDispatch()
        return res

    def __call__(cls, *args, **kwargs):
        f = args[0] if args else None
        if isinstance(f,Callable) and f.__name__ in ('decode','encode','_'):
            d = cls.dec if f.__name__ == 'decode' else cls.enc
            d.add(param_anno(f) or object, f)
            return f
        return super().__call__(*args, **kwargs)

In [None]:
class Transform(BasicTransform, metaclass=TfmMeta):
    def lookup(self, is_enc, x):
        d = self.enc if is_enc else self.dec
        return d[type(x)] or noops
#         try: return d[next(t for t in type(x).mro() if t in d)]
#         except StopIteration: return noops
    
    def _f(self, is_enc, x, *args, **kwargs):
        f = MethodType(self.lookup(is_enc, x), self)
        return f(x, *args, **kwargs)

    def __call__(self, x, *args, **kwargs):
        if args: return [self._f(True , x_, **kwargs) for x_ in (x,)+args]
        return self._f(True, x, **kwargs)

    def decode  (self, x, *args, **kwargs):
        if args: return [self._f(False, x_, **kwargs) for x_ in (x,)+args]
        return self._f(False, x, **kwargs)

In [None]:
b.enc

{<class 'object'>: <function _ at 0x7fa7a694d268>}

In [None]:
class B(Transform): pass
class C(Transform): pass
b = B()
test_eq(b(1), 1)

@B
def _(self, x:int):  return x+1
@B
def _(self, x:str): return x+'1'

b,c = B(),C()
test_eq(b(1), 2)
test_eq(b('1'), '11')
test_eq(c(1), 1)
test_eq(b.decode(2), 2)
assert pickle.loads(pickle.dumps(b))

TypeError: unsupported operand type(s) for +: 'int' and 'str'

In [None]:
@B
def decode(self, x:int): return x-1
test_eq(b.decode(2), 1)
test_eq(b.decode('2'), '2')

In [None]:
start = (1.,2,3.)
t = d(*start)
test_eq(t, (-1.,2,-3.))
test_eq(d.decode(*t), (-1.,2,-3.))

In [None]:
#Apply on all part of the tuple
class A(Transform): pass
@A
def _(self, x): return x+1
@A
def decode(self, x): return x-1

addt = A()
t = addt(1,2)
test_eq(t, (2,3))
test_eq(addt.decode(*t), (1,2))

In [None]:
class FuncTransform(TupleTransform):
    def __init__(self, enc=None, dec=None):
        if enc: self.enc = enc
        if dec: self.dec = dec
        self.t_enc,self.t_dec = _p1_anno(self.enc),_p1_anno(self.dec)

    def encodes(self, x, *args, **kwargs):
        return self.enc(x, *args, **kwargs) if isinstance(x,self.t_enc) else x
    def decodes(self, x, *args, **kwargs):
        return self.dec(x, *args, **kwargs) if isinstance(x,self.t_dec) else x

    def enc(self, x, *args, **kwargs): return x
    def dec(self, x, *args, **kwargs): return x

In [None]:
d = FuncTransform(operator.neg, operator.neg)
test_eq(d(1), -1)
test_eq(d.decode(1), -1)

In [None]:
def neg_float(x:float): return -x

d = FuncTransform(neg_float)
test_eq(d(1), 1)
test_eq(d(1.0), -1.0)
test_eq(d.decode(1), 1)

In [None]:
#Apply on all integers of the tuple
#Also note that your tuples can have more than two elements
class A(Transform): pass
@A
def _(self, x:int): return x+1
@A
def _(self, x:float): return x*3
@A
def decode(self, x:int): return x-1

addt = A()
start = 1.0
t = addt(start)
test_eq(t, 3.)
test_eq(addt.decode(t), 3)

In [None]:
start = (1.,2,3.)
t = addt(*start)
test_eq(t, (3.,3,9.))
test_eq(addt.decode(*t), (3.,2,9.))

Using the encodes method of TensorImage in the encodes method of TensorMask

It's used in PILFlip, PILDihedral and AffineCoordTfm IIRC

In [None]:
#export
def n_params(f):
    "Count of positional params"
    return len([p for p in inspect.signature(f).parameters.values()
                if p.default == inspect.Parameter.empty and p.kind != inspect.Parameter.VAR_KEYWORD])