In [None]:
#default_exp data.transform

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.notebook.showdoc import show_doc

from types import MethodType

# Transforms

In [None]:
#export
def anno_ret(func):
    "Get the return annotation of `func`"
    ann = typing.get_type_hints(func)
    if not ann: return None
    typ = ann.get('return')
    return list(typ.__args__) if getattr(typ, '_name', '')=='Tuple' else typ

In [None]:
#hide
def f(x) -> float: return x
test_eq(anno_ret(f), float)
def f(x) -> Tuple[float,float]: return x
test_eq(anno_ret(f), [float,float])
def f(x) -> None: return x
test_eq(anno_ret(f), NoneType)
def f(x): return x
test_eq(anno_ret(f), None)

In [None]:
#export
cmp_instance = functools.cmp_to_key(lambda a,b: 0 if a==b else 1 if issubclass(a,b) else -1)

In [None]:
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])

In [None]:
#export
def _p1_anno(f):
    "Get the annotation of first param of `f`"
    ann = [o for n,o in typing.get_type_hints(f).items() if n!='return']
    return ann[0] if ann else object

In [None]:
def _f(a, b): pass
test_eq(_p1_anno(_f), object)
def _f(a, b)->str: pass
test_eq(_p1_anno(_f), object)
def _f(a, b:str)->float: pass
test_eq(_p1_anno(_f), str)
def _f(a:int, b:int)->float: pass
test_eq(_p1_anno(_f), int)

In [None]:
#export
class TransformBase():
    "Delegates (`__call__`,`decode`) to (`encodes`,`decodes`) if `filt` matches"
    filt=None
    def __init__(self, filt=None, whole_tuple=False): self.filt,self.whole_tuple=filt,whole_tuple
    def __call__(self, *args, **kwargs): return self.call(True, *args, **kwargs)
    def decode  (self, *args, **kwargs): return self.call(False, *args, **kwargs)
    def encodes(self, x, *args, **kwargs): return x
    def decodes(self, x, *args, **kwargs): return x
    
    def call(self, is_enc, x, *args, filt=None, **kwargs):
        f = self.func(is_enc, x, filt=filt)
        if is_listy(f): return tuple(self._do_call(f_, x_, *args, filt=filt, **kwargs) for f_,x_ in zip(f,x))
        return self._do_call(f, x, *args, filt=filt, **kwargs)
    
    def func(self, is_enc, x, filt=None):
        if filt!=self.filt and filt is not None: return None
        f = self.encodes if is_enc else self.decodes
        if self.whole_tuple: return f
        t = _p1_anno(f)
        f_ = lambda o: f if isinstance(o,t) else None
        return [f_(x_) for x_ in x] if is_listy(x) else f_(x)
        
    def _do_call(self, f, x, *args, filt=None, **kwargs):
        if f is None: return x
        res = f(x, *args, **kwargs)
        typ_r = ifnone(anno_ret(f), type(x))
        return typ_r(res) if (type(res) != typ_r) and typ_r!=NoneType else res

In [None]:
#export
class ShowTitle:
    def show(self, ctx=None, **kwargs): return show_title(str(self), ctx=ctx)
class Int(int, ShowTitle): pass
class Float(float, ShowTitle): pass
class Str(str, ShowTitle): pass

In [None]:
test_stdout(lambda: Str('s').show(), 's')
test_stdout(lambda: Int(1).show(), '1')

In [None]:
class A(TransformBase):
    def encodes(self, x)->Int: return x/2
    
f = A()
test_eq_type(f(2), Int(1))
test_eq_type(f.decode(2.0), 2.0)

In [None]:
class A(TransformBase):
    def encodes(self, x): return x/2
    
f = A()
test_eq_type(f(Int(2)), Int(1))
test_eq_type(f(2), 1)

In [None]:
class A(TransformBase):
    def encodes(self, x)->None: return x/2
    
f = A()
test_eq_type(f(2), 1.)
test_eq_type(f(2.), 1.)

In [None]:
class A(TransformBase): 
    def encodes(self, x:int)->Int: return x+1
    def decodes(self, x:int): return x-1

f = A()
test_eq_type(f(1), Int(2))
test_eq_type(f(1.), 1.)
t = f((1.,2))
test_eq_type(t, (1.,Int(3)))
test_eq(f.decode(t), (1,2))

f.filt = 1
test_eq(f((1.,2), filt=1), (1.,3))
test_eq_type(f((1.,2), filt=0), (1.,2))

In [None]:
class TransformWhole(TransformBase):
    def __init__(self, filt=None): super().__init__(filt=filt, whole_tuple=True)

In [None]:
#Apply on the tuple as a whole
class A(TransformWhole): 
    def encodes(self, xy): x,y=xy; return (x+y,y)
    def decodes(self, xy): x,y=xy; return (x-y,y)

f = A()
t = f((1,2))
test_eq(t, (3,2))
test_eq(f.decode(t), (1,2))
f.filt = 1
test_eq(f((1,2), filt=1), (3,2))
test_eq(f((1,2), filt=0), (1,2))

In [None]:
#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, *funcs):
        self.funcs,self.cache = {},{}
        for f in funcs: self.add(f)
        
    def _reset(self):
        self.funcs = {k:self.funcs[k] for k in sorted(self.funcs, key=cmp_instance, reverse=True)}
        self.cache = {**self.funcs}
        
    def add(self, f):
        "Add type `t` and function `f`"
        self.funcs[_p1_anno(f) or object] = f
        self._reset()
        
    def __repr__(self): return str(self.funcs)
    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        if k in self.cache: return self.cache[k]
        types = [f for f in self.funcs if issubclass(k,f)]
        res = self.funcs[types[0]] if types else None
        self.cache[k] = res
        return res

In [None]:
def f_col(x:typing.Collection): pass
def f_nin(x:numbers.Integral): pass
def f_num(x:numbers.Number): pass
t = TypeDispatch(f_nin,f_num)

test_eq(t[int], f_nin)
test_eq(t[str], None)
test_eq(t[float], f_num)
t.add(f_col)
test_eq(t[str], f_col)
test_eq(t[int], f_nin)

In [None]:
#export
class TfmMeta(type):
    def __new__(cls, name, bases, dct):
        res = super().__new__(cls, name, bases, dct)
        res.fs = (TypeDispatch(),TypeDispatch())
        return res

    def __call__(cls, *args, **kwargs):
        f = args[0] if args else None
        if isinstance(f,Callable) and f.__name__ in ('decode','encode','_'):
            d = cls.fs[f.__name__ != 'decode']
            d.add(f)
            return f
        return super().__call__(*args, **kwargs)

In [None]:
#export
class Transform(TransformBase, metaclass=TfmMeta):
    def __init__(self, enc=None, dec=None, filt=None):
        super().__init__(filt)
        if not (enc or dec): return
        self.fs = (TypeDispatch(),TypeDispatch())
        if enc: self.fs[True].add(enc)
        if dec: self.fs[False].add(dec)

    def lookup(self, is_enc, x): return MethodType(self.fs[is_enc][type(x)] or noops, self)
    
    def func(self, is_enc, x, filt=None):
        if filt!=self.filt and filt is not None: return None
        f = partial(self.lookup, is_enc)
        return [f(x_) for x_ in x] if is_listy(x) else f(x)

In [None]:
def neg_int(self, x:numbers.Integral): return -x

f = Transform(neg_int)
test_eq(f(1), -1)
test_eq(f(1.), 1.)
test_eq(f((1.,2,3.)), (1.,-2,3.))
test_eq(f.decode((1,2)), (1,2))

In [None]:
def float_to_int(self, x:(float,int))->Int: return x

f = Transform(float_to_int)
test_eq_type(f(1.), Int(1))
test_eq_type(f(1), Int(1))
test_eq_type(f('1'), '1')
test_eq_type(f((1,'1')), (Int(1),'1'))
test_eq(f.decode(1), 1)

In [None]:
class B(Transform): pass
class C(Transform): pass
f = B()
test_eq(f(1), 1)

In [None]:
@B
def _(self, x:int): return x+1
@B
def _(self, x:str): return x+'1'

b,c = B(),C()
test_eq(b(1), 2)
test_eq(b('1'), '11')
test_eq(c(1), 1)
test_eq(b((1,2)), (2,3))
test_eq(b.decode(2), 2)
assert pickle.loads(pickle.dumps(b))

In [None]:
@B
def decode(self, x:int): return x-1
test_eq(b.decode(2), 1)
test_eq(b.decode('2'), '2')

In [None]:
class B(Transform): pass
@B
def _(self, x:int)->Int: return x+1
@B
def _(self, x:str): return x+'1'
@B
def decode(self, x:Int): return x/2

f = B()
start = (1.,2,'3')
t = f(start)
test_eq_type(t, (1.,Int(3),'31'))
test_eq(f.decode(t), (1.,Int(1),'31'))

In [None]:
class A(Transform): pass
@A
def _(self, x): return x+1
@A
def decode(self, x): return x-1

f = A()
t = f((1,2))
test_eq(t, (2,3))
test_eq(f.decode(t), (1,2))

In [None]:
class A(Transform): pass
@A
def _(self, x:numbers.Integral): return x+1
@A
def _(self, x:float): return x*3
@A
def decode(self, x:int): return x-1

f = A()
start = 1.0
t = f(start)
test_eq(t, 3.)
test_eq(f.decode(t), 3)

In [None]:
start = (1.,2,3.)
t = f(start)
test_eq(t, (3.,3,9.))
test_eq(f.decode(t), (3.,2,9.))

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 02_data_pipeline-Copy1.ipynb.
Converted 02_data_pipeline.ipynb.
Converted 02_transforms.ipynb.
Converted 02a_pipeline.ipynb.
Converted 03_data_external.ipynb.
Converted 04_data_core.ipynb.
Converted 05_data_source.ipynb.
Converted 06_vision_core.ipynb.
Converted 07_pets_tutorial-meta.ipynb.
Converted 07_pets_tutorial.ipynb.
Converted 08_vision_augment.ipynb.
Converted 09_data_block-Copy1.ipynb.
Converted 09_data_block.ipynb.
Converted 10_layers.ipynb.
Converted 11_optimizer.ipynb.
Converted 12_learner.ipynb.
Converted 13_callback_schedule.ipynb.
Converted 14_callback_hook.ipynb.
Converted 15_callback_progress.ipynb.
Converted 16_callback_tracker.ipynb.
Converted 17_callback_fp16.ipynb.
Converted 30_text_core.ipynb.
Converted 90_notebook_core.ipynb.
Converted 91_notebook_export.ipynb.
Converted 92_notebook_showdoc.ipynb.
Converted 93_notebook_export2html.ipynb.
Converted 94_index.ipynb.
Converted 95_synth_learner.ipynb.


In [None]:
# get rid of this later if we don't need it
def n_params(f):
    "Count of positional params"
    return len([p for p in inspect.signature(f).parameters.values()
                if p.default == inspect.Parameter.empty and p.kind != inspect.Parameter.VAR_KEYWORD])