# RNN-GRU architecture implementation with fast.ai and PyTorch

### GRU Architecture
![GRU Cell](https://upload.wikimedia.org/wikipedia/commons/3/37/Gated_Recurrent_Unit%2C_base_type.svg)


### Equations - GRU - fully gated version
![GRU equations](https://wikimedia.org/api/rest_v1/media/math/render/svg/56d278fc80bd8febad40b3550de6d77e883e9c0b)


**References**
* [fastbook - Chapter 12 - NLP](https://github.com/fastai/fastbook/blob/master/12_nlp_dive.ipynb)
* [Wikipedia - Gated recurrent unit](https://en.wikipedia.org/wiki/Gated_recurrent_unit)

## Setup

Download and preprocessing HUMAN_NUMBERS dataset from fast.ai

In [1]:
!pip install -Uqq fastai

[K     |████████████████████████████████| 194kB 5.7MB/s 
[K     |████████████████████████████████| 61kB 5.3MB/s 
[?25h

In [2]:
from fastai.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)

In [3]:
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines

(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]

In [4]:
text = ' . '.join([l.strip() for l in lines])
text[:100]

'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'

In [5]:
tokens = text.split(' ')
tokens[:10]

['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']

In [6]:
vocab = L(*tokens).unique()
vocab

(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]

In [7]:
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
nums

(#63095) [0,1,2,1,3,1,4,1,5,1...]

In [8]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds

In [9]:
bs = 64
sl = 16
seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
         for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
                             group_chunks(seqs[cut:], bs),
                             bs=bs, drop_last=True, shuffle=False)

## Model with built-in GRU function

To compare with our model implementation

In [None]:
class LMModel_GRU(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.GRU(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(n_layers, bs, n_hidden)
        
    def forward(self, x):
        outs, h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(outs)
    
    def reset(self): 
        self.h = self.h.zero_()

In [None]:
learn = Learner(dls, LMModel_GRU(len(vocab), 64, 2), 
                loss_func=CrossEntropyLossFlat(),
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.885426,2.543878,0.428223,00:01
1,2.036111,2.203238,0.358317,00:01
2,1.48737,1.974242,0.450521,00:01
3,1.047475,1.728969,0.583333,00:01
4,0.614403,1.36007,0.7146,00:01
5,0.321952,1.097459,0.762533,00:01
6,0.165975,1.264157,0.806722,00:01
7,0.08844,1.257791,0.788411,00:01
8,0.052492,1.417744,0.822673,00:01
9,0.033698,1.491579,0.820557,00:01


## Model with custom GRU implementation

GRUCell, as per equation above from Wikipedia.

In [13]:
class GRUCell(Module):
    def __init__(self, ni, nh):
        self.reset_gate = nn.Linear(ni + nh, nh)
        self.update_gate  = nn.Linear(ni + nh, nh)
        self.candidate  = nn.Linear(ni + nh, nh)

    def forward(self, input, h):
        outputs = []
        for i in range(sl):
          x = input[:, i, :]
          h_old = h
          h = torch.cat([h, x], dim=1)
          
          rt = torch.sigmoid(self.reset_gate(h))
          zt = torch.sigmoid(self.update_gate(h))

          rt_h_x = torch.cat([h_old * rt, x], dim=1)
          ht_hat = torch.tanh(self.candidate(rt_h_x))
          
          h = ((1 - zt) * h_old) + (zt * ht_hat)
          outputs.append(h)
        return torch.stack(outputs, dim=1), h

In [16]:
class LMModelX(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = GRUCell(bs, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(bs, n_hidden)
        
    def forward(self, x):
        res, h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(res)

    def reset(self):
        self.h = self.h.zero_()

In [17]:
learn = Learner(dls, LMModelX(len(vocab), 64), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.879817,2.598815,0.286296,00:01
1,1.943929,1.971485,0.321777,00:01
2,1.560044,1.852278,0.425049,00:01
3,1.229757,1.579086,0.482178,00:01
4,0.94483,1.440558,0.56429,00:01
5,0.694256,1.065961,0.642741,00:01
6,0.489304,1.03793,0.694661,00:01
7,0.339288,1.020411,0.734456,00:01
8,0.242954,0.880779,0.773438,00:01
9,0.178799,0.906347,0.786458,00:01
