In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| export 
from fastai.data.all import *
from fastai.text.models.core import *
from fastai.text.models.awdlstm import *
from xcube.layers import *

In [None]:
#| default_exp text.models.core

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},}

# Core XML Text Modules
> Contain the modules needed to build different XML architectures and the generic functions to get those models.

The models provided here are variations of the ones provided by [fastai](https://docs.fast.ai/text.models.core.html) with modifications tailored for XML.

## Basic Models

In [None]:
#| export
class SequentialRNN(nn.Sequential):
    "A sequential pytorch module that passes the reset call to its children."
    def reset(self):
        for c in self.children(): getattr(c, 'reset', noop)()

## Classification Models

In [None]:
#| export
def _pad_tensor(t, bs):
    if t.size(0) < bs: return torch.cat([t, t.new_zeros(bs-t.size(0), *t.shape[1:])])
    return t

The `SentenceEncoder` below is the [fastai's source code](https://docs.fast.ai/text.models.core.html#sentenceencoder). Copied here for understanding its components and chaning it to `AttentiveSentenceEncoder`: 

In [None]:
#| export
class SentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        inps = input[:, -outs.shape[1]:] # the ofsetted tokens for the outs
        mask = torch.cat(masks, dim=1)
        return inps,outs,mask

:::{.callout-warning}

This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of bptt (and the rest of the padding at the end). Use `pad_input_chunk` to get your data in a suitable format.

:::

In [None]:
#| export
class AttentiveSentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, decoder, pad_idx=1, max_len=None, running_decoder=True): 
        store_attr('bptt,module,decoder,pad_idx,max_len,running_decoder')
        self.n_lbs = getattr(self.decoder, 'n_lbs', None)
        
    def reset(self): 
        getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        self.decoder.hl = input.new_zeros((bs, self.n_lbs))
        # print(f"Starting to read a btch of docs start hl.sum() = {self.decoder.hl.sum()}", end='\n')
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            chunk = slice(i, min(i+self.bptt, sl))
            o = self.module(input[:real_bs, chunk]) # shape (bs, bptt, nh)
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:, chunk])
                # print(f"\t\t (Within max_len) After reading bptt chunk: hl.sum() = {self.decoder.hl.sum()}", end='\n')
            elif self.running_decoder:
                mask_slice = mask[:real_bs, chunk] 
                inp = input[:real_bs, chunk]
                # import pdb; pdb.set_trace()
                hl, *_ = self.decoder((inp, o, mask_slice))
                self.decoder.hl = hl.sigmoid()#.detach()
                # print(f"\t (Outside max_len) After reading bptt chunk: hl.sum() = {self.decoder.hl.sum()}", end='\n')
                
        # import pdb; pdb.set_trace()
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        inps = input[:, -outs.shape[1]:] # the ofsetted tokens for the outs
        mask = torch.cat(masks, dim=1)
        return inps, outs, mask

In [None]:
#| export
def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.
    return x

In [None]:
#| export
class XPoolingLinearClassifier(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.layer = LinBnDrop(dims[0], dims[1], p=ps, act=None)
        self.bptt = bptt

    def forward(self, input):
        out, mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layer(x)
        return x, out, out

Note that `XPoolingLinearClassifier` is exactly same as fastai's [`PoolingLinearClassifier`](https://docs.fast.ai/text.models.core.html#poolinglinearclassifier) except that we do not do the feature compression from 1200 to 50 linear features. 

Note: Also try `XPoolingLinearClassifier` w/o dropouts and batch normalization (Verify this, but as far as what I found it does not work well as compared to /w batch normalization)

In [None]:
#| export
from xcube.layers import _create_bias

In [None]:
#| export
from xcube.utils import *

In [None]:
#| export
class LabelAttentionClassifier(Module):
    initrange=0.1
    def __init__(self, n_hidden, n_lbs, y_range=None):
        store_attr('n_hidden,n_lbs,y_range')
        self.pay_attn = XMLAttention(self.n_lbs, self.n_hidden)
        self.boost_attn = ElemWiseLin(self.n_lbs, self.n_hidden)
        self.label_bias = _create_bias((self.n_lbs,), with_zeros=False)
        self.hl = torch.zeros(1)
    
    def forward(self, sentc):
        if isinstance(sentc, tuple): inp, sentc, mask = sentc # sentc is the stuff coming outta SentenceEncoder i.e., shape (bs, max_len, nh) in other words the concatenated output of the AWD_LSTM
        test_eqs(inp.shape, sentc.shape[:-1], mask.shape)
        sentc = sentc.masked_fill(mask[:, :, None], 0)
        attn, wgts, lbs_cf = self.pay_attn(inp, sentc, mask) #shape (bs, n_lbs, n_hidden)
        attn = self.boost_attn(attn) # shape (bs, n_lbs, n_hidden)
        bs = self.hl.size(0)
        self.hl = self.hl.to(sentc.device)
        pred = self.hl + _pad_tensor(attn.sum(dim=2), bs) + self.label_bias  # shape (bs, n_lbs)
        
        # if lbs_cf is not None: 
        #     lbs_cf = _pad_tensor(lbs_cf, bs)
        #     pred.add_(lbs_cf) 
        
        if self.y_range is not None: pred = sigmoid_range(pred, *self.y_range)
        return pred, attn, wgts 

TODOS: Deb 
- ~Find out what happens with respect to RNN Regularizer callback after LabelAttentionClassifier returns a tuple of 3. (Check the learner cbs and follow the `RNNcallback`)~
- ~Check if we are losing anything by ignoring the mask in `LabelAttentionClassifier`. That is should we be ignoring the masked tokens while computing atten wgts.~  
- Change the label bias initial distribution from uniform to the one we leanerd seperately.
- ~Implement Treacher Forcing~

In [None]:
# %%debug
attn_clas = LabelAttentionClassifier(400, 1271)
test_eq(getattrs(attn_clas, 'n_hidden', 'n_lbs'), (400, 1271))
inps, outs, mask = torch.zeros(16, 72*20).random_(10), torch.randn(16, 72*20, 400), torch.randint(2, size=(16, 72*20))
x, *_ = attn_clas((inps, outs, mask))
test_eq(x.shape, (16, 1271))

In [None]:
#| export
def get_xmltext_classifier(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., pad_idx=1, max_len=72*20, y_range=None):
    "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_clas']).copy()
    for k in config.keys():
        if k.endswith('_p'): config[k] *= drop_mult
    n_hidden = config[meta['hid_name']]
    config.pop('output_p')
    init = config.pop('init') if 'init' in config else None
    encoder = SentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)
    decoder = LabelAttentionClassifier(n_hidden, n_class, y_range=y_range)
    model = SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)

In [None]:
#| hide
def awd_lstm_xclas_split(model):
    "Split a RNN `model` in groups for differential learning rates."
    groups = [nn.Sequential(model[0].module.encoder, model[0].module.encoder_dp)]
    groups += [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]
    groups = L(groups + [model[1].pay_attn, model[1].boost_attn])
    return groups.map(params)+model[1].label_bias

In [None]:
#| export
def get_xmltext_classifier2(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., pad_idx=1, max_len=72*20, y_range=None, running_decoder=True):
    "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_clas']).copy()
    for k in config.keys():
        if k.endswith('_p'): config[k] *= drop_mult
    n_hidden = config[meta['hid_name']]
    config.pop('output_p')
    init = config.pop('init') if 'init' in config else None
    decoder = LabelAttentionClassifier(n_hidden, n_class, y_range=y_range)
    encoder = AttentiveSentenceEncoder(seq_len, arch(vocab_sz, **config), decoder, pad_idx=pad_idx, max_len=max_len, running_decoder=running_decoder)
    model =  SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)

In [None]:
#| hide
assert _model_meta[AWD_LSTM]['config_clas'] == awd_lstm_clas_config
model = get_xmltext_classifier2(AWD_LSTM, 60000, 1271, seq_len=72, config=awd_lstm_clas_config, 
                               drop_mult=0.1, max_len=72*40)
assert hasattr(model[0], 'decoder') # encoder knows about the decoder
assert model[0].decoder is model[1] 

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()