In [None]:
#default_exp core

# Core
> Subscribe core functionality.

In [None]:
#export
from fastai2.basics import *
from functools import wraps

## Cache TfmdLists

In [None]:
#export
_old_tfmdlists_init = TfmdLists.__init__
@patch
def __init__(self:TfmdLists, items, tfms, **kwargs):
    _old_tfmdlists_init(self, items, tfms, **kwargs)
    self.cached = False

In [None]:
#export
_old_tfmdlists_new = TfmdLists._new
@patch
def _new(self:TfmdLists, items, **kwargs):
    tls = _old_tfmdlists_new(self, items, )
    tls.cached = self.cached
    return tls

In [None]:
#export
@patch
def cache(self:TfmdLists, tfms=None, pbar=True):
    tfms = Pipeline(tfms)
    self.items = [tfms(o) for o in (progress_bar(self) if pbar else self)]
    self.cached = True

In [None]:
#export
_old_getitem = TfmdLists.__getitem__
@patch
def __getitem__(self:TfmdLists, idx):
    if self.cached: return super(TfmdLists, self).__getitem__(idx)
    else:      return _old_getitem(self, idx)

In [None]:
splits = [[0],[1]]
lazy = TfmdLists([1,2], [partial(random.randint, b=1e6)], splits=splits)
test_ne(lazy[0], lazy[0])
test_ne(lazy.valid[0], lazy.valid[0])

In [None]:
lazy.cache(pbar=False)
test_eq(lazy[0], lazy[0])
test_eq(lazy.valid[0], lazy.valid[0])
test_ne(lazy.train[0], lazy.valid[0])

## AttrProxy

It's not possible to add attributes to _builtins_ types, so we wrap those with `AttrProxy`.

In [None]:
#export
class AttrProxy(GetAttr):
    def __init__(self, default): self.default = default

In [None]:
#export
def _get_proxy(x):
    if x.__class__.__module__ != 'builtins': raise ValueError('Use only with builtins')
    name = 'Proxy' + x.__class__.__name__.capitalize()
    return type(name, (x.__class__,), {})(x)

In [None]:
#export
def _add_attr(obj, name, value):
    try:                   
        setattr(obj, name, getattr(obj,'labels',value))
        return obj
    except AttributeError: return _add_attr(_get_proxy(obj), name, value)

## Maintain labels

Maintain the `labels` attribute when objects are modified.

In [None]:
#export
def _maintain_labels(old, new):
    if hasattr(old, 'labels'): new = _add_attr(new, 'labels', old.labels)
    return new

In [None]:
#export
def maintain_labels(f):
    def _inner(fn, x, **kwargs):
        return _maintain_labels(x, f(fn, x, **kwargs))
    return _inner

Patch `Pipeline` with `maintain_labels`

In [None]:
#export
# figure out delegates
_old_pipe_init = Pipeline.__init__
@patch
def __init__(self:Pipeline, *args, **kwargs):
    _old_pipe_init(self, *args, **kwargs)
    for o in self.fs: o._do_call = maintain_labels(o._do_call)

## Subscribe

`Subscribe` is used to inject arbitrary functions that execute when the original subscribed object is called.

In [None]:
#export
# TODO: Can confirm function was called without doing "res is not x"?
@typedispatch
def subscribe(tfm:Transform, func_order=None):
    old_call = tfm._do_call
    tfm.broadcast = True
    def _inner(f):
        def _call(fn, x, **kwargs):
            res = old_call(fn, x, **kwargs)
            res = _maintain_labels(x, res)
            if tfm.broadcast:
                if func_order is not None: func_order.append(f.__name__)
                if res is not x: res = f(res)
            return res
        tfm._do_call = _call
        return f
    return _inner

It's possible to turn subscriptions off

In [None]:
#export
@patch
def broadcast(self:Pipeline, v):
    for f in self.fs: f.broadcast = v

## Labeller

`Labeller` wraps `subscribe` and saves the returned value of wrapped functions in a attribute called `labels` in the original object. 

In [None]:
#export
class UniqueList(L):
    def append(self, o):
        if o not in self.items: super().append(o)

In [None]:
#export
class Labeller:
    def __init__(self, abstain='abstain'):
        self.func_order,self.abstain = UniqueList(),abstain
        
    def __call__(self, tfm):
        def _inner(f):
            return subscribe(tfm, self.func_order)(self._add_label(f))
        return _inner
    
    def _add_label(self, f):
        @wraps(f)
        def _inner(x):
            label = ifnone(f(x), self.abstain)
            x = _add_attr(x, 'labels', [])
            x.labels.append(label)
            return x
        return _inner

Tests labeller with arbitrary transforms

In [None]:
CAT1,CAT2 = 'cat1','cat2'

In [None]:
@Transform
def neg(x:Tensor): return -x
class IntDiv(Transform):
    def encodes(self, x:int): return x//2

In [None]:
labeller = Labeller()
int_div = IntDiv()

In [None]:
@labeller(neg)
def labeller_cat1(x): return CAT1
@labeller(neg)
def labeller_cat2(x): return CAT2
@labeller(int_div)
def labeller_cat3(x): return CAT1

In [None]:
pipe = Pipeline(neg)
test_eq(pipe(tensor(2)).labels, ['cat1', 'cat2'])

In [None]:
pipe.broadcast(False)
test_fail(lambda: pipe(tensor(2)).labels, ['cat1', 'cat2'], "'Tensor' object has no attribute 'labels'")

In [None]:
pipe = Pipeline([neg, int_div])
test_eq(pipe(2).labels, ['cat1'])

## Tasks labels helper

Extract the `labels` from a `TfmdLists`.

In [None]:
#export
def tasks_labels(tls, vocab, splits=None, lazy=False):
    tasks = TfmdLists(tls, [AttrGetter('labels'), MultiCategorize(vocab)], splits=splits)
    if not lazy: tasks.cache()
    return tasks

## Export -

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.transforms.ipynb.
Converted 03_model.majority_label_voter.ipynb.
Converted 05_text.core.ipynb.
Converted 06_text.labellers.ipynb.
Converted Untitled-Copy1.ipynb.
Converted index.ipynb.
Converted resume-Copy1.ipynb.
Converted resume.ipynb.
Converted rx_transform.ipynb.
Converted rx_transform2-Copy1.ipynb.
Converted rx_transform2.ipynb.
