In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| export 
from fastai.data.all import *
from fastai.text.models.core import *
from fastai.text.models.awdlstm import *
from xcube.layers import *

In [None]:
#| default_exp text.models.core

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},}

# Core XML Text Modules
> Contain the modules needed to build different XML architectures and the generic functions to get those models.

The models provided here are variations of the ones provided by [fastai](https://docs.fast.ai/text.models.core.html) with modifications tailored for XML.

## Basic Models

In [None]:
#| export
class SequentialRNN(nn.Sequential):
    "A sequential pytorch module that passes the reset call to its children."
    def reset(self):
        for c in self.children(): getattr(c, 'reset', noop)()

## Classification Models

In [None]:
#| export
def _pad_tensor(t, bs):
    if t.size(0) < bs: return torch.cat([t, t.new_zeros(bs-t.size(0), *t.shape[1:])])
    return t

In [None]:
#| export
class SentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        mask = torch.cat(masks, dim=1)
        return outs,mask

:::{.callout-warning}

This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of bptt (and the rest of the padding at the end). Use `pad_input_chunk` to get your data in a suitable format.

:::

Under DEV:

In [None]:
#| export
class AttentiveSentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        mask = torch.cat(masks, dim=1)
        return outs,mask

Examples:

In [None]:
#| hide
#| eval: false
config = awd_lstm_clas_config.copy()
del config['output_p']
config
encoder = SentenceEncoder(72, AWD_LSTM(vocab_sz=100, **config), pad_idx=1, max_len=72*20)
encoder

In [None]:
#| export
def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.
    return x

In [None]:
#| hide
#| eval: false
# x = to_device(torch.randint(low=0, high=100, size=(128, 85))) # if you want to send it to gpu
x = torch.randint(low=0, high=100, size=(128, 85)) 
x.device
out, mask = encoder(x)
out.shape, mask.shape

In [None]:
#| export
class PoolingLinearClassifier(Module):
    "Create a linear classifier with pooling"
    def __init__(self, dims, ps, bptt, y_range=None):
        if len(ps) != len(dims)-1: raise ValueError("Number of layers and dropout values do not match.")
        acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]
        layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        self.layers = nn.Sequential(*layers)
        self.bptt = bptt

    def forward(self, input):
        out,mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layers(x)
        return x, out, out

In [None]:
#| hide
#| eval: false
x = masked_concat_pool(out, mask, bptt=72)
x.shape

The output of `masked_concat_pool` is fed into the decoder. So Let's now check out the decoder which compresses the incoming features (in this case 1200) to 50 linear features and then outputs the number of classes (in this example 6594).

In [None]:
#| hide
#| eval: false
layers = [1200, 50, 6594]
ps = [0.04, 0.1]
# decoder = PoolingLinearClassifier(layers, ps, bptt=72).cuda() # if gpu available
decoder = PoolingLinearClassifier(layers, ps, bptt=72)
decoder

preds, *_ = decoder((out, mask))

preds.shape

Breaking down the `PoolingLinearClassifier.__init__`:

Note that in the `__init__` while creating `PoolingLinearClassifier` `dims` is `layers`

In [None]:
#| hide
#| eval: false
dims = layers
print(f"{dims = }")

print(f"{ps = }")

# Also note that `bptt` is `seq_len`

bptt = 72
print(f"{bptt = }")

y_range = None

if len(ps) != len(dims) - 1: raise ValueError("Number of layers and dopout values do not match.")

acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]
acts

for i, o, p, a in zip(dims[:-1], dims[1:], ps, acts):
    print(f"{i = }, {o = }, {p = }, {a = }")

layers = [LinBnDrop(i, o, p=p, act=a) for i, o, p, a in zip(dims[:-1], dims[1:], ps, acts)]
layers

In [None]:
#| export
class OurPoolingLinearClassifier(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.layer = LinBnDrop(dims[0], dims[1], p=ps, act=None)
        self.bptt = bptt

    def forward(self, input):
        out, mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layer(x)
        return x, out, out

Note that `OurPoolingLinearClassifier` is exactly same as fastai's [`PoolingLinearClassifier`](https://docs.fast.ai/text.models.core.html#poolinglinearclassifier) except that we do not do the feature compression from 1200 to 50 linear features. 

In [None]:
#| hide
#| eval: false
decoder = OurPoolingLinearClassifier(dims=[1200, 6594], ps=0.04, bptt=72)

decoder

Note: Also try `OurPoolingLinearClassifier` w/o dropouts and batch normalization (Verify this, but as far as what I found it does not work well as compared to /w batch normalization)

In [None]:
#| export
class LabelAttentionClassifier(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.fts = dims[0]
        self.lbs = dims[-1] 
        self.layers = LinBnDrop(self.lbs, ln=False, p=ps, act=None) # deb
        self.bptt = bptt
        self.emb_label = Embedding(self.lbs, self.fts) # deb: note that size of the label embeddings need not be same as nh 
        self.final_lin = nn.Linear(self.fts, self.lbs) 

    def forward(self, input):
        out, _ = input
        attn_wgts = out @ self.emb_label.weight.transpose(0, 1) # deb
        attn_wgts = F.softmax(attn_wgts, 1) # deb
        ctx = attn_wgts.transpose(1,2) @ out # deb
        x = self.layers(ctx)
        x = (self.final_lin.weight * x).sum(dim=2)
        return x, out, out

In [None]:
#| export
class LabelAttentionClassifier2(Module):
    initrange=0.1
    def __init__(self, dims, ps, bptt, y_range=None):
        self.fts = dims[0]
        self.lbs = dims[-1] 
        
        # ps = 0.1 # deb
        self.layers = LinBnDrop(self.lbs, ln=False, p=ps, act=None) # deb
        self.bptt = bptt
        # self.emb_label = Embedding(self.lbs, self.fts) # deb: note that size of the label embeddings need not be same as nh 
        self.emb_label = self._init_param(self.lbs, self.fts) # deb: note that size of the label embeddings need not be same as nh 
        self.final_lin = nn.Linear(self.fts, self.lbs) 
        self.final_lin.weight.data.uniform_(-self.initrange, self.initrange)
        self.final_lin.bias.data.zero_()
    
    def _init_param(self, *sz): return nn.Parameter(torch.zeros(sz).normal_(0, 0.01))

    def forward(self, input):
        out, _ = input
        # x = masked_concat_pool(out, mask, self.bptt)
        
        # bs = out.shape[0]
        # ctx = out.new_zeros((bs, self.lbs, self.fts))
        # for out_split in torch.split(out, 1, dim=1):
        # self.emb_label = nn.Parameter(self.emb_label * self.m1)
        attn_wgts = out @ self.emb_label.transpose(0, 1) # deb
        # attn_wgts = sigmoid_range(attn_wgts, 0, 5.5) # did not help
        attn_wgts = F.softmax(attn_wgts, 1) # deb
        # attn_wgts = torch.nn.functional.log_softmax(attn_wgts, 1) # deb
        # attn_wgts = torch.log(attn_wgts)/(attn_wgts.sum(dim=1, keepdim=True) + 1e-12)
        # attn_wgts[torch.isnan(attn_wgts)] = tensor(0.)
        # attn_wgts = torch.nn.functional.normalize(torch.log(attn_wgts), dim=1)
        ctx = attn_wgts.transpose(1,2) @ out # deb
        

        x = self.layers(ctx)
        # x = self.final_lin.weight.mul(x).sum(dim=2).add(self.final_lin.bias) #missed_deb
        x = (self.final_lin.weight * x).sum(dim=2) + self.final_lin.bias
        # x = (self.final_lin.weight * x + self.final_lin.bias.unsqueeze(1)).sum(dim=2)
        
        # x = x.view(x.shape[0], x.shape[1])
        return x, out, out

In [None]:
#| export
class LabelAttentionClassifier3(Module):
    initrange=0.1
    def __init__(self, dims, ps, bptt, y_range=None):
        self.fts = dims[0]
        self.lbs = dims[-1] 
        
        # ps = 0.1 # deb
        self.layers = LinBnDrop(self.lbs, ln=False, p=ps, act=None) # deb
        self.attn = XMLAttention(self.lbs, self.fts, 0.0)
        self.final_lin = nn.Linear(self.fts, self.lbs) 
        init_default(self.final_lin, 
                     func=partial(torch.nn.init.uniform_, a=-self.initrange, b=self.initrange))
    
    def forward(self, input):
        out, _ = input
        ctx = self.attn(out)
        x = self.layers(ctx)
        x = (self.final_lin.weight * ctx).sum(dim=2) + self.final_lin.bias
        
        return x, out, out

In [None]:
#| hide
#| eval:false
# decoder = LabelAttentionClassifier([1200, 6594], ps=0.04, bptt=72).cuda() # if gpu available
decoder = LabelAttentionClassifier([400, 6594], ps=0.04, bptt=72)
decoder

preds, *_ = decoder((out, None))
preds.shape

Breaking down `LabelAttentionClassifier` to make sure we understand each line:

In [None]:
#| hide
#| eval:false
decoder.emb_label.weight.shape

out.shape, out.device

attn_wgts = out @ decoder.emb_label.weight.transpose(0,1)
attn_wgts.shape, attn_wgts.device

attn_wgts = F.softmax(attn_wgts, 1)

# attn_wgts = None
# import gc
# gc.collect()
# torch.cuda.empty_cache()

out[:, :, None].shape

attn_wgts.transpose(1,2).shape

ctx = attn_wgts.transpose(1,2) @ out
ctx.shape

a = torch.arange(10).reshape(5,2)

a, a.shape

for a_split in torch.split(a, 2): print(a_split, a_split.shape, end='\n****\n')

In [None]:
#| export
def get_xmltext_classifier(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., lin_ftrs=None,
                       ps=None, pad_idx=1, max_len=72*20, y_range=None):
    "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_clas']).copy()
    for k in config.keys():
        if k.endswith('_p'): config[k] *= drop_mult
    if lin_ftrs is None: lin_ftrs = [50]
    if ps is None: ps = [0.1]*len(lin_ftrs) # not required if not using OurPoolingLinearClasifier
#     layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]  # required if using fastai's PoolingLinearClassifier
    layers = [config[meta['hid_name']]] + [n_class]
#     ps = [config.pop('output_p')] + ps
    ps = config.pop('output_p')
    init = config.pop('init') if 'init' in config else None
    encoder = AttentiveSentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)
    # decoder = OurPoolingLinearClassifier(layers, ps, bptt=seq_len, y_range=y_range)
    decoder = LabelAttentionClassifier3(layers, ps, bptt=seq_len, y_range=y_range)
    model = SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()