In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| export 
from fastai.data.all import *
from fastai.text.models.core import *
from fastai.text.models.awdlstm import *
from xcube.layers import *

In [None]:
#| default_exp text.models.core

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},}

# Core XML Text Modules
> Contain the modules needed to build different XML architectures and the generic functions to get those models.

The models provided here are variations of the ones provided by [fastai](https://docs.fast.ai/text.models.core.html) with modifications tailored for XML.

## Basic Models

In [None]:
#| export
class SequentialRNN(nn.Sequential):
    "A sequential pytorch module that passes the reset call to its children."
    def reset(self):
        for c in self.children(): getattr(c, 'reset', noop)()

## Classification Models

In [None]:
#| export
def _pad_tensor(t, bs):
    if t.size(0) < bs: return torch.cat([t, t.new_zeros(bs-t.size(0), *t.shape[1:])])
    return t

In [None]:
#| export
class AttentiveSentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        mask = torch.cat(masks, dim=1)
        return outs,mask

:::{.callout-warning}

This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of bptt (and the rest of the padding at the end). Use `pad_input_chunk` to get your data in a suitable format.

:::

In [None]:
#| export
def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.
    return x

In [None]:
#| export
class OurPoolingLinearClassifier(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.layer = LinBnDrop(dims[0], dims[1], p=ps, act=None)
        self.bptt = bptt

    def forward(self, input):
        out, mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layer(x)
        return x, out, out

Note that `OurPoolingLinearClassifier` is exactly same as fastai's [`PoolingLinearClassifier`](https://docs.fast.ai/text.models.core.html#poolinglinearclassifier) except that we do not do the feature compression from 1200 to 50 linear features. 

Note: Also try `OurPoolingLinearClassifier` w/o dropouts and batch normalization (Verify this, but as far as what I found it does not work well as compared to /w batch normalization)

In [None]:
#| export
class LabelAttentionClassifier(Module):
    initrange=0.1
    def __init__(self, n_hidden, n_lbs, y_range=None):
        store_attr('n_hidden,n_lbs,y_range')
        self.pay_attn = XMLAttention(self.n_lbs, self.n_hidden)
        self.final_lin = nn.Linear(self.n_hidden, self.n_lbs) 
        init_default(self.final_lin, func=partial(torch.nn.init.uniform_, a=-self.initrange, b=self.initrange))
    
    def forward(self, input):
        out, _ = input
        ctx = self.pay_attn(out) #shape (bs, n_lbs, n_hidden)
        x = (self.final_lin.weight * ctx).sum(dim=2) + self.final_lin.bias
        
        if self.y_range is not None: x = sigmoid_range(x, *self.y_range)
        return x, out, out

In [None]:
attn_clas = LabelAttentionClassifier(400, 1271)
test_eq(getattrs(attn_clas, 'n_hidden', 'n_lbs'), (400, 1271))
outs, mask = (torch.randn(16, 72*20, 400), torch.randint(2, size=(16, 72*20)))
x, *_ = attn_clas((outs, mask))
test_eq(x.shape, (16, 1271))

In [None]:
#| export
def get_xmltext_classifier(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., pad_idx=1, max_len=72*20, y_range=None):
    "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_clas']).copy()
    for k in config.keys():
        if k.endswith('_p'): config[k] *= drop_mult
    n_hidden = config[meta['hid_name']]
    config.pop('output_p')
    init = config.pop('init') if 'init' in config else None
    encoder = AttentiveSentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)
    decoder = LabelAttentionClassifier(n_hidden, n_class, y_range=y_range)
    model = SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()