# Predicting English word version of numbers using an RNN

We were using RNNs as part of our language model in the previous lesson.  Today, we will dive into more details of what RNNs are and how they work.  We will do this using the problem of trying to predict the English word version of numbers.

Let's predict what should come next in this sequence:

*eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve...*


Jeremy created this synthetic dataset to have a better way to check if things are working, to debug, and to understand what was going on. When experimenting with new ideas, it can be nice to have a smaller dataset to do so, to quickly get a sense of whether your ideas are promising (for other examples, see [Imagenette and Imagewoof](https://github.com/fastai/imagenette)) This English word numbers will serve as a good dataset for learning about RNNs.  Our task today will be to predict which word comes next when counting.

## Data

In [1]:
from fastai.text import *

In [2]:
bs=64

In [3]:
path = untar_data(URLs.HUMAN_NUMBERS)
path.ls()

[PosixPath('/home/molly/.fastai/data/human_numbers/valid.txt'),
 PosixPath('/home/molly/.fastai/data/human_numbers/train.txt')]

In [4]:
def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]

train.txt gives us a sequence of numbers written out as English words:

In [5]:
train_txt = readnums('train.txt'); train_txt[0][:80]

'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'

In [6]:
valid_txt = readnums('valid.txt'); valid_txt[0][-80:]

' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'

In [7]:
train = TextList(train_txt, path=path)
valid = TextList(valid_txt, path=path)

src = ItemLists(path=path, train=train, valid=valid).label_for_lm()
data = src.databunch(bs=bs)

In [8]:
train[0].text[:80]

'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'

`bptt` stands for *back-propagation through time*.  This tells us how many steps of history we are considering.

In [9]:
data.bptt, len(data.valid_dl)

(70, 3)

We have 3 batches in our validation set:

13017 tokens, with about ~70 tokens in about a line of text, and 64 lines of text per batch.

We will store each batch in a separate variable, so we can walk through this to understand better what the RNN does at each step:

In [10]:
it = iter(data.valid_dl)
x1,y1 = next(it)
x2,y2 = next(it)
x3,y3 = next(it)
it.close()

In [11]:
v = data.valid_ds.vocab

In [12]:
data = src.databunch(bs=bs, bptt=40)

In [13]:
x,y = data.one_batch()
x.shape,y.shape

(torch.Size([64, 40]), torch.Size([64, 40]))

In [14]:
nv = len(v.itos); nv

40

In [15]:
nh=56

In [16]:
def loss4(input,target): return F.cross_entropy(input, target[:,-1])
def acc4 (input,target): return accuracy(input, target[:,-1])

Layer names:
- `i_h`: input to hidden
- `h_h`: hidden to hidden
- `h_o`: hidden to output
- `bn`: batchnorm

## Adding a GRU

When you have long time scales and deeper networks, these become impossible to train.  One way to address this is to add mini-NN to decide how much of the green arrow and how much of the orange arrow to keep.  These mini-NNs can be GRUs or LSTMs.

In [17]:
class Model5(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.rnn = nn.GRU(nh, nh, 1, batch_first=True)
        self.h_o = nn.Linear(nh,nv)
        self.bn = BatchNorm1dFlat(nh)
        self.h = torch.zeros(1, bs, nh).cuda()
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(self.bn(res))

In [18]:
nv, nh

(40, 56)

In [19]:
learn = Learner(data, Model5(), metrics=accuracy)

In [20]:
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,3.379762,3.116946,0.332747,00:00
1,2.520318,2.208271,0.314909,00:00
2,1.991337,1.843576,0.476497,00:00
3,1.661237,1.736325,0.521029,00:00
4,1.358236,1.608955,0.61276,00:00
5,1.050121,1.424691,0.708659,00:00
6,0.784529,1.446448,0.757292,00:00
7,0.584348,1.378798,0.765755,00:00
8,0.441738,1.449597,0.774349,00:00
9,0.344461,1.422903,0.773242,00:00


## Let's make our own GRU

### Using PyTorch's GRUCell

Axis 0 is the batch dimension, and axis 1 is the time dimension.  We want to loop through axis 1:

In [21]:
def rnn_loop(cell, h, x):
    res = []
    for x_ in x.transpose(0,1):
        h = cell(x_, h)
        res.append(h)
    return torch.stack(res, dim=1)

In [22]:
class Model6(Model5):
    def __init__(self):
        super().__init__()
        self.rnnc = nn.GRUCell(nh, nh)
        self.h = torch.zeros(bs, nh).cuda()
        
    def forward(self, x):
        res = rnn_loop(self.rnnc, self.h, self.i_h(x))
        self.h = res[:,-1].detach()
        return self.h_o(self.bn(res))

In [23]:
learn = Learner(data, Model6(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,3.430449,3.348477,0.36543,00:00
1,2.54947,2.416037,0.312109,00:00
2,2.000668,1.806786,0.432747,00:00
3,1.619377,1.636431,0.532292,00:00
4,1.244722,1.361456,0.691276,00:00
5,0.904485,1.274927,0.726628,00:00
6,0.644521,1.109192,0.757422,00:00
7,0.461107,1.148634,0.762044,00:00
8,0.335697,1.220675,0.743359,00:00
9,0.251593,1.243652,0.745313,00:00


### With a custom GRUCell

The following is based on code from [emadRad](https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb):

taken from here: https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21

![test](https://miro.medium.com/max/700/1*jhi5uOm9PvZfmxvfaCektw.png)

In [24]:
class GRUCell(nn.Module):
    def __init__(self, ni, nh):
        super(GRUCell, self).__init__()
        self.ni,self.nh = ni,nh
        self.i2h = nn.Linear(ni, 3*nh)#done for performance/simplification reasons, 
        self.h2h = nn.Linear(nh, 3*nh) #requires only one kernel, not multiple, better gpu utilization/simpler
    
    def forward(self, x, h):
        gate_x = self.i2h(x).squeeze()
        gate_h = self.h2h(h).squeeze()
        i_r,i_u,i_n = gate_x.chunk(3, 1)
        h_r,h_u,h_n = gate_h.chunk(3, 1)
        
        resetgate = torch.sigmoid(i_r + h_r)
        updategate = torch.sigmoid(i_u + h_u)
        newgate = torch.tanh(i_n + (resetgate*h_n)) #reset gate allows us to forget h_n, based on both input and hidden 
        return updategate*h + (1-updategate)*newgate #average of new and old allows us to preserve hidden(h)

#### My explanation

Reset gate is kind of like a forget gate, forget some of the hidden state. This is very similar to previous works. Important for "short-term" memory. How does my prevous state predict the next word. This is what the reset neurons fire on. 

GRU allows h to be passed to longer sequences because it is averaged with the new, and not lost to a linear layer/activation function. Update gate's previous function is sigmoid, with expected mean of 0.5, instead of 0(mean of weights) so we get to avoid vanishing gradient problem a bit. 0.5\*\*n goes to zero slower than 0.1\*\*n. This allows for "Long-term" memory. The update related weights get to balance long term information, vs short term information. 

#### Excerpt from paper

When the reset gate is close
to 0, the hidden state is forced to ignore the pre-
vious hidden state and reset with the current input.

On the other hand, the update gate controls how
much information from the previous hidden state
will carry over to the current hidden state. This
acts similarly to the memory cell in the LSTM
network and helps the RNN to remember long-
term information.

As each hidden unit(GRU Cell) has separate reset and up-
date gates, each hidden unit will learn to capture
dependencies over different time scales(# of rnn loops). Those
units that learn to capture short-term dependencies
will tend to have reset gates that are frequently ac-
tive, but those that capture longer-term dependen-
cies will have update gates that are mostly active. [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078) 

In [25]:
class Model7(Model6):
    def __init__(self):
        super().__init__()
        self.rnnc = GRUCell(nh,nh)

In [26]:
learn = Learner(data, Model7(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,3.411696,3.165735,0.279427,00:01
1,2.52379,2.343463,0.314844,00:01
2,1.998289,1.813654,0.452669,00:01
3,1.649988,1.6754,0.538607,00:01
4,1.330819,1.589842,0.610352,00:01
5,1.03336,1.339241,0.665755,00:01
6,0.781098,1.364926,0.69043,00:01
7,0.585266,1.415118,0.682943,00:01
8,0.444499,1.511679,0.68776,00:01
9,0.347861,1.494937,0.684635,00:01


In [27]:
! nvidia-smi

Wed Aug 18 01:19:50 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:02:00.0 Off |                  N/A |
|  0%   50C    P8    20W / 100W |    748MiB /  7979MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:03:00.0 Off |                  N/A |
|  0%   48C    P8     8W / 100W |      3MiB /  7982MiB |      0%      Defaul

### Connection to ULMFit

In the previous lesson, we were essentially swapping out `self.h_o` with a classifier in order to do classification on text.

RNNs are just a refactored, fully-connected neural network.

You can use the same approach for any sequence labeling task (part of speech, classifying whether material is sensitive,..)

## fin