# AWD-LSTM

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
#export
from exp.nb_12 import *

## Data

In [None]:
path = datasets.untar_data(datasets.URLs.IMDB)

We have to preprocess the data agin to pickle it because if we try to load the previous ll with pickle, it will complain the functions aren't in main.

In [None]:
il = TextList.from_files(path, include=['train', 'test', 'unsup'])
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [None]:
ll = label_by_func(sd, lambda x: 0, proc_x = [TokenizeProcessor(), NumericalizeProcessor()])

In [None]:
pickle.dump(ll, open(path/'ll_lm.pkl', 'wb'))

In [None]:
ll = pickle.load(open(path/'ll_lm.pkl', 'rb'))

In [None]:
bs,bptt = 64,70
data = lm_databunchify(ll, bs, bptt)

In [None]:
vocab = ll.train.x.processors[1].vocab

## AWD-LSTM

### LSTM

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size))

    def forward(self, input, state):
        hx, cx = state
        #One big multiplication for all the gates is better than 4 smaller ones
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)

In [None]:
class LSTMLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    def forward(self, input, state):
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [None]:
lstm = LSTMLayer(LSTMCell, 300, 300)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300),torch.zeros(64, 300))

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

103 ms ± 611 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

25.8 ms ± 5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = nn.LSTM(300, 300, 1, batch_first=True)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(1, 64, 300),torch.zeros(1, 64, 300))

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

104 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached.
2.76 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


Jit version

In [None]:
import torch.jit as jit

In [None]:
class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)

In [None]:
class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [None]:
lstm = LSTMLayer(LSTMCell, 300, 300)

In [None]:
x = torch.randn(64, 70, 300)
h = (torch.zeros(64, 300),torch.zeros(64, 300))

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

98.9 ms ± 3.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
lstm = lstm.cuda()
x = x.cuda()
h = (h[0].cuda(), h[1].cuda())

In [None]:
%timeit -n 10 y,h1 = lstm(x,h)

The slowest run took 6.56 times longer than the fastest. This could mean that an intermediate result is being cached.
12 ms ± 13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Dropout

We want to use the AWD-LSTM from [Stephen Merity et al.](https://arxiv.org/abs/1708.02182). First, we'll need all different kinds of dropouts. Dropout consists into replacing some coefficients by 0 with probability p. To ensure that the averga of the weights remains constant, we apply a correction to the weights that aren't nullified of a factor `1/(1-p)`.

In [None]:
#export
def dropout_mask(x, sz, p):
    return x.new(*sz).bernoulli_(1-p).div_(1-p)

In [None]:
x = torch.randn(10,10)
dropout_mask(x, (10,10), 0.5)

tensor([[0., 2., 2., 0., 0., 0., 0., 0., 2., 0.],
        [2., 2., 2., 2., 2., 0., 2., 2., 2., 2.],
        [0., 0., 0., 2., 2., 0., 0., 0., 0., 0.],
        [2., 0., 0., 2., 0., 2., 0., 2., 0., 2.],
        [0., 2., 0., 2., 0., 0., 2., 2., 0., 2.],
        [0., 0., 0., 2., 2., 0., 2., 2., 2., 0.],
        [0., 0., 0., 2., 0., 0., 0., 2., 0., 2.],
        [2., 0., 2., 0., 2., 0., 2., 0., 2., 0.],
        [0., 2., 0., 0., 0., 2., 0., 0., 0., 0.],
        [2., 2., 2., 0., 2., 0., 0., 2., 0., 0.]])

Once with have a dropout mask `m`, applying the dropout to `x` is simply done by `x = x * m`. We create our own dropout mask and don't rely on pytorch dropout because we want to nullify the coefficients on the batch dimension but not the token dimension (aka the same coefficients are replaced by zero for each word in the sentence). 

Inside a RNN, a tensor x will have three dimensions: seq_len, bs, vocab_size, so we create a dropout mask for the last two dimensions and broadcast it to the first dimension.

In [None]:
#export
class RNNDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p=p

    def forward(self, x):
        if not self.training or self.p == 0.: return x
        m = dropout_mask(x.data, (1, x.size(1), x.size(2)), self.p)
        return x * m

In [None]:
dp = RNNDropout(0.3)
tst_input = torch.randn(3,3,7)
tst_input, dp(tst_input)

(tensor([[[-0.4715, -1.3210,  0.5567, -0.1550,  0.6257, -0.5021,  1.2980],
          [-1.3992, -2.8178, -0.3803,  0.0100,  1.6873,  0.5790, -1.8532],
          [-0.1775,  1.2794,  1.8237, -0.7937, -1.5842,  0.5972, -1.0857]],
 
         [[-1.8920, -0.9604, -0.6604, -1.0397, -0.9365,  0.0288, -0.8288],
          [ 0.9694,  2.8095,  0.1415, -0.1870, -0.6186, -0.2414,  0.1933],
          [-0.9919, -1.3077, -0.1721, -1.8967,  0.8264,  0.6205,  0.9652]],
 
         [[-1.5855, -0.9235,  0.8438, -0.0199, -0.1605,  1.0278, -0.7152],
          [ 2.4180, -0.9915, -1.4827, -0.4177,  0.3145,  0.0549,  0.7595],
          [ 0.1461, -0.9218,  0.0850,  0.1781, -0.0421,  0.4008,  0.0343]]]),
 tensor([[[-0.0000, -0.0000,  0.0000, -0.0000,  0.8938, -0.0000,  1.8543],
          [-1.9989, -4.0254, -0.5434,  0.0142,  0.0000,  0.0000, -0.0000],
          [-0.2536,  0.0000,  0.0000, -1.1338, -2.2632,  0.8532, -1.5511]],
 
         [[-0.0000, -0.0000, -0.0000, -0.0000, -1.3378,  0.0000, -1.1840],
          [ 1

Dropout applied to the weights of the inner LSTM cell. This is a little hacky.

In [None]:
#export
import warnings

class WeightDropout(nn.Module):
    def __init__(self, module, weight_p=[0.], layer_names=['weight_hh_l0']):
        super().__init__()
        self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)

    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)

    def forward(self, *args):
        self._setweights()
        with warnings.catch_warnings():
            #To avoid the warning that comes because the weights aren't flattened.
            warnings.simplefilter("ignore")
            return self.module.forward(*args)

In [None]:
module = nn.LSTM(5, 2)
dp_module = WeightDropout(module, 0.4)
getattr(dp_module.module, 'weight_hh_l0')

Parameter containing:
tensor([[-0.0351,  0.0559],
        [-0.1025,  0.1650],
        [ 0.6003,  0.0238],
        [-0.3934, -0.5116],
        [-0.6447,  0.6633],
        [ 0.2746,  0.6176],
        [ 0.1577,  0.1489],
        [-0.1887, -0.5948]], requires_grad=True)

It's at the beginning of a forward pass that the dropout is applied to the weights.

In [None]:
tst_input = torch.randn(4,20,5)
h = (torch.zeros(1,20,2), torch.zeros(1,20,2))
x,h = dp_module(tst_input,h)
getattr(dp_module.module, 'weight_hh_l0')

tensor([[-0.0584,  0.0000],
        [-0.1708,  0.2749],
        [ 1.0005,  0.0000],
        [-0.6557, -0.8526],
        [-1.0746,  0.0000],
        [ 0.0000,  1.0294],
        [ 0.0000,  0.0000],
        [-0.3146, -0.9914]], grad_fn=<MulBackward0>)

Dropout applied to full rows of the embedding matrix.

In [None]:
#export
class EmbeddingDropout(nn.Module):
    "Applies dropout in the embedding layer by zeroing out some elements of the embedding vector."
    
    def __init__(self, emb, embed_p):
        super().__init__()
        self.emb,self.embed_p = emb,embed_p
        self.pad_idx = self.emb.padding_idx
        if self.pad_idx is None: self.pad_idx = -1

    def forward(self, words, scale=None):
        if self.training and self.embed_p != 0:
            size = (self.emb.weight.size(0),1)
            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
            masked_embed = self.emb.weight * mask
        else: masked_embed = self.emb.weight
        if scale: masked_embed.mul_(scale)
        return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,
                           self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)

In [None]:
enc = nn.Embedding(100, 7, padding_idx=1)
enc_dp = EmbeddingDropout(enc, 0.5)
tst_input = torch.randint(0,100,(8,))
enc_dp(tst_input)

tensor([[-0.8062,  1.9744, -1.3620, -0.2249, -1.4037,  2.3976, -2.9599],
        [ 0.5693, -0.7265, -0.2522,  2.9447, -0.1223, -2.0842,  1.9399],
        [ 2.7000, -2.5391,  1.0501, -5.4233, -2.7581, -0.5289, -0.0114],
        [ 1.5184, -2.5519,  0.9132,  1.0460, -2.2969,  0.1667, -1.7071],
        [ 0.3400, -0.1362, -2.3155,  0.6989, -3.4174,  0.8296, -0.4206],
        [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
        [-0.0501, -1.4007,  2.2240, -0.6240,  2.2514,  1.2842, -1.4562],
        [-0.3684, -1.1094, -2.0388, -1.3518,  0.4788,  2.6264,  4.7393]],
       grad_fn=<EmbeddingBackward>)

### Main model

In [None]:
#export
def to_detach(h):
    "Detaches `h` from its history."
    return h.detach() if type(h) == torch.Tensor else tuple(to_detach(v) for v in h)

In [None]:
#export
class AWD_LSTM(nn.Module):
    "AWD-LSTM inspired by https://arxiv.org/abs/1708.02182."
    initrange=0.1

    def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token,
                 hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
        super().__init__()
        self.bs,self.emb_sz,self.n_hid,self.n_layers = 1,emb_sz,n_hid,n_layers
        self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
        self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)
        self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz), 1,
                             batch_first=True) for l in range(n_layers)]
        self.rnns = nn.ModuleList([WeightDropout(rnn, weight_p) for rnn in self.rnns])
        self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.input_dp = RNNDropout(input_p)
        self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])

    def forward(self, input):
        bs,sl = input.size()
        if bs!=self.bs:
            self.bs=bs
            self.reset()
        raw_output = self.input_dp(self.encoder_dp(input))
        new_hidden,raw_outputs,outputs = [],[],[]
        for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
            raw_output, new_h = rnn(raw_output, self.hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
            outputs.append(raw_output) 
        self.hidden = to_detach(new_hidden)
        return raw_outputs, outputs

    def _one_hidden(self, l):
        "Return one hidden state."
        nh = self.n_hid if l != self.n_layers - 1 else self.emb_sz
        return next(self.parameters()).new(1, self.bs, nh).zero_()

    def reset(self):
        "Reset the hidden states."
        self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]

In [None]:
#export
class LinearDecoder(nn.Module):
    "To go on top of an AWD-LSTM module"
    initrange=0.1
    
    def __init__(self, n_out, n_hid, output_p, tie_encoder=None, bias=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input):
        raw_outputs, outputs = input
        output = self.output_dp(outputs[-1]).contiguous()
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded, raw_outputs, outputs

In [None]:
#export
class SequentialRNN(nn.Sequential):
    "A sequential module that passes the reset call to its children."
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()

In [None]:
#export
def get_language_model(vocab_sz, emb_sz, n_hid, n_layers, pad_token, tie_weights=True, bias=True, 
                       output_p=0.4, hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5):
    "To create a full AWD-LSTM"
    rnn_enc = AWD_LSTM(vocab_sz, emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=pad_token,
                       hidden_p=hidden_p, input_p=input_p, embed_p=embed_p, weight_p=weight_p)
    enc = rnn_enc.encoder if tie_weights else None
    return SequentialRNN(rnn_enc, LinearDecoder(vocab_sz, emb_sz, output_p, tie_encoder=enc, bias=bias))

In [None]:
tst_model = get_language_model(500, 20, 100, 2, 0)
tst_model = tst_model.cuda()

In [None]:
x = torch.randint(0, 500, (10,5)).long()
z = tst_model(x.cuda())

In [None]:
len(z)

3

### Callbacks to train the model

In [None]:
#export
class GradientClipping(Callback):
    def __init__(self, clip=None): self.clip = clip
    def after_backward(self):
        if self.clip:  nn.utils.clip_grad_norm_(self.run.model.parameters(), self.clip)

In [None]:
#export
class RNNTrainer(Callback):
    def __init__(self, alpha, beta): self.alpha,self.beta = alpha,beta
    
    def after_pred(self):
        #Save the extra outputs for later and only returns the true output.
        self.raw_out,self.out = self.pred[1],self.pred[2]
        self.run.pred = self.pred[0]
    
    def after_loss(self):
        #AR and TAR
        if self.alpha != 0.:  self.run.loss += self.alpha * self.out[-1].float().pow(2).mean()
        if self.beta != 0.:
            h = self.raw_out[-1]
            if len(h)>1: self.run.loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
                
    def begin_epoch(self):
        #Shuffle the texts at the beginning of the epoch
        if hasattr(self.dl.dataset, "batchify"): self.dl.dataset.batchify()

In [None]:
#export
def cross_entropy_flat(input, target):
    bs,sl = target.size()
    return F.cross_entropy(input.view(bs * sl, -1), target.view(bs * sl))

def accuracy_flat(input, target):
    bs,sl = target.size()
    return accuracy(input.view(bs * sl, -1), target.view(bs * sl))

In [None]:
emb_sz, nh, nl = 300, 300, 1
model = get_language_model(len(vocab), emb_sz, nh, nl, 0, input_p=0.6, output_p=0.4, weight_p=0.5, 
                           embed_p=0.1, hidden_p=0.2)

In [None]:
cbs = [partial(AvgStatsCallback,accuracy_flat),
       CudaCallback,
       Recorder,
       partial(GradientClipping, clip=0.1),
       partial(RNNTrainer, alpha=2., beta=1.),
       ProgressCallback]

In [None]:
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt)

In [None]:
learn.fit(1)

## Export

In [None]:
!python notebook2script.py 12a_awd_lstm.ipynb

Converted 12a_awd_lstm.ipynb to exp/nb_12a.py
