In [None]:
#default_exp core

# Core
> Subscribe core functionality.

In [None]:
#export
from fastai2.basics import *
from functools import wraps

## Utils

In [None]:
#export
def instantiate(o):
    "Instantiate `o` if it's a class"
    return o() if isinstance(o,type) else o

In [None]:
#export
def split_batch(dl, b):
    i = getattr(dl, 'n_inp', 1 if len(b)==1 else len(b)-1)
    return map(detuplify, (b[:i],b[i:]))

## Cache TfmdLists

In [None]:
#export
_old_tfmdlists_init = TfmdLists.__init__
@patch
def __init__(self:TfmdLists, items, tfms, **kwargs):
    _old_tfmdlists_init(self, items, tfms, **kwargs)
    self.cached = False

In [None]:
#export
_old_tfmdlists_new = TfmdLists._new
@patch
def _new(self:TfmdLists, items, **kwargs):
    tls = _old_tfmdlists_new(self, items, )
    tls.cached = self.cached
    return tls

In [None]:
#export
@patch
def cache(self:TfmdLists, tfms=None, pbar=True):
    tfms = Pipeline(tfms)
    self.items = [tfms(o) for o in (progress_bar(self) if pbar else self)]
    self.cached = True

In [None]:
#export
_old_getitem = TfmdLists.__getitem__
@patch
def __getitem__(self:TfmdLists, idx):
    if self.cached: return super(TfmdLists, self).__getitem__(idx)
    else:      return _old_getitem(self, idx)

In [None]:
splits = [[0],[1]]
lazy = TfmdLists([1,2], [partial(random.randint, b=1e6)], splits=splits)
test_ne(lazy[0], lazy[0])
test_ne(lazy.valid[0], lazy.valid[0])

In [None]:
lazy.cache(pbar=False)
test_eq(lazy[0], lazy[0])
test_eq(lazy.valid[0], lazy.valid[0])
test_ne(lazy.train[0], lazy.valid[0])

## AttrProxy

It's not possible to add attributes to _builtins_ types, so we wrap those with `AttrProxy`.

In [None]:
#export
class AttrProxy(GetAttr):
    def __init__(self, default): self.default = default

In [None]:
#export
def _get_proxy(x):
    if x.__class__.__module__ != 'builtins': raise ValueError('Use only with builtins')
    name = 'Proxy' + x.__class__.__name__.capitalize()
    return type(name, (x.__class__,), {})(x)

In [None]:
#export
def add_attr(obj, name, value):
    try:                   
        setattr(obj, name, getattr(obj,'labels',value))
        return obj
    except AttributeError: return add_attr(_get_proxy(obj), name, value)

## Maintain labels

Maintain the `labels` attribute when objects are modified.

In [None]:
#export
def _maintain_labels(old, new):
    if hasattr(old, 'labels'): new = add_attr(new, 'labels', old.labels)
    return new

In [None]:
#export
def maintain_labels(f):
    def _inner(fn, x, **kwargs):
        return _maintain_labels(x, f(fn, x, **kwargs))
    return _inner

Patch `Pipeline` with `maintain_labels`

In [None]:
#export
# figure out delegates
_old_pipe_init = Pipeline.__init__
@patch
def __init__(self:Pipeline, *args, **kwargs):
    _old_pipe_init(self, *args, **kwargs)
    for o in self.fs: o._do_call = maintain_labels(o._do_call)

## Subscribe

`Subscribe` is used to inject arbitrary functions that execute when the original subscribed object is called.

In [None]:
#export
# TODO: Can confirm function was called without doing "res is not x"?
class subscribe:
    def __init__(self, tfm, func_order=None):
        store_attr(self, 'tfm,func_order')
        self.old_call,self.listen = tfm._do_call,True
        
    def __call__(self, f):
        def _call(fn, x, **kwargs):
            res = self.old_call(fn, x, **kwargs)
            res = _maintain_labels(x, res)
            if self.listen:
                if self.func_order is not None: self.func_order.append(f.__name__)
                if res is not x: res = f(res)
            return res
        self.tfm._do_call = _call
        return f
    
    def cancel(self): self.tfm._do_call = MethodType(Transform._do_call, self.tfm)

It's possible to turn subscriptions off.  
**Deprecated**, use `labeller.listen`

In [None]:
#export
# @patch
# def broadcast(self:Pipeline, v):
#     for f in self.fs: f.broadcast = v

## Export -

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 02_labeller.core.ipynb.
Converted 02a_labeller.metrics.ipynb.
Converted 03_model.majority_label_voter.ipynb.
Converted 05_text.core.ipynb.
Converted 06_text.labellers.ipynb.
Converted index.ipynb.
