This notebook was part of [Lesson 7](https://course.fast.ai/videos/?lesson=7) of the Practical Deep Learning for Coders course.

# Predicting English word version of numbers using an RNN
In this notebook I've combined the two notebooks 6-rnn-english-numbers.ipynb and 6-rnn-english-numbers_GRU.ipynb

We were using RNNs as part of our language model in the previous lesson.  Today, we will dive into more details of what RNNs are and how they work.  We will do this using the problem of trying to predict the English word version of numbers.

Let's predict what should come next in this sequence:

*eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve...*


Jeremy created this synthetic dataset to have a better way to check if things are working, to debug, and to understand what was going on. When experimenting with new ideas, it can be nice to have a smaller dataset to do so, to quickly get a sense of whether your ideas are promising (for other examples, see [Imagenette and Imagewoof](https://github.com/fastai/imagenette)) This English word numbers will serve as a good dataset for learning about RNNs.  Our task today will be to predict which word comes next when counting.

### In deep learning, there are 2 types of numbers

**Parameters** are numbers that are learned.  **Activations** are numbers that are calculated (by affine functions & element-wise non-linearities).

When you learn about any new concept in deep learning, ask yourself: is this a parameter or an activation?

Note to self: Point out the hidden state, going from the version without a for-loop to the for loop.  This is the step where people get confused.

## 0. Download, explore and prepare the data

In [1]:
from fastai.text import *

In [2]:
bs=64

In [3]:
path = untar_data(URLs.HUMAN_NUMBERS)
path.ls()

[WindowsPath('C:/Users/cross-entropy/.fastai/data/human_numbers/models'),
 WindowsPath('C:/Users/cross-entropy/.fastai/data/human_numbers/train.txt'),
 WindowsPath('C:/Users/cross-entropy/.fastai/data/human_numbers/valid.txt')]

In [4]:
def readnums(file): 
    return [', '.join(o.strip() for o in open(path/file).readlines())]

#### train.txt is a sequence of numbers from 1 to 8000 written out as English words:

In [5]:
train_txt = readnums('train.txt')
print(train_txt[0][:80])
print(train_txt[0][-80:])
print(len(train_txt[0]))

one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt
even thousand nine hundred ninety eight, seven thousand nine hundred ninety nine
280597


#### valid.txt is a sequence of numbers from 8001 to 9999 written out as English words:

In [6]:
valid_txt = readnums('valid.txt')
print(valid_txt[0][0:80])
print(valid_txt[0][-80:])
print(len(valid_txt[0]))

eight thousand one, eight thousand two, eight thousand three, eight thousand fou
 nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine
74882


#### make a databunch; batch size of 64 lines of length bptt = 70 tokens per line

In [7]:
train = TextList(train_txt, path=path)
valid = TextList(valid_txt, path=path)

src = ItemLists(path=path, train=train, valid=valid).label_for_lm()
data = src.databunch(bs=bs)

In [8]:
train[0].text[:80]

'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'

In [9]:
len(data.valid_ds[0][0].data)

13017

In [10]:
len(data.train_ds[0][0].data)

50079

`bptt` stands for *back-propagation through time*.  This tells us how many steps of history we are considering.

In [11]:
data.bptt, len(data.valid_dl)

(70, 3)

In [12]:
len(data.train_dl)

12

We have 3 batches in our validation set:

13017 tokens, with about ~70 tokens in about a line of text, and 64 lines of text per batch.

In [13]:
13017/70/bs

2.905580357142857

In [14]:
50079/70/bs

11.178348214285714

We will store each batch in a separate variable, so we can walk through this to understand better what the RNN does at each step:

In [15]:
it = iter(data.valid_dl)
x1,y1 = next(it)
x2,y2 = next(it)
x3,y3 = next(it)
it.close()

In [16]:
x1

tensor([[ 2, 19, 11,  ..., 36,  9, 19],
        [ 9, 19, 11,  ..., 24, 20,  9],
        [11, 27, 18,  ...,  9, 19, 11],
        ...,
        [20, 11, 20,  ..., 11, 20, 10],
        [20, 11, 20,  ..., 24,  9, 20],
        [20, 10, 26,  ..., 20, 11, 20]], device='cuda:0')

`numel()` is a [PyTorch method](https://pytorch.org/docs/stable/torch.html#torch.numel) to return the number of elements in a tensor:

In [17]:
x1.numel()+x2.numel()+x3.numel()

13440

In [18]:
print(x1.numel())
print(x2.numel())
print(x3.numel())

4480
4480
4480


In [19]:
x1.shape, y1.shape

(torch.Size([64, 70]), torch.Size([64, 70]))

In [20]:
x2.shape, y2.shape

(torch.Size([64, 70]), torch.Size([64, 70]))

In [21]:
x3.shape, y3.shape

(torch.Size([64, 70]), torch.Size([64, 70]))

#### vocabulary has 40 tokens

In [22]:
v = data.valid_ds.vocab

In [23]:
print(len(v.itos))
v.itos

40


['xxunk',
 'xxpad',
 'xxbos',
 'xxeos',
 'xxfld',
 'xxmaj',
 'xxup',
 'xxrep',
 'xxwrep',
 ',',
 'hundred',
 'thousand',
 'one',
 'two',
 'three',
 'four',
 'five',
 'six',
 'seven',
 'eight',
 'nine',
 'twenty',
 'thirty',
 'forty',
 'fifty',
 'sixty',
 'seventy',
 'eighty',
 'ninety',
 'ten',
 'eleven',
 'twelve',
 'thirteen',
 'fourteen',
 'fifteen',
 'sixteen',
 'seventeen',
 'eighteen',
 'nineteen',
 'xxfake']

In [24]:
x1[0,:]

tensor([ 2, 19, 11, 12,  9, 19, 11, 13,  9, 19, 11, 14,  9, 19, 11, 15,  9, 19,
        11, 16,  9, 19, 11, 17,  9, 19, 11, 18,  9, 19, 11, 19,  9, 19, 11, 20,
         9, 19, 11, 29,  9, 19, 11, 30,  9, 19, 11, 31,  9, 19, 11, 32,  9, 19,
        11, 33,  9, 19, 11, 34,  9, 19, 11, 35,  9, 19, 11, 36,  9, 19],
       device='cuda:0')

In [25]:
y1[0,:]

tensor([19, 11, 12,  9, 19, 11, 13,  9, 19, 11, 14,  9, 19, 11, 15,  9, 19, 11,
        16,  9, 19, 11, 17,  9, 19, 11, 18,  9, 19, 11, 19,  9, 19, 11, 20,  9,
        19, 11, 29,  9, 19, 11, 30,  9, 19, 11, 31,  9, 19, 11, 32,  9, 19, 11,
        33,  9, 19, 11, 34,  9, 19, 11, 35,  9, 19, 11, 36,  9, 19, 11],
       device='cuda:0')

In [26]:
v.itos[19], v.itos[11], v.itos[12], v.itos[9], v.itos[19], v.itos[11],v.itos[13],v.itos[9]

('eight', 'thousand', 'one', ',', 'eight', 'thousand', 'two', ',')

In [27]:
v.textify(x1[0])

'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'

In [28]:
v.textify(y1[0])

'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'

In [29]:
v.textify(x1[1,:])

', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'

In [30]:
v.textify(y1[1,:])

'eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight'

In [31]:
v.textify(x2[0])

'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'

In [32]:
v.textify(x2[1])

'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'

In [33]:
v.textify(x3[0])

'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'

In [34]:
v.textify(x1[1])

', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'

In [35]:
v.textify(x2[1])

'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'

In [36]:
v.textify(x3[1])

'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'

In [37]:
v.textify(x3[-1])

'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'

In [38]:
data.show_batch(ds_type=DatasetType.Valid)

idx,text
0,"thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty"
1,"eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight thousand one hundred one , eight thousand one"
2,"thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight thousand one hundred thirty three , eight thousand"
3,"three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred sixty two , eight thousand one hundred sixty three"
4,"thousand one hundred eighty three , eight thousand one hundred eighty four , eight thousand one hundred eighty five , eight thousand one hundred eighty six , eight thousand one hundred eighty seven , eight thousand one hundred eighty eight , eight thousand one hundred eighty nine , eight thousand one hundred ninety , eight thousand one hundred ninety one , eight thousand one hundred ninety two , eight thousand"


We will iteratively consider a few different models, building up to a more traditional RNN.

## 1. Single fully connected model: predict the next word

#### make a databunch; batch size 64 lines of bptt = 3 tokens  per line

In [39]:
data = src.databunch(bs=bs, bptt=3)

In [40]:
print(len(data.train_dl))
print(len(data.valid_dl))

261
68


In [41]:
bs

64

In [42]:
x,y = data.one_batch()
x.shape,y.shape

(torch.Size([64, 3]), torch.Size([64, 3]))

#### number of tokens in vocabulary, i.e. length of vocabulary vector

In [43]:
nv = len(v.itos)
nv

40

#### size of hidden layer

In [44]:
nh=64

In [47]:
def loss4(input,target): return F.cross_entropy(input, target[:,-1])
def acc4 (input,target): return accuracy(input, target[:,-1])

#### first token in each line of the batch

In [48]:
x[:,0]

tensor([13, 13, 10,  9, 18,  9, 11, 11, 13, 19, 16, 23, 24,  9, 12,  9, 13, 14,
        15, 11, 10, 22, 15,  9, 10, 14, 11, 16, 10, 28, 11,  9, 20,  9, 15, 15,
        11, 18, 10, 28, 23, 24,  9, 16, 10, 16, 19, 20, 12, 10, 22, 16, 17, 17,
        17, 11, 24, 10,  9, 15, 16,  9, 18, 11])

Layer names:
- `i_h`: input to hidden
- `h_h`: hidden to hidden
- `h_o`: hidden to output
- `bn`: batchnorm

In [49]:
class Model0(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)  # green arrow
        self.h_h = nn.Linear(nh,nh)     # brown arrow
        self.h_o = nn.Linear(nh,nv)     # blue arrow
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = self.bn(F.relu(self.i_h(x[:,0])))
        if x.shape[1]>1:
            h = h + self.i_h(x[:,1])
            h = self.bn(F.relu(self.h_h(h)))
        if x.shape[1]>2:
            h = h + self.i_h(x[:,2])
            h = self.bn(F.relu(self.h_h(h)))
        return self.h_o(h)

In [50]:
learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)

In [51]:
learn.fit_one_cycle(6, 1e-4)

epoch,train_loss,valid_loss,acc4,time
0,3.508088,3.562798,0.099954,00:05
1,2.605002,2.912117,0.348575,00:03
2,2.040771,2.520888,0.427849,00:03
3,1.840195,2.36176,0.428539,00:03
4,1.77074,2.316277,0.429688,00:03
5,1.756233,2.310652,0.412914,00:03


## Same thing with a loop

Let's refactor this to use a for-loop.  This does the same thing as before:

In [52]:
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)  # green arrow
        self.h_h = nn.Linear(nh,nh)     # brown arrow
        self.h_o = nn.Linear(nh,nv)     # blue arrow
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = torch.zeros(x.shape[0], nh).to(device=x.device)
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = self.bn(F.relu(self.h_h(h)))
        return self.h_o(h)

This is the difference between unrolled (what we had before) and rolled (what we have now) RNN diagrams:

In [53]:
learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)

In [54]:
learn.fit_one_cycle(6, 1e-4)

epoch,train_loss,valid_loss,acc4,time
0,3.517269,3.554246,0.095588,00:03
1,2.63206,2.776406,0.451287,00:03
2,2.081343,2.32191,0.462546,00:03
3,1.874864,2.156537,0.465763,00:03
4,1.801396,2.106496,0.466222,00:03
5,1.785916,2.09968,0.466452,00:03


Our accuracy is about the same, since we are doing the same thing as before.

## 2. Multi fully connected model: ie. predict an arbitrary number of next words

Before, we were just predicting the last word in a line of text.  Given 70 tokens, what is token 71?  That approach was throwing away a lot of data.  Why not predict token 2 from token 1, then predict token 3, then predict token 4, and so on?  We will modify our model to do this.

In [55]:
data = src.databunch(bs=bs, bptt=20)

In [56]:
x,y = data.one_batch()
x.shape,y.shape

(torch.Size([64, 20]), torch.Size([64, 20]))

In [57]:
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.h_h = nn.Linear(nh,nh)
        self.h_o = nn.Linear(nh,nv)
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = torch.zeros(x.shape[0], nh).to(device=x.device)
        res = []
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
            res.append(self.h_o(self.bn(h)))
        return torch.stack(res, dim=1)

In [58]:
learn = Learner(data, Model2(), metrics=accuracy)

In [59]:
learn.fit_one_cycle(10, 1e-4, pct_start=0.1)

epoch,train_loss,valid_loss,accuracy,time
0,3.6316,3.497233,0.069176,00:02
1,3.430653,3.248969,0.172585,00:02
2,3.199888,3.036891,0.340554,00:02
3,2.981416,2.870872,0.405327,00:02
4,2.795409,2.750232,0.424716,00:02
5,2.649936,2.668055,0.438281,00:02
6,2.54428,2.617352,0.443679,00:02
7,2.473483,2.59065,0.445526,00:02
8,2.430873,2.580477,0.445881,00:02
9,2.40879,2.578957,0.445881,00:02


In [58]:
learn.fit_one_cycle??

Note that our accuracy is worse now, because we are doing a harder task.  When we predict word k (k<70), we have less history to help us then when we were only predicting word 71.

## 3. Maintain state with an RNN

To address this issue, let's keep the hidden state from the previous line of text, so we are not starting over again on each new line of text.

In [61]:
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.h_h = nn.Linear(nh,nh)
        self.h_o = nn.Linear(nh,nv)
        self.bn = nn.BatchNorm1d(nh)
        self.h = torch.zeros(bs, nh).cuda()
        
    def forward(self, x):
        res = []
        h = self.h
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
            res.append(self.bn(h))
        self.h = h.detach()
        res = torch.stack(res, dim=1)
        res = self.h_o(res)
        return res

In [62]:
learn = Learner(data, Model3(), metrics=accuracy)

In [63]:
learn.fit_one_cycle(20, 3e-3)

epoch,train_loss,valid_loss,accuracy,time
0,3.480024,3.5633,0.10554,00:01
1,2.932248,2.695795,0.368679,00:01
2,2.240176,2.114701,0.314986,00:01
3,1.822065,2.11982,0.316477,00:01
4,1.625373,2.182368,0.317756,00:01
5,1.532871,2.154896,0.319247,00:01
6,1.474447,2.100534,0.328693,00:01
7,1.355359,1.795488,0.416477,00:01
8,1.2058,1.702325,0.505895,00:01
9,1.057464,1.595995,0.560511,00:01


Now we are getting greater accuracy than before!

## Use PyTorch's nn.RNN

Let's refactor the above to use PyTorch's RNN.  This is what you would use in practice, but now you know the inside details!

Question: why is PyTorch's RNN so much better than ours?

In [65]:
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.rnn = nn.RNN(nh,nh, batch_first=True)
        self.h_o = nn.Linear(nh,nv)
        self.bn = BatchNorm1dFlat(nh)
        self.h = torch.zeros(1, bs, nh).cuda()
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(self.bn(res))

In [66]:
learn = Learner(data, Model4(), metrics=accuracy)

In [67]:
learn.fit_one_cycle(20, 3e-3)

epoch,train_loss,valid_loss,accuracy,time
0,3.693117,3.327585,0.22848,00:00
1,2.923024,2.223051,0.467116,00:00
2,2.182852,1.992746,0.347301,00:00
3,1.794714,2.088521,0.316903,00:00
4,1.605301,2.043347,0.320028,00:00
5,1.443225,1.676813,0.464134,00:00
6,1.241937,1.732475,0.501776,00:00
7,1.058475,1.779596,0.459659,00:00
8,0.9085,1.725978,0.48473,00:00
9,0.775708,1.589084,0.518821,00:00


## 4. 2-layer GRU gives a dramatic boost to accuracy!

When you have long time scales and deeper networks, these become impossible to train.  One way to address this is to add mini-NN to decide how much of the green arrow and how much of the orange arrow to keep.  These mini-NNs can be GRUs or LSTMs.  We will cover more details of this in a later lesson.

In [68]:
class Model5(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.rnn = nn.GRU(nh, nh, 2, batch_first=True)
        self.h_o = nn.Linear(nh,nv)
        self.bn = BatchNorm1dFlat(nh)
        self.h = torch.zeros(2, bs, nh).cuda()
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(self.bn(res))

In [69]:
learn = Learner(data, Model5(), metrics=accuracy)

In [70]:
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.74281,2.211538,0.455824,00:00
1,1.531137,1.083616,0.68821,00:00
2,0.735757,0.861519,0.81321,00:00
3,0.36466,0.621228,0.862855,00:00
4,0.185709,0.6929,0.85277,00:00
5,0.100498,0.848657,0.822869,00:00
6,0.058173,0.873973,0.813707,00:00
7,0.035563,0.889189,0.800923,00:00
8,0.023415,0.920494,0.800071,00:00
9,0.017021,0.914783,0.794673,00:00


## Let's make our own GRU

### Using PyTorch's GRUCell

Axis 0 is the batch dimension, and axis 1 is the time dimension.  We want to loop through axis 1:

In [76]:
def rnn_loop(cell, h, x):
    res = []
    for x_ in x.transpose(0,1):
        h = cell(x_, h)
        res.append(h)
    return torch.stack(res, dim=1)

In [77]:
class Model6(Model5):
    def __init__(self):
        super().__init__()
        self.rnnc = nn.GRUCell(nh, nh)
        self.h = torch.zeros(bs, nh).cuda()
        
    def forward(self, x):
        res = rnn_loop(self.rnnc, self.h, self.i_h(x))
        self.h = res[:,-1].detach()
        return self.h_o(self.bn(res))

In [78]:
learn = Learner(data, Model6(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.807608,2.667815,0.427344,00:01
1,1.906226,1.840542,0.45973,00:01
2,1.381995,1.91732,0.569176,00:01
3,0.857038,1.700449,0.684801,00:01
4,0.473643,1.715642,0.717756,00:01
5,0.259735,1.780169,0.738778,00:01
6,0.149085,1.828086,0.739702,00:01
7,0.09135,1.878771,0.750284,00:01
8,0.061516,1.879429,0.739915,00:01
9,0.046344,1.966933,0.734588,00:01


### With a custom (non-PyTorch) GRUCell

The following is based on code from [emadRad](https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb):

In [82]:
class GRUCell(nn.Module):
    def __init__(self, ni, nh):
        #super(GRUCell, self).__init__()
        super().__init__()
        
        self.ni,self.nh = ni,nh
        self.i2h = nn.Linear(ni, 3*nh)
        self.h2h = nn.Linear(nh, 3*nh)
    
    def forward(self, x, h):
        gate_x = self.i2h(x).squeeze()
        gate_h = self.h2h(h).squeeze()
        i_r,i_u,i_n = gate_x.chunk(3, 1)
        h_r,h_u,h_n = gate_h.chunk(3, 1)
        
        resetgate = torch.sigmoid(i_r + h_r)
        updategate = torch.sigmoid(i_u + h_u)
        newgate = torch.tanh(i_n + (resetgate*h_n))
        return updategate*h + (1-updategate)*newgate

In [83]:
class Model7(Model6):
    def __init__(self):
        super().__init__()
        self.rnnc = GRUCell(nh,nh)

In [84]:
learn = Learner(data, Model7(), metrics=accuracy)
learn.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.729431,2.500164,0.41044,00:02
1,1.884727,1.90512,0.40277,00:03
2,1.410755,2.112573,0.567543,00:03
3,0.888971,1.937243,0.69098,00:03
4,0.497026,1.894348,0.72081,00:02
5,0.270738,1.885511,0.742827,00:03
6,0.153405,2.10153,0.768466,00:03
7,0.093191,1.958075,0.769673,00:03
8,0.061403,1.874458,0.76875,00:03
9,0.044888,1.982405,0.763992,00:03


### Connection to ULMFit

In the previous lesson, we were essentially swapping out `self.h_o` with a classifier in order to do classification on text.

RNNs are just a refactored, fully-connected neural network.

You can use the same approach for any sequence labeling task (part of speech, classifying whether material is sensitive,..)


## fin