# RNN-LSTM architecture implementation with fast.ai and PyTorch

### LSTM Architecture
![LSTM Cell](https://upload.wikimedia.org/wikipedia/commons/5/56/LSTM_cell.svg)


### Equations - LSTM with a forget gate
![LSTM equations](https://wikimedia.org/api/rest_v1/media/math/render/svg/7dee414820d5c0162ae1fff1899e58b08923944f)


**References**
* [fastbook - Chapter 12 - NLP](https://github.com/fastai/fastbook/blob/master/12_nlp_dive.ipynb)
* [Wikipedia - Long short-term memory](https://en.wikipedia.org/wiki/Long_short-term_memory)

## Setup

Download and preprocessing HUMAN_NUMBERS dataset from fast.ai

In [1]:
!pip install -Uqq fastai

[?25l[K     |█▊                              | 10kB 17.0MB/s eta 0:00:01[K     |███▌                            | 20kB 21.6MB/s eta 0:00:01[K     |█████▏                          | 30kB 10.0MB/s eta 0:00:01[K     |███████                         | 40kB 8.3MB/s eta 0:00:01[K     |████████▋                       | 51kB 4.4MB/s eta 0:00:01[K     |██████████▍                     | 61kB 4.9MB/s eta 0:00:01[K     |████████████                    | 71kB 5.2MB/s eta 0:00:01[K     |█████████████▉                  | 81kB 5.7MB/s eta 0:00:01[K     |███████████████▌                | 92kB 5.3MB/s eta 0:00:01[K     |█████████████████▎              | 102kB 5.7MB/s eta 0:00:01[K     |███████████████████             | 112kB 5.7MB/s eta 0:00:01[K     |████████████████████▊           | 122kB 5.7MB/s eta 0:00:01[K     |██████████████████████▍         | 133kB 5.7MB/s eta 0:00:01[K     |████████████████████████▏       | 143kB 5.7MB/s eta 0:00:01[K     |████████████████████████

In [2]:
from fastai.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)

In [3]:
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines

(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]

In [4]:
text = ' . '.join([l.strip() for l in lines])
text[:100]

'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'

In [5]:
tokens = text.split(' ')
tokens[:10]

['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']

In [6]:
vocab = L(*tokens).unique()
vocab

(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]

In [7]:
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
nums

(#63095) [0,1,2,1,3,1,4,1,5,1...]

In [8]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds

In [9]:
bs = 64
sl = 16
seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
         for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
                             group_chunks(seqs[cut:], bs),
                             bs=bs, drop_last=True, shuffle=False)

## Model with built-in LSTM function

To compare with our model implementation

In [15]:
class LMModel(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res, h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)
    
    def reset(self): 
        for h in self.h: h.zero_()

In [16]:
learn = Learner(dls, LMModel(len(vocab), 64, 2), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,3.055825,2.775102,0.171794,00:02
1,2.221072,1.764197,0.411377,00:02
2,1.629837,1.762449,0.462646,00:02
3,1.321504,2.075851,0.496908,00:02
4,1.096306,2.156112,0.515137,00:02
5,0.858027,2.050146,0.545166,00:02
6,0.638046,1.777854,0.598958,00:02
7,0.454945,1.834544,0.633952,00:02
8,0.296643,1.781012,0.672689,00:02
9,0.184759,1.823402,0.697591,00:02


## Model with custom LSTM implementation

LSTMCell from fastbook, modified to handle sequence length

In [17]:
class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        outputs = []
        h, c = state
        for i in range(sl):
          h = torch.cat([h, input[:, i, :]], dim=1)
          forget = torch.sigmoid(self.forget_gate(h))
          c = c * forget
          inp = torch.sigmoid(self.input_gate(h))
          cell = torch.tanh(self.cell_gate(h))
          c = c + inp * cell
          out = torch.sigmoid(self.output_gate(h))
          h = out * torch.tanh(c)
          outputs.append(h)
        return torch.stack(outputs, dim=1), (h, c)

In [18]:
class LMModelX(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = LSTMCell(bs, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res, h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)

    def reset(self): 
        for h in self.h: h.zero_()

In [19]:
learn = Learner(dls, LMModelX(len(vocab), 64), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.956665,2.52278,0.299479,00:01
1,1.989015,1.970153,0.318034,00:01
2,1.55081,1.760146,0.480876,00:01
3,1.289587,1.851735,0.520426,00:01
4,1.06614,1.699347,0.603109,00:01
5,0.85279,1.740918,0.633057,00:01
6,0.643045,1.649858,0.700521,00:01
7,0.441365,1.690009,0.707275,00:01
8,0.279773,1.612885,0.772868,00:01
9,0.170428,1.578884,0.778809,00:01
