In [None]:
#default_exp data_transforms

In [None]:
from fastai_local.test import *
from fastai_local.core import *

# Basic data loading

> Creates low-level transforms and data sources

## Convenience functions

In [None]:
# export
def opt_call(f, fname='__call__', *args, **kwargs):
    "Call `f.{fname}(*args, **kwargs)`, or `noop` if not defined"
    return getattr(f,fname,noop)(*args, **kwargs)

In [None]:
test_eq(opt_call(operator.neg, '__call__', 2), -2)
test_eq(opt_call(list, 'foobar', [2]), [2])

a=[2,1]
opt_call(list, 'sort', a)
test_eq(a, [1,2])

## Transform -

In [None]:
# export
class Transform():
    "A function that `encodes` if `filt` matches, and optionally `decodes`, with an optional `setup`"
    order,filt = 0,None

    def __init__(self, **kwargs):
        for k,v in kwargs.items(): setattr(self, k, v)

    @classmethod
    def create(cls, f, filt=None):
        "classmethod: Turn `f` into a `Transform` unless it already is one"
        return f if hasattr(f,'decode') or isinstance(f,Transform) else cls(f)
    
    def __call__(self, o, filt=None, **kwargs): 
        "Call `self.encodes` unless `filt` is passed and it doesn't match `self.filt`"
        if self.filt is not None and self.filt!=filt: return o
        return self.encodes(o, **kwargs)

    def decode(self, o, filt=None, **kwargs): 
        "Call `self.decodes` unless `filt` is passed and it doesn't match `self.filt`"
        if self.filt is not None and self.filt!=filt: return o
        return self.decodes(o, **kwargs)
    
    def __repr__(self): return str(self.encodes) if self.__class__==Transform else str(self.__class__)
    def decodes(self, o, *args, **kwargs): return o

In a transformation pipeline some steps need to be reversible - for instance, if you turn a string (such as *dog*) into an int (such as *1*) for modeling, then for display purposes you'll want to turn it back to a string again (e.g. when you have a prediction). In addition, you may wish to only run the transformation for a particular data subset, such as the training set.

`Transform` provides all this functionality. `filt` is some dataset index (e.g. provided by `DataSource`), and you provide `encodes` and optional `decodes` functions for your code. You can pass `encodes` and `decodes` functions directly to the constructor for quickly creating simple transforms, or you (more common) you can subclass `Transform` and define these methods.

In [None]:
def add(x, a=1): return x+a
def add_undo(x, a=1): return x-a
addt  = Transform(encodes=add, decodes=add_undo)

start = 4
t = addt(start)
test_eq(t, 5)
test_eq(addt.decode(5), start)

In [None]:
class AddTfm(Transform):
    def encodes(self, x, a=1): return x+a
    def decodes(self, x, a=1): return x-a
    
addt  = AddTfm()
start = 4
t = addt(start)
test_eq(t, 5)
test_eq(addt.decode(5), start)

In [None]:
show_doc(Transform.__call__)

<h4 id="Transform.__call__" class="doc_header"><code>__call__</code><a class="source_link" data-toggle="collapse" data-target="#Transform-__call__-pytest" style="float:right; padding-right:10px">[test]</a></h4>

> <code>__call__</code>(**`o`**, **`filt`**=***`None`***, **\*\*`kwargs`**)

<div class="collapse" id="Transform-__call__-pytest"><div class="card card-body pytest_card"><a type="button" data-toggle="collapse" data-target="#Transform-__call__-pytest" class="close" aria-label="Close"><span aria-hidden="true">&times;</span></a><p>No tests found for <code>__call__</code>. To contribute a test please refer to <a href="/dev/test.html">this guide</a> and <a href="https://forums.fast.ai/t/improving-expanding-functional-tests/32929">this discussion</a>.</p></div></div>

Call `self.encodes` unless `filt` is passed and it doesn't match `self.filt`  

In [None]:
show_doc(Transform.decode)

<h4 id="Transform.decode" class="doc_header"><code>decode</code><a class="source_link" data-toggle="collapse" data-target="#Transform-decode-pytest" style="float:right; padding-right:10px">[test]</a></h4>

> <code>decode</code>(**`o`**, **`filt`**=***`None`***, **\*\*`kwargs`**)

<div class="collapse" id="Transform-decode-pytest"><div class="card card-body pytest_card"><a type="button" data-toggle="collapse" data-target="#Transform-decode-pytest" class="close" aria-label="Close"><span aria-hidden="true">&times;</span></a><p>No tests found for <code>decode</code>. To contribute a test please refer to <a href="/dev/test.html">this guide</a> and <a href="https://forums.fast.ai/t/improving-expanding-functional-tests/32929">this discussion</a>.</p></div></div>

Call `self.decodes` unless `filt` is passed and it doesn't match `self.filt`  

In [None]:
show_doc(Transform.create)

<h4 id="Transform.create" class="doc_header"><code>create</code><a class="source_link" data-toggle="collapse" data-target="#Transform-create-pytest" style="float:right; padding-right:10px">[test]</a></h4>

> <code>create</code>(**`f`**, **`filt`**=***`None`***)

<div class="collapse" id="Transform-create-pytest"><div class="card card-body pytest_card"><a type="button" data-toggle="collapse" data-target="#Transform-create-pytest" class="close" aria-label="Close"><span aria-hidden="true">&times;</span></a><p>No tests found for <code>create</code>. To contribute a test please refer to <a href="/dev/test.html">this guide</a> and <a href="https://forums.fast.ai/t/improving-expanding-functional-tests/32929">this discussion</a>.</p></div></div>

classmethod: Turn `f` into a `Transform` unless it already is one  

## Pipeline -

In [None]:
class Pipeline():
    "A pipeline of transforms, composed and applied for encode/decode, and setup one at a time"
    def __init__(self, tfms):
        self.tfms,self._tfms = [],[Transform.create(t) for t in listify(tfms)]

    def __call__(self, x, **kwargs): return self._apply(x, **kwargs)
    def decode(self, x, **kwargs): return self._apply(x, rev=True, fname='decode', **kwargs)
    def _apply(self, x, rev=False, fname='__call__', **kwargs):
        tfms = reversed(self.tfms) if rev else self.tfms
        for f in tfms: x = opt_call(f, fname, x, **kwargs)
        return x

    def __repr__(self): return str(self.tfms)
    def delete(self, idx): del(self.tfms[idx])
    def remove(self, tfm): self.tfms.remove(tfm)

    def setup(self, items=None): self.add(self._tfms, items)
    def add(self, tfm, items):
        # We only add one at a time so that each setup has access to correct tfm subset
        for t in sorted(listify(tfm), key=lambda o: getattr(o, 'order', 0)):
            self.tfms.append(t)
            opt_call(t, 'setup', items)
    
    def __getattr__(self, k):
        for t in reversed(self.tfms):
            a = getattr(t, k, None)
            if a is not None: return a
        raise AttributeError(k)

A list of transforms are often applied in a particular order, and decoded by applying in the reverse order. `Pipeline` provides this functionality, and also ensures that any `setup` methods are called, without including later transforms in those calls. NB: `setup` must be run before encoding/decoding.

Here's some simple examples 

In [None]:
# test
def add(x, a=1): return x+a
def multiply(x, a=2): return x*a
def square(x): return x**2
def add_undo(x, a=1): return x-a
def multiply_undo(x, a=2): return x/a
tadd  = Transform(encodes=add, decodes=add_undo, order=2)
tmult = Transform(encodes=multiply, decodes=multiply_undo, order=1)
tsqr  = Transform(encodes=square, order=0)

pln = Pipeline([tadd,tmult,tsqr])
pln.setup()
start = 2
t = pln(2)
test_eq(t, ((2**2) * 2) + 1)
test_eq(pln.decode(t), (9-1)/2)

## TfmList

In [None]:
#export
class TfmList():
    def __init__(self, ttfms, tfm=noop):
        self.activ,self.ttfms = False,[Transforms(tfm) for tfm in listify(ttfms)]

    def __call__(self, o, **kwargs):
        if self.activ: return self.activ(o, **kwargs)
        return [t(o, **kwargs) for t in self.ttfms]
    
    def decode(self, o, **kwargs): return [t.decode(p, **kwargs) for p,t in zip(o,self.ttfms)]

    def setup(self, o):
        for tfm in self.ttfms:
            self.activ = tfm
            tfm.setup(o)
        self.activ=None
        
    def show(self, o, **kwargs): return show_xs(o, self.ttfms, **kwargs)
    def __repr__(self): return f'TfmList({self.ttfms})'
    
    @property
    def xt(self): return self.ttfms[0]
    @property
    def yt(self): return self.ttfms[1]