# AWD-LSTM

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#export
from exp.nb_12 import *

## Data

In [None]:
path = datasets.untar_data(datasets.URLs.IMDB)

We have to preprocess the data agin to pickle it because if we try to load the previous ll with pickle, it will complain the functions aren't in main.

In [None]:
il = TextList.from_files(path, include=['train', 'test', 'unsup'])
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [None]:
proc_tok,proc_num = TokenizeProcessor(max_workers=8),NumericalizeProcessor()

In [None]:
ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok,proc_num])

In [None]:
pickle.dump(ll, open(path/'ll_lm.pkl', 'wb'))

In [None]:
ll = pickle.load(open(path/'ll_lm.pkl', 'rb'))

In [None]:
bs,bptt = 64,70
data = lm_databunchify(ll, bs, bptt)

In [None]:
vocab = proc_num.vocab

## AWD-LSTM

### LSTM from scratch

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        super().__init__()
        self.ih = nn.Linear(ni,4*nh)
        self.hh = nn.Linear(nh,4*nh)

    def forward(self, input, state):
        h,c = state
        #One big multiplication for all the gates is better than 4 smaller ones
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
        cellgate = gates[3].tanh()

        c = (forgetgate*c) + (ingate*cellgate)
        h = outgate * c.tanh()
        return h, (h,c)

In [None]:
class LSTMLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    def forward(self, input, state):
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [None]:
lstm = LSTMLayer(LSTMCell, 300, 300)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300),torch.zeros(64, 300))

CPU

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

106 ms ± 444 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
def time_fn(f):
    f()
    torch.cuda.synchronize()

CUDA

In [None]:
f = partial(lstm,x,h)
time_fn(f)

In [None]:
%timeit -n 10 time_fn(f)

28.5 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Builtin version

In [None]:
lstm = nn.LSTM(300, 300, 1, batch_first=True)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(1, 64, 300),torch.zeros(1, 64, 300))

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

102 ms ± 227 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
f = partial(lstm,x,h)
time_fn(f)

In [None]:
%timeit -n 10 time_fn(f)

2.2 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Jit version

In [None]:
import torch.jit as jit
from torch import Tensor

In [None]:
class LSTMCell(jit.ScriptModule):
    def __init__(self, ni, nh):
        super().__init__()
        self.ni = ni
        self.nh = nh
        self.w_ih = nn.Parameter(torch.randn(4 * nh, ni))
        self.w_hh = nn.Parameter(torch.randn(4 * nh, nh))
        self.bias_ih = nn.Parameter(torch.randn(4 * nh))
        self.bias_hh = nn.Parameter(torch.randn(4 * nh))

    @jit.script_method
    def forward(self, input:Tensor, state:Tuple[Tensor, Tensor])->Tuple[Tensor, Tuple[Tensor, Tensor]]:
        hx, cx = state
        gates = (input @ self.w_ih.t() + self.bias_ih +
                 hx @ self.w_hh.t() + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)

In [None]:
class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input:Tensor, state:Tuple[Tensor, Tensor])->Tuple[Tensor, Tuple[Tensor, Tensor]]:
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [None]:
lstm = LSTMLayer(LSTMCell, 300, 300)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300),torch.zeros(64, 300))

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

96.4 ms ± 3.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
f = partial(lstm,x,h)
time_fn(f)

In [None]:
%timeit -n 10 time_fn(f)

7.98 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Dropout

We want to use the AWD-LSTM from [Stephen Merity et al.](https://arxiv.org/abs/1708.02182). First, we'll need all different kinds of dropouts. Dropout consists into replacing some coefficients by 0 with probability p. To ensure that the averga of the weights remains constant, we apply a correction to the weights that aren't nullified of a factor `1/(1-p)`.

In [None]:
#export
def dropout_mask(x, sz, p):
    return x.new(*sz).bernoulli_(1-p).div_(1-p)

In [None]:
x = torch.randn(10,10)
mask = dropout_mask(x, (10,10), 0.5); mask

tensor([[2., 2., 0., 2., 2., 2., 2., 0., 0., 2.],
        [2., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
        [2., 2., 2., 2., 0., 0., 2., 2., 2., 0.],
        [2., 2., 2., 0., 0., 2., 2., 2., 2., 0.],
        [0., 0., 2., 2., 2., 0., 2., 0., 0., 0.],
        [2., 0., 2., 0., 0., 2., 2., 0., 0., 0.],
        [0., 0., 0., 2., 2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 2., 2., 2., 2., 0., 0.],
        [2., 0., 2., 0., 2., 2., 0., 2., 0., 2.],
        [0., 2., 0., 0., 2., 0., 2., 0., 0., 0.]])

Once with have a dropout mask `m`, applying the dropout to `x` is simply done by `x = x * m`. We create our own dropout mask and don't rely on pytorch dropout because we want to nullify the coefficients on the batch dimension but not the token dimension (aka the same coefficients are replaced by zero for each word in the sentence). 

In [None]:
(x*mask).std(),x.std()

(tensor(1.2469), tensor(0.9459))

Inside a RNN, a tensor x will have three dimensions: bs, seq_len, vocab_size, so we create a dropout mask for the last two dimensions and broadcast it to the first dimension.

In [None]:
#export
class RNNDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p=p

    def forward(self, x):
        if not self.training or self.p == 0.: return x
        m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
        return x * m

In [None]:
dp = RNNDropout(0.3)
tst_input = torch.randn(3,3,7)
tst_input, dp(tst_input)

(tensor([[[-0.4906, -1.4757, -0.8150, -0.1549, -0.6623,  0.9590,  0.8725],
          [ 0.5281, -0.0725,  0.7462,  0.4467,  0.1519,  0.4513,  0.0380],
          [ 2.1783, -1.1480, -0.6598, -0.2936,  0.2893,  0.8233,  0.0272]],
 
         [[-1.0300,  1.0400, -0.7219, -0.1432, -0.5659,  0.1479,  1.1913],
          [-0.1042, -1.0175,  0.0429, -0.7553,  0.8381, -0.5671,  0.1904],
          [ 0.1898, -0.6735,  0.9978,  0.4496,  0.3275,  1.6205,  1.1158]],
 
         [[ 0.0995, -0.9166,  1.3238,  0.2154,  1.8205,  1.5846,  1.7698],
          [ 0.2066, -1.2835, -1.1549,  0.4163,  1.3043, -0.3958, -0.3215],
          [-0.3986, -2.2325,  0.6568,  0.7527,  1.2539, -1.0581, -0.8355]]]),
 tensor([[[-0.7008, -0.0000, -1.1643, -0.2213, -0.9462,  1.3700,  1.2464],
          [ 0.7545, -0.0000,  1.0660,  0.6381,  0.2170,  0.6447,  0.0543],
          [ 3.1118, -0.0000, -0.9426, -0.4194,  0.4133,  1.1761,  0.0389]],
 
         [[-0.0000,  1.4857, -1.0313, -0.2045, -0.8084,  0.2113,  1.7018],
          [-0

Dropout applied to the weights of the inner LSTM cell. This is a little hacky.

In [None]:
#export
import warnings

WEIGHT_HH = 'weight_hh_l0'

class WeightDropout(nn.Module):
    def __init__(self, module, weight_p=[0.], layer_names=[WEIGHT_HH]):
        super().__init__()
        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)

    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)

    def forward(self, *args):
        self._setweights()
        with warnings.catch_warnings():
            #To avoid the warning that comes because the weights aren't flattened.
            warnings.simplefilter("ignore")
            return self.module.forward(*args)

In [None]:
module = nn.LSTM(5, 2)
dp_module = WeightDropout(module, 0.4)
getattr(dp_module.module, WEIGHT_HH)

Parameter containing:
tensor([[-0.0752, -0.4945],
        [ 0.4209, -0.1203],
        [ 0.6989,  0.0846],
        [ 0.2286, -0.2910],
        [ 0.1882,  0.2029],
        [ 0.4733, -0.0883],
        [-0.3956, -0.1842],
        [-0.2313, -0.0853]], requires_grad=True)

It's at the beginning of a forward pass that the dropout is applied to the weights.

In [None]:
tst_input = torch.randn(4,20,5)
h = (torch.zeros(1,20,2), torch.zeros(1,20,2))
x,h = dp_module(tst_input,h)
getattr(dp_module.module, WEIGHT_HH)

tensor([[-0.1254, -0.0000],
        [ 0.7015, -0.2006],
        [ 0.0000,  0.1409],
        [ 0.3810, -0.4850],
        [ 0.3137,  0.3382],
        [ 0.7889, -0.1472],
        [-0.6593, -0.3070],
        [-0.3855, -0.0000]], grad_fn=<MulBackward0>)

Dropout applied to full rows of the embedding matrix.

In [None]:
#export
class EmbeddingDropout(nn.Module):
    "Applies dropout in the embedding layer by zeroing out some elements of the embedding vector."
    
    def __init__(self, emb, embed_p):
        super().__init__()
        self.emb,self.embed_p = emb,embed_p
        self.pad_idx = self.emb.padding_idx
        if self.pad_idx is None: self.pad_idx = -1

    def forward(self, words, scale=None):
        if self.training and self.embed_p != 0:
            size = (self.emb.weight.size(0),1)
            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
            masked_embed = self.emb.weight * mask
        else: masked_embed = self.emb.weight
        if scale: masked_embed.mul_(scale)
        return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,
                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)

In [None]:
enc = nn.Embedding(100, 7, padding_idx=1)
enc_dp = EmbeddingDropout(enc, 0.5)
tst_input = torch.randint(0,100,(8,))
enc_dp(tst_input)

tensor([[-0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
        [-0.0000, -0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
        [-2.2219,  0.1879,  0.4018,  0.6256,  1.2506, -2.2270, -0.8169],
        [-2.0820,  1.0500,  2.0398,  3.0884, -2.2652, -2.2884, -2.4420],
        [ 0.8035,  1.4288, -2.8404,  0.3058,  1.4443,  2.2933,  0.6750],
        [-2.8565,  0.3026,  0.2049, -1.6020,  1.4497,  5.2394,  0.5871],
        [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
        [-0.8294, -1.0537,  0.5131, -1.8451,  1.4449, -2.9209,  0.9390]],
       grad_fn=<EmbeddingBackward>)

### Main model

In [None]:
#export
def to_detach(h):
    "Detaches `h` from its history."
    return h.detach() if type(h) == torch.Tensor else tuple(to_detach(v) for v in h)

In [None]:
#export
class AWD_LSTM(nn.Module):
    "AWD-LSTM inspired by https://arxiv.org/abs/1708.02182."
    initrange=0.1

    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token,
                 hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
        super().__init__()
        self.bs,self.emb_sz,self.n_hid,self.n_layers = 1,emb_sz,n_hid,n_layers
        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
        self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)
        self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz), 1,
                             batch_first=True) for l in range(n_layers)]
        self.rnns = nn.ModuleList([WeightDropout(rnn, weight_p) for rnn in self.rnns])
        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.input_dp = RNNDropout(input_p)
        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])

    def forward(self, input):
        bs,sl = input.size()
        if bs!=self.bs:
            self.bs=bs
            self.reset()
        raw_output = self.input_dp(self.encoder_dp(input))
        new_hidden,raw_outputs,outputs = [],[],[]
        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
            raw_output, new_h = rnn(raw_output, self.hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
            outputs.append(raw_output) 
        self.hidden = to_detach(new_hidden)
        return raw_outputs, outputs

    def _one_hidden(self, l):
        "Return one hidden state."
        nh = self.n_hid if l != self.n_layers - 1 else self.emb_sz
        return next(self.parameters()).new(1, self.bs, nh).zero_()

    def reset(self):
        "Reset the hidden states."
        self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]

In [None]:
#export
class LinearDecoder(nn.Module):
    def __init__(self, n_out, n_hid, output_p, tie_encoder=None, bias=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight
        else: init.kaiming_uniform_(self.decoder.weight)

    def forward(self, input):
        raw_outputs, outputs = input
        output = self.output_dp(outputs[-1]).contiguous()
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded, raw_outputs, outputs

In [None]:
#export
class SequentialRNN(nn.Sequential):
    "A sequential module that passes the reset call to its children."
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

In [None]:
#export
def get_language_model(vocab_sz, emb_sz, n_hid, n_layers, pad_token, tie_weights=True, bias=True, 
                       output_p=0.4, hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
    rnn_enc = AWD_LSTM(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token,
                       hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    enc = rnn_enc.encoder if tie_weights else None
    return SequentialRNN(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))

In [None]:
tok_pad = vocab.index(PAD)

In [None]:
tst_model = get_language_model(len(vocab), 300, 300, 2, tok_pad)
tst_model = tst_model.cuda()

In [None]:
x = torch.randint(0, 500, (10,5)).long()
z = tst_model(x.cuda())

In [None]:
len(z)

3

### Callbacks to train the model

In [None]:
#export
class GradientClipping(Callback):
    def __init__(self, clip=None): self.clip = clip
    def after_backward(self):
        if self.clip:  nn.utils.clip_grad_norm_(self.run.model.parameters(), self.clip)

In [None]:
#export
class RNNTrainer(Callback):
    def __init__(self, alpha, beta): self.alpha,self.beta = alpha,beta
    
    def after_pred(self):
        #Save the extra outputs for later and only returns the true output.
        self.raw_out,self.out = self.pred[1],self.pred[2]
        self.run.pred = self.pred[0]
    
    def after_loss(self):
        #AR and TAR
        if self.alpha != 0.:  self.run.loss += self.alpha * self.out[-1].float().pow(2).mean()
        if self.beta != 0.:
            h = self.raw_out[-1]
            if len(h)>1: self.run.loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
                
    def begin_epoch(self):
        #Shuffle the texts at the beginning of the epoch
        if hasattr(self.dl.dataset, "batchify"): self.dl.dataset.batchify()

In [None]:
#export
def cross_entropy_flat(input, target):
    bs,sl = target.size()
    return F.cross_entropy(input.view(bs * sl, -1), target.view(bs * sl))

def accuracy_flat(input, target):
    bs,sl = target.size()
    return accuracy(input.view(bs * sl, -1), target.view(bs * sl))

In [None]:
emb_sz, nh, nl = 300, 300, 1
model = get_language_model(len(vocab), emb_sz, nh, nl, 0, input_p=0.6, output_p=0.4, weight_p=0.5, 
                           embed_p=0.1, hidden_p=0.2)

In [None]:
cbs = [partial(AvgStatsCallback,accuracy_flat),
       CudaCallback, Recorder,
       partial(GradientClipping, clip=0.1),
       partial(RNNTrainer, alpha=2., beta=1.),
       ProgressCallback]

In [None]:
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt)

In [None]:
learn.fit(1)

## Export

In [None]:
!python notebook2script.py 12a_awd_lstm.ipynb

Converted 12a_awd_lstm.ipynb to exp/nb_12a.py
