In [None]:
#default_exp core

# Core
> Subscribe core functionality.

In [None]:
#export
from fastai2.basics import *

## AttrProxy

In [None]:
#export
class AttrProxy(GetAttr):
    def __init__(self, default): self.default = default

In [None]:
#export
def _get_proxy(x):
    if x.__class__.__module__ != 'builtins': raise ValueError('Use only with builtins')
    name = 'Proxy' + x.__class__.__name__.capitalize()
    return type(name, (x.__class__,), {})(x)

In [None]:
#export
def _add_attr(obj, name, value):
    try:                   
        setattr(obj, name, getattr(obj,'labels',value))
        return obj
    # It's not possible to set attributes on builtin types, so we wrap with a proxy
    except AttributeError: return _add_attr(_get_proxy(obj), name, value)

## Maintain labels

In [None]:
#export
def _maintain_labels(old, new):
    if hasattr(old, 'labels'): new = _add_attr(new, 'labels', old.labels)
    return new

In [None]:
#export
def maintain_labels(f):
    def _inner(fn, x, **kwargs):
        return _maintain_labels(x, f(fn, x, **kwargs))
    return _inner

Patch `Pipeline` with `maintain_labels`

In [None]:
#export
# figure out delegates
_old_init = Pipeline.__init__
@patch
def __init__(self:Pipeline, *args, **kwargs):
    _old_init(self, *args, **kwargs)
    for o in self.fs: o._do_call = maintain_labels(o._do_call)

## Subscribe

In [None]:
#export
# TODO: Can confirm function was called without doing "res is not x"?
@typedispatch
def subscribe(tfm:Transform):
    old_call = tfm._do_call
    tfm.broadcast = True
    def _inner(f):
        def _call(fn, x, **kwargs):
            res = old_call(fn, x, **kwargs)
            res = _maintain_labels(x, res)
            if tfm.broadcast:
                if res is not x: res = f(res)
            return res
        tfm._do_call = _call
        return f
    return _inner

It's possible to turn subscriptions off

In [None]:
#export
@patch
def broadcast(self:Pipeline, v):
    for f in self.fs: f.broadcast = v

## Labeller

In [None]:
#export
class Labeller:
    def __init__(self, abstain='abstain'): self.abstain = abstain
        
    def __call__(self, tfm):
        def _inner(f):
            return subscribe(tfm)(self._add_label(f))
        return _inner
    
    def _add_label(self, f):
        def _inner(x):
            label = ifnone(f(x), self.abstain)
            x = _add_attr(x, 'labels', [])
            x.labels.append(label)
            return x
        return _inner

Tests labeller with arbitrary transforms

In [None]:
CAT1,CAT2 = 'cat1','cat2'

In [None]:
@Transform
def neg(x:Tensor): return -x
class IntDiv(Transform):
    def encodes(self, x:int): return x//2

In [None]:
labeller = Labeller()
int_div = IntDiv()

In [None]:
@labeller(neg)
def labeller_cat1(x): return CAT1
@labeller(neg)
def labeller_cat2(x): return CAT2
@labeller(int_div)
def labeller_cat1(x): return CAT1

In [None]:
pipe = Pipeline(neg)
test_eq(pipe(tensor(2)).labels, ['cat1', 'cat2'])

In [None]:
pipe.broadcast(False)
test_fail(lambda: pipe(tensor(2)).labels, ['cat1', 'cat2'], "'Tensor' object has no attribute 'labels'")

In [None]:
pipe = Pipeline([neg, int_div])
test_eq(pipe(2).labels, ['cat1'])

## Export -

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.transforms.ipynb.
Converted 03_model.majority_label_voter.ipynb.
Converted 05_text.core.ipynb.
Converted 06_text.labellers.ipynb.
Converted Untitled-Copy1.ipynb.
Converted index.ipynb.
Converted resume-Copy1.ipynb.
Converted resume.ipynb.
Converted rx_transform.ipynb.
Converted rx_transform2-Copy1.ipynb.
Converted rx_transform2.ipynb.
