In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab

In [None]:
#export 
from fastai.data.all import *
from fastai.text.models.core import *
from fastai.text.models.awdlstm import *
from xcube.layers import *

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#default_exp text.models.core

# Core text modules

> Contain the modules common between different architectures and the generic functions to get models

In [None]:
#export
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},}

## Basic Models

In [None]:
#export
class SequentialRNN(nn.Sequential):
    "A sequential pytorch module that passes the reset call to its children."
    def reset(self):
        for c in self.children(): getattr(c, 'reset', noop)()

## Classification Models

In [None]:
#export
def _pad_tensor(t, bs):
    if t.size(0) < bs: return torch.cat([t, t.new_zeros(bs-t.size(0), *t.shape[1:])])
    return t

In [None]:
#export
class SentenceEncoder(Module):
    "Create an encoder over `module` that can process a full sentence."
    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')
    def reset(self): getattr(self.module, 'reset', noop)()

    def forward(self, input):
        bs,sl = input.size()
        self.reset()
        mask = input == self.pad_idx
        outs,masks = [],[]
        for i in range(0, sl, self.bptt):
            #Note: this expects that sequence really begins on a round multiple of bptt
            real_bs = (input[:,i] != self.pad_idx).long().sum()
            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])
            if self.max_len is None or sl-i <= self.max_len:
                outs.append(o)
                masks.append(mask[:,i: min(i+self.bptt, sl)])
        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)
        mask = torch.cat(masks, dim=1)
        return outs,mask

> Warning: This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of bptt (and the rest of the padding at the end). Use `pad_input_chunk` to get your data in a suitable format.

In [None]:
#export
def masked_concat_pool(output, mask, bptt):
    "Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]"
    lens = output.shape[1] - mask.long().sum(dim=1)
    last_lens = mask[:,-bptt:].long().sum(dim=1)
    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)
    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])
    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
    x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.
    return x

In [None]:
#export
class PoolingLinearClassifier(Module):
    "Create a linear classifier with pooling"
    def __init__(self, dims, ps, bptt, y_range=None):
        if len(ps) != len(dims)-1: raise ValueError("Number of layers and dropout values do not match.")
        acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]
        layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        self.layers = nn.Sequential(*layers)
        self.bptt = bptt

    def forward(self, input):
        out,mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layers(x)
        return x, out, out

In [None]:
#export
class CAML2(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.layer = LinBnDrop(dims[0], dims[1], p=ps, act=None)
        self.bptt = bptt

    def forward(self, input):
        out, mask = input
        x = masked_concat_pool(out, mask, self.bptt)
        x = self.layer(x)
        return x, out, out

In [None]:
#export
class CAML3(Module):
    def __init__(self, dims, ps, bptt, y_range=None):
        self.lbs = dims[-1] 
        self.fts = dims[0]//3
        self.layers = Lin1BnDrop(self.lbs, self.fts, p=ps, act=None) # deb
        self.bptt = bptt
        self.emb_label = nn.Embedding(self.lbs, self.fts) # deb

    def forward(self, input):
        out, _ = input
        # breakpoint()
        # x = masked_concat_pool(out, mask, self.bptt)
        attn_wgts = out @ self.emb_label.weight.transpose(0, 1) # deb
        attn_wgts = F.softmax(attn_wgts, 1) # deb
        ctx = attn_wgts.transpose(1,2) @ out # deb
        
        x = self.layers(ctx)
        x = x.view(x.shape[0], x.shape[1])
        return x, out, out

In [None]:
#export
def get_text_classifier(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., lin_ftrs=None,
                       ps=None, pad_idx=1, max_len=72*20, y_range=None):
    "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_clas']).copy()
    for k in config.keys():
        if k.endswith('_p'): config[k] *= drop_mult
    if lin_ftrs is None: lin_ftrs = [50]
    if ps is None: ps = [0.1]*len(lin_ftrs)
#     layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]
    layers = [config[meta['hid_name']] * 3] + [n_class]
#     ps = [config.pop('output_p')] + ps
    ps = config.pop('output_p')
    init = config.pop('init') if 'init' in config else None
    encoder = SentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)
#     decoder = PoolingLinearClassifier(layers, ps, bptt=seq_len, y_range=y_range)
    decoder = CAML3(layers, ps, bptt=seq_len, y_range=y_range)
    model = SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)

In [None]:
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_text.models.core.ipynb.
Converted 03_text.learner.ipynb.
Converted 04_metrics.ipynb.
Converted index.ipynb.
