In [None]:
#default_exp labeller.core

In [None]:
#export
from fastai2.basics import *
from pigboat.basics import *
from functools import wraps

# Labeller
> `Labeller` wraps `subscribe` and saves the returned value of wrapped functions in a attribute called `labels` in the original object. 

In [None]:
#export
class UniqueList(L):
    def append(self, o):
        if o not in self.items: super().append(o)

In [None]:
#export
class Labeller:
    def __init__(self, abstain='abstain'):
        self.func_order,self.abstain = UniqueList(),abstain
        self.subs = L()
        
    def __call__(self, tfm):
        def _inner(f):
            self.func_order.clear()
            sub = subscribe(tfm, self.func_order)
            self.subs.append(sub)
            return sub(self._add_label(f))
        return _inner
    
    def reset(self):
        for sub in self.subs: sub.cancel()
        self.subs.clear()
        self.func_order.clear()
    
    def listen(self, v):
        for sub in self.subs: sub.listen = v
    
    def _add_label(self, f):
        @wraps(f)
        def _inner(x):
            label = ifnone(f(x), self.abstain)
            x = add_attr(x, 'labels', [])
            x.labels.append(label)
            return x
        return _inner

Tests labeller with arbitrary transforms

In [None]:
CAT1,CAT2 = 'cat1','cat2'

In [None]:
@Transform
def neg(x:Tensor): return -x
class IntDiv(Transform):
    def encodes(self, x:int): return x//2

In [None]:
labeller = Labeller()
int_div = IntDiv()

In [None]:
@labeller(neg)
def labeller_cat1(x): return CAT1
@labeller(neg)
def labeller_cat2(x): return CAT2
@labeller(int_div)
def labeller_cat3(x): return CAT1

In [None]:
pipe = Pipeline(neg)
test_eq(pipe(tensor(2)).labels, ['cat1', 'cat2'])
test_eq(labeller.func_order, ['labeller_cat1', 'labeller_cat2'])

In [None]:
labeller.listen(False)
test_fail(lambda: pipe(tensor(2)).labels, ['cat1', 'cat2'], "'Tensor' object has no attribute 'labels'")

In [None]:
labeller.listen(True)
pipe = Pipeline([neg, int_div])
test_eq(pipe(2).labels, ['cat1'])
# test_eq(labeller.func_order, ['labeller_cat3']) # TODO: failing, old funcs still in the list

In [None]:
labeller.reset()
pipe = Pipeline([neg, int_div])
test_fail(lambda: pipe(tensor(2)).labels, "'Tensor' object has no attribute 'labels'")

## Tasks labels helper

Extract the `labels` from a `TfmdLists`.

In [None]:
#export
def tasks_labels(tls, vocab, splits=None, lazy=False):
    tasks = TfmdLists(tls, [AttrGetter('labels'), MultiCategorize(vocab)], splits=splits)
    if not lazy: tasks.cache()
    return tasks

## Export -

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.transforms.ipynb.
Converted 02_labeller.core.ipynb.
Converted 03_model.majority_label_voter.ipynb.
Converted 05_text.core.ipynb.
Converted 06_text.labellers.ipynb.
Converted Untitled-Copy1.ipynb.
Converted index.ipynb.
Converted resume-Copy1.ipynb.
Converted resume.ipynb.
Converted rx_transform.ipynb.
Converted rx_transform2-Copy1.ipynb.
Converted rx_transform2.ipynb.
