In [53]:
from fastai.text.all import *

import pdb

In [1]:
# !wget https://s3.amazonaws.com/text-datasets/nietzsche.txt

--2023-03-29 11:36:46--  https://s3.amazonaws.com/text-datasets/nietzsche.txt
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.77.102, 52.216.62.152, 52.217.107.94, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.77.102|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 600901 (587K) [text/plain]
Saving to: ‘nietzsche.txt’


2023-03-29 11:36:47 (1,22 MB/s) - ‘nietzsche.txt’ saved [600901/600901]



In [3]:
text = open('nietzsche.txt').read()

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)

total chars: 85


In [5]:
chars.insert(0, "\0")

''.join(chars)

'\x00\n !"\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyzÆäæéë'

In [6]:
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

In [7]:
tokens = L([c for c in text])
tokens

(#600893) ['P','R','E','F','A','C','E','\n','\n','\n'...]

In [8]:
nums = L(char_indices[i] for i in tokens)
nums

(#600893) [40,42,29,30,25,27,29,1,1,1...]

In [9]:
sl = 3

L([
    (tokens[i:i+sl], tokens[i+sl]) for i in range(len(tokens)-sl)
])

(#600890) [(['P', 'R', 'E'], 'F'),(['R', 'E', 'F'], 'A'),(['E', 'F', 'A'], 'C'),(['F', 'A', 'C'], 'E'),(['A', 'C', 'E'], '\n'),(['C', 'E', '\n'], '\n'),(['E', '\n', '\n'], '\n'),(['\n', '\n', '\n'], 'S'),(['\n', '\n', 'S'], 'U'),(['\n', 'S', 'U'], 'P')...]

In [10]:
L([
    (tensor(nums[i:i+sl]), nums[i+sl]) for i in range(len(tokens)-sl)
])

(#600890) [(tensor([40, 42, 29]), 30),(tensor([42, 29, 30]), 25),(tensor([29, 30, 25]), 27),(tensor([30, 25, 27]), 29),(tensor([25, 27, 29]), 1),(tensor([27, 29,  1]), 1),(tensor([29,  1,  1]), 1),(tensor([1, 1, 1]), 43),(tensor([ 1,  1, 43]), 45),(tensor([ 1, 43, 45]), 40)...]

In [11]:
d = nums[:100000]


data = L([
    (tensor(d[i:i+sl]), d[i+sl]) for i in range(len(d)-sl)
])

In [12]:
bs = 1024
cut = int(len(data) * 0.8)
dls = DataLoaders.from_dsets(data[:cut], data[cut:], bs=bs, shuffle=False)

In [13]:
xb, yb = dls.one_batch()
xb.shape

torch.Size([1024, 3])

In [14]:
yb

tensor([30, 25, 27,  ..., 56, 58, 72])

In [15]:
class Char3Model(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.h_h = nn.Linear(n_hidden, n_hidden) # [256,256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, x):
        # x = [1024, 3]
        
        in1 = self.i_h(x[:,0]) # [1024, 256]
        h1 = F.tanh(self.h_h(in1)) # [1024, 256]
        
        in2 = self.i_h(x[:,1]) # [1024, 256]
        h2 =  F.tanh(self.h_h(h1 + in2)) # [1024, 256]
        
        in3 = self.i_h(x[:,2]) # [1024, 256]
        h3 = F.tanh(self.h_h(h2 + in3)) # [1024, 256]
        
        out = self.h_o(h3) #[1024, 85]
        
        return out
    

In [17]:
model = Char3Model(vocab_size, 256)

learner = Learner(dls, model, loss_func=F.cross_entropy)

In [18]:
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.681352,2.357966,00:01
1,2.244841,2.148472,00:01
2,2.06292,2.075403,00:01
3,1.956914,2.02321,00:01
4,1.883873,1.977676,00:01
5,1.811974,1.936352,00:01
6,1.742999,1.882813,00:01
7,1.674462,1.853813,00:01
8,1.610426,1.828449,00:01
9,1.56737,1.824165,00:01




In [19]:
class CharXModel(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.h_h = nn.Linear(n_hidden, n_hidden) # [256,256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # x = [1024, "sl"]
        
        h = 0.
        for i in range(xs.shape[1]):
            h = h + self.i_h(xs[:,i]) # [1024, 256]
            h = F.tanh(self.h_h(h)) # [1024, 256]
                
        out = self.h_o(h) #[1024, 85]
        
        return out
        

In [20]:
model = CharXModel(vocab_size, 256)
learner = Learner(dls, model, loss_func=F.cross_entropy)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.695577,2.357954,00:01
1,2.247527,2.149469,00:01
2,2.061855,2.082031,00:01
3,1.95481,2.016302,00:01
4,1.880038,1.979227,00:01
5,1.808372,1.940966,00:01
6,1.739033,1.891341,00:01
7,1.671465,1.86209,00:01
8,1.608615,1.834718,00:01
9,1.565807,1.830057,00:01


In [24]:
class CharXConcatModel(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.h_h = nn.Linear(n_hidden+n_hidden, n_hidden) # [256x2,256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # x = [1024, "sl"]
        
        bs = xs.shape[0]
        h = torch.zeros(bs,self.n_hidden)
        for i in range(xs.shape[1]):
            h = torch.cat([h, self.i_h(xs[:,i])], dim=1) # [1024, 512]
            h = F.tanh(self.h_h(h)) # [1024, 256]
                
        out = self.h_o(h) #[1024, 85]
        
        return out
        

In [25]:
model = CharXConcatModel(vocab_size, 256)
learner = Learner(dls, model, loss_func=F.cross_entropy)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.653754,2.339791,00:01
1,2.222025,2.131947,00:01
2,2.050427,2.069796,00:01
3,1.95551,2.014739,00:01
4,1.885468,1.984746,00:01
5,1.81921,1.940857,00:01
6,1.75103,1.881467,00:01
7,1.678259,1.850601,00:01
8,1.610384,1.822538,00:01
9,1.563779,1.818214,00:01


In [26]:
class CharXRNNModel(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = nn.RNN(256, 256) # [256, 256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        bs = xs.shape[0] # 1024
        h = torch.zeros(1, bs, self.n_hidden) # [1,1024, 256]
        
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x.transpose(0,1), h)
        # hs [sl, 1024, 256]
        # h [1, 1024, 256]
        
        out = self.h_o(h.squeeze())
        
        return out
        

In [27]:
model = CharXRNNModel(vocab_size, 256)
learner = Learner(dls, model, loss_func=F.cross_entropy)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.640129,2.325046,00:01
1,2.211674,2.126548,00:01
2,2.048181,2.076583,00:01
3,1.955322,2.021035,00:01
4,1.888071,1.984448,00:01
5,1.82317,1.936843,00:01
6,1.750477,1.887074,00:01
7,1.676031,1.849825,00:01
8,1.606963,1.822414,00:01
9,1.55989,1.818767,00:01


In [28]:
class CustomRNN(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.n_in = n_in
        self.n_out = n_out
        
        self.l_hidden = nn.Linear(n_in, n_out)
        
    def forward(self, xs, h):
        # xs [sl, 1024, 256]
        # h  [1024, 256]
        hs = []
        for x in xs:
            inp = x + h # [1024, 256]
            h = torch.tanh(self.l_hidden(inp))
            hs.append(h)
        return torch.stack(hs), h


class CharXCustomRNNModel(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = CustomRNN(256, 256)
        # self.h_h = nn.Linear(n_hidden, n_hidden) # [256,256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        bs = xs.shape[0] # 1024
        
        h = torch.zeros(bs, self.n_hidden) # [1024, 256]
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x.transpose(0,1), h)
        # hs [sl, 1024, 256]
        # h [1024, 256]
        
        out = self.h_o(h)
        
        return out
        

In [29]:
model = CharXCustomRNNModel(vocab_size, 256)
learner = Learner(dls, model, loss_func=F.cross_entropy)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.68484,2.356459,00:01
1,2.248914,2.138613,00:01
2,2.062666,2.080703,00:01
3,1.955841,2.020806,00:01
4,1.879938,1.98384,00:01
5,1.813352,1.939183,00:01
6,1.743211,1.88777,00:01
7,1.674205,1.856273,00:01
8,1.61271,1.830978,00:01
9,1.569182,1.826869,00:01


In [30]:
class CustomConcatRNN(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.n_in = n_in
        self.n_out = n_out
        
        self.l_hidden = nn.Linear(2*n_in, n_out)
        
    def forward(self, xs, h):
        # xs [sl, 1024, 256]
        # h  [1024, 256]
        hs = []
        for x in xs:
            inp = torch.cat([x, h], dim=1) # [1024, 512]
            h = torch.tanh(self.l_hidden(inp))
            hs.append(h)
        return torch.stack(hs), h


class CharXCustomConcatRNNModel(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = CustomConcatRNN(256, 256)
        # self.h_h = nn.Linear(n_hidden, n_hidden) # [256,256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        bs = xs.shape[0] # 1024
        
        h = torch.zeros(bs, self.n_hidden) # [1024, 256]
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x.transpose(0,1), h)
        # hs [sl, 1024, 256]
        # h [1024, 256]
        
        out = self.h_o(h)
        
        return out
        

In [31]:
model = CharXCustomConcatRNNModel(vocab_size, 256)
learner = Learner(dls, model, loss_func=F.cross_entropy)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,2.663217,2.350113,00:01
1,2.235993,2.12818,00:01
2,2.060384,2.071265,00:01
3,1.964063,2.01411,00:01
4,1.893441,1.983125,00:01
5,1.824177,1.94024,00:01
6,1.754602,1.892925,00:01
7,1.682467,1.856571,00:01
8,1.613159,1.8282,00:01
9,1.566036,1.823506,00:01


### Now let's make this more efficient by using a different way to load the data

In [32]:
data = [
    (tensor(d[i:i+sl]), tensor(d[i+1:i+1+sl])) for i in range(0, len(d)-sl, sl)
]

In [33]:
data[0]

(tensor([40, 42, 29]), tensor([42, 29, 30]))

In [34]:
data[1]

(tensor([30, 25, 27]), tensor([25, 27, 29]))

In [35]:
bs = 1024
cut = int(len(data) * 0.8)
dls = DataLoaders.from_dsets(data[:cut], data[cut:], bs=bs, shuffle=False)

In [36]:
dls.one_batch()

(tensor([[40, 42, 29],
         [30, 25, 27],
         [29,  1,  1],
         ...,
         [ 2, 69, 65],
         [54, 62, 67],
         [58, 71,  8]]),
 tensor([[42, 29, 30],
         [25, 27, 29],
         [ 1,  1,  1],
         ...,
         [69, 65, 54],
         [62, 67, 58],
         [71,  8,  2]]))

In [37]:
def multi_output_ce(inp, targ):
    # inp = [sl, 1024, 85]
    # targ = [1024, sl]
    
    targ = targ.transpose(0,1).contiguous().view(-1) # [1024*sl]
    
    sl, bs, vocab_size = inp.shape
    inp = inp.view(-1,vocab_size) # [1024*sl, 85]
    
    return F.cross_entropy(inp, targ)


class CharXRNNModelMultiOutput(nn.Module):
    def __init__(self, vocab_size, n_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = nn.RNN(256, 256) # [256, 256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        bs = xs.shape[0] # 1024
        h = torch.zeros(1, bs, self.n_hidden) # [1,1024, 256]
        
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x.transpose(0,1), h)
        # hs [sl, 1024, 256]
        # h [1, 1024, 256]
        
        out = self.h_o(hs) # [3, 1024, 85]
        
        return out
        

In [38]:
model = CharXRNNModelMultiOutput(vocab_size, 256)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.322781,2.724751,00:00
1,2.799685,2.410069,00:00
2,2.5621,2.33468,00:00
3,2.42309,2.302087,00:00
4,2.330526,2.276958,00:00
5,2.260746,2.254164,00:00
6,2.204988,2.23214,00:00
7,2.157377,2.215478,00:00
8,2.116749,2.20741,00:00
9,2.086792,2.202936,00:00


### So this does speed things up.. however we get pretty bad loss. Let's improve this by making things statefull

In [39]:
dls = DataLoaders.from_dsets(data[:cut], data[cut:], bs=bs, shuffle=False, drop_last=True)

class CharXRNNModelMultiOutputStateful(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.h = torch.zeros(1, bs, self.n_hidden) # [1,1024, 256]
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = nn.RNN(256, 256) # [256, 256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x.transpose(0,1), self.h)
        # hs [sl, 1024, 256]
        # h [1, 1024, 256]
        
        self.h = h.detach()
        
        out = self.h_o(hs) # [3, 1024, 85]
        
        return out
        

In [40]:
model = CharXRNNModelMultiOutputStateful(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.257765,2.633426,00:00
1,2.797917,2.42302,00:00
2,2.598141,2.384452,00:00
3,2.484302,2.346257,00:00
4,2.408065,2.340763,00:00
5,2.35111,2.321306,00:00
6,2.303579,2.307419,00:00
7,2.261971,2.29159,00:00
8,2.224066,2.286012,00:00
9,2.195254,2.281302,00:00


### Still not really a lot better, that's because the stuff doesn't align...

In [41]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds

In [42]:
cut = int(len(data) * 0.8)
dls = DataLoaders.from_dsets(
    group_chunks(data[:cut], bs), 
    group_chunks(data[cut:], bs), 
    bs=bs, drop_last=True, shuffle=False)

In [43]:
model = CharXRNNModelMultiOutputStateful(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.227303,2.585609,00:00
1,2.682618,2.227214,00:00
2,2.367397,2.08143,00:00
3,2.159767,2.018955,00:00
4,2.010766,1.978363,00:00
5,1.900762,1.963945,00:00
6,1.81079,1.920839,00:00
7,1.73722,1.89898,00:00
8,1.675749,1.8834,00:00
9,1.63058,1.878325,00:00


### Now it's fast and has good performance!

### Let's simplify a couple of things by reordering and using batch_first=True on the RNN

In [44]:
def multi_output_ce(inp, targ):
    # inp = [1024, sl, 85]
    # targ = [1024, sl]
    
    targ = targ.view(-1) # [1024*sl]
    
    sl, bs, vocab_size = inp.shape
    inp = inp.view(-1,vocab_size) # [1024*sl, 85]
    
    return F.cross_entropy(inp, targ)


class CharXRNNModelMultiOutputStatefulBatchFirst(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.h = torch.zeros(1, bs, self.n_hidden) # [1,1024, 256]
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = nn.RNN(256, 256, batch_first=True) # [256, 256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x, self.h)
        # hs [1024, sl, 256]
        # h [1, 1024, 256]
        
        self.h = h.detach()
        
        out = self.h_o(hs) # [1024, sl, 85]
        return out
        

In [45]:
model = CharXRNNModelMultiOutputStatefulBatchFirst(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.23635,2.580962,00:00
1,2.685823,2.23114,00:00
2,2.370497,2.09285,00:00
3,2.164947,2.016285,00:00
4,2.011562,1.967001,00:00
5,1.896052,1.94108,00:00
6,1.806493,1.91636,00:00
7,1.731916,1.888724,00:00
8,1.669649,1.874724,00:00
9,1.624042,1.869331,00:00


### That means we can also use the CrossEntropyLossFlat loss 

In [None]:
model = CharXRNNModelMultiOutputStatefulBatchFirst(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=CrossEntropyLossFlat())
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.180236,2.551947,00:00
1,2.657809,2.230387,00:00
2,2.354864,2.102542,00:00
3,2.152959,2.018605,00:00
4,2.005427,1.971653,00:00
5,1.89443,1.939518,00:00
6,1.805748,1.914371,00:00
7,1.732585,1.88928,00:00
8,1.670935,1.873836,00:00
9,1.625509,1.868808,00:00


In [58]:
class CharXMultiRNNModelMultiOutputStateful(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs, n_layers):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        
        self.h = torch.zeros(n_layers, bs, self.n_hidden) # [2, 1024, 256]
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True) # [256, 256]
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, sl]
        
        x = self.i_h(xs) # [1024, sl, 256]
        
        hs, h = self.rnn(x, self.h)
        # hs [1024, sl, 256]
        # h [2, 1024, 256]
        
        self.h = h.detach()
        
        out = self.h_o(hs) # [1024, sl, 85]
        return out
        

In [59]:
model = CharXMultiRNNModelMultiOutputStateful(vocab_size, 256, 1024, 2)
learner = Learner(dls, model, loss_func=CrossEntropyLossFlat())
learner.fit_one_cycle(10, 1e-2)

epoch,train_loss,valid_loss,time
0,3.171021,2.566434,00:01
1,2.641303,2.217872,00:01
2,2.32472,2.058521,00:01
3,2.101378,1.972325,00:01
4,1.941963,1.942842,00:01
5,1.822112,1.925534,00:01
6,1.723762,1.877729,00:01
7,1.637169,1.855244,00:01
8,1.560163,1.83758,00:01
9,1.501047,1.832978,00:01


### LSTMs

![lstm.png](lstm.png)

### Create custom LSTMCell according to the diagram above:

In [259]:
class LSTMCell(nn.Module):
    def __init__(self, ni, nh):
        super().__init__()
        self.ni = ni
        self.nh = nh
        
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate = nn.Linear(ni + nh, nh)
        self.cell_gate = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)
        
    def forward(self, inp, state):
        # inp [bs, ni]
        # state ([bs, nh], [bs, nh])
        h,c = state
        
        h = torch.cat([h, inp], dim=1) # [bs, ni+nh]
        
        forget = torch.sigmoid(self.forget_gate(h)) # [bs, nh]
        c = forget * c
        
        inp = torch.sigmoid(self.input_gate(h)) # [bs, nh]
        cell = torch.tanh(self.cell_gate(h)) # [bs, nh]
        
        inp = inp * cell 
        c = c + inp
        
        output = torch.sigmoid(self.output_gate(h))
        output = output * torch.tanh(c)
        
        return output, (output,c)

### And build the model:

In [264]:
class LSTMNet(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs):
        super().__init__()
        
        self.state = [torch.zeros(bs, n_hidden), torch.zeros(bs, n_hidden)] # ([1024, 256], [1024, 256])
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.lstm = LSTMCell(n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, 3]
        
        x = self.i_h(xs) # [1024, 3, 256]
                
        outs = []
        out, self.state = self.lstm(x[:,0,:], self.state)
        outs.append(out)
        out, self.state = self.lstm(x[:,1,:], self.state)
        outs.append(out)
        out, self.state = self.lstm(x[:,2,:], self.state)
        outs.append(out)
        
        self.state = [i.detach() for i in self.state]
        
        out = self.h_o(torch.stack(outs, dim=1)) # [1024, sl, 85]
        return out
        

In [265]:
model = LSTMNet(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=CrossEntropyLossFlat())

In [266]:
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.864736,2.338626,00:02
1,2.422578,2.156112,00:02
2,2.224717,2.10926,00:02
3,2.092041,2.061114,00:02
4,1.985551,2.021141,00:02
5,1.891941,1.976507,00:02
6,1.804668,1.929145,00:02
7,1.718539,1.906536,00:02
8,1.639218,1.879113,00:02
9,1.576758,1.875072,00:02


### Factor out the LSTMLayer

In [263]:
class LSTMLayer(nn.Module):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    def forward(self, input, state):
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

In [267]:
class LSTMNet1(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs, lstm_cell):
        super().__init__()
        
        self.h = [torch.zeros(bs, n_hidden), torch.zeros(bs, n_hidden)] # ([1024, 256], [1024, 256])
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.lstm = LSTMLayer(lstm_cell, n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, 3]
        
        x = self.i_h(xs) # [1024, 3, 256]
        
        outs, self.h = self.lstm(x, self.h)
        self.h = [i.detach() for i in self.h]
        
        out = self.h_o(outs) # [1024, sl, 85]
        return out
        

In [268]:
model = LSTMNet1(vocab_size, 256, 1024, LSTMCell)
learner = Learner(dls, model, loss_func=CrossEntropyLossFlat())
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.872561,2.352835,00:02
1,2.415869,2.148014,00:02
2,2.20981,2.093563,00:02
3,2.07361,2.049113,00:02
4,1.966854,2.011589,00:02
5,1.874174,1.973702,00:02
6,1.784474,1.925455,00:02
7,1.695859,1.900301,00:02
8,1.615563,1.87636,00:02
9,1.55256,1.871616,00:02


### LSTMCell that "adds" instead of "concats" the input and hidden state together

In [269]:
class LSTMCellAdd(nn.Module):
    def __init__(self, ni, nh):
        super().__init__()
        self.ni = ni
        self.nh = nh
        
        self.forget_gate = nn.Linear(ni, nh)
        self.input_gate = nn.Linear(ni, nh)
        self.cell_gate = nn.Linear(ni, nh)
        self.output_gate = nn.Linear(ni, nh)
        
    def forward(self, inp, state):
        # inp [bs, ni]
        # state ([bs, nh], [bs, nh])
        h,c = state
        
        h = h + inp
        
        forget = torch.sigmoid(self.forget_gate(h)) # [bs, nh]
        c = forget * c
        
        inp = torch.sigmoid(self.input_gate(h)) # [bs, nh]
        cell = torch.tanh(self.cell_gate(h)) # [bs, nh]
        
        inp = inp * cell 
        c = c + inp
        
        output = torch.sigmoid(self.output_gate(h))
        output = output * torch.tanh(c)
        
        return output, (output,c)


In [270]:
model = LSTMNet1(vocab_size, 256, 1024, LSTMCellAdd)
learner = Learner(dls, model, loss_func=CrossEntropyLossFlat())
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.906466,2.36446,00:01
1,2.43081,2.150088,00:01
2,2.213123,2.09214,00:01
3,2.072024,2.052454,00:01
4,1.963137,2.020215,00:01
5,1.869209,1.966592,00:01
6,1.778399,1.934537,00:01
7,1.688962,1.90055,00:01
8,1.60755,1.877616,00:01
9,1.544713,1.873836,00:01


### PyTorch is using a different LSTMCell:

![torch_lstm.png](torch_lstm.png)

In [171]:
class LSTMCellTorch(nn.Module):
    def __init__(self, ni, nh):
        super().__init__()
        self.ni = ni
        self.nh = nh
        
        self.forget_gate_inp = nn.Linear(ni, nh)
        self.forget_gate_h = nn.Linear(ni, nh)
        
        self.input_gate_inp = nn.Linear(ni, nh)
        self.input_gate_h = nn.Linear(ni, nh)
        
        self.cell_gate_inp = nn.Linear(ni, nh)
        self.cell_gate_h = nn.Linear(ni, nh)
        
        self.output_gate_inp = nn.Linear(ni, nh)
        self.output_gate_h = nn.Linear(ni, nh)
        
    def forward(self, inp, state):
        # inp [bs, ni]
        # state ([bs, nh], [bs, nh])
        h,c = state
        
        i = (self.input_gate_inp(inp) + self.input_gate_h(h)).sigmoid()
        f = (self.forget_gate_inp(inp) + self.forget_gate_h(h)).sigmoid()
        g = (self.cell_gate_inp(inp) + self.cell_gate_h(h)).tanh()
        o = (self.output_gate_inp(inp) + self.output_gate_h(h)).sigmoid()
        
        c = f * c + i * g
        output = o * c.tanh()
        
        return output, (output,c)

In [271]:
model = LSTMNet1(vocab_size, 256, 1024, LSTMCellTorch)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.855768,2.33383,00:02
1,2.400611,2.139885,00:02
2,2.200271,2.086586,00:02
3,2.067042,2.050319,00:02
4,1.95979,2.010012,00:02
5,1.867491,1.971201,00:02
6,1.777975,1.918855,00:02
7,1.689493,1.895441,00:02
8,1.60682,1.871696,00:02
9,1.541801,1.86778,00:02


### To use the nn.LSTMCell we have to change our LSTMLayer, since nn.LSTMCell has a slightly different API

In [274]:
class LSTMLayerTorch(nn.Module):
    def __init__(self, cell, *cell_args):
        super().__init__()
        self.cell = cell(*cell_args)

    def forward(self, input, state):
        inputs = input.unbind(1)
        outputs = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            state = (out, state)
            outputs += [out]
        return torch.stack(outputs, dim=1), state

class LSTMNetTorch(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs):
        super().__init__()
        
        self.h = [torch.zeros(bs, n_hidden), torch.zeros(bs, n_hidden)] # ([1024, 256], [1024, 256])
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.lstm = LSTMLayerTorch(nn.LSTMCell, n_hidden, n_hidden)
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, 3]
        
        x = self.i_h(xs) # [1024, 3, 256]
        
        outs, self.h = self.lstm(x, self.h)
        self.h = [i.detach() for i in self.h]
        
        out = self.h_o(outs) # [1024, sl, 85]
        return out

In [275]:
model = LSTMNetTorch(vocab_size, 256, 1024)
learner = Learner(dls, model, loss_func=multi_output_ce)

In [276]:
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.870847,2.355901,00:02
1,2.41933,2.163442,00:01
2,2.215185,2.105365,00:02
3,2.080186,2.051028,00:01
4,1.971298,2.011527,00:01
5,1.877012,1.957065,00:02
6,1.787701,1.921162,00:01
7,1.701306,1.894734,00:02
8,1.621861,1.872823,00:02
9,1.559144,1.868233,00:02


### Let's now use the native LSTM module

In [281]:
class LSTMNetTorch2(nn.Module):
    def __init__(self, vocab_size, n_hidden, bs, n_layers):
        super().__init__()
        
        self.h = [torch.zeros(n_layers, bs, n_hidden), torch.zeros(n_layers, bs, n_hidden)] # ([1024, 256], [1024, 256])
        
        self.i_h = nn.Embedding(vocab_size, n_hidden) # [85, 256]
        self.lstm = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        
        self.h_o = nn.Linear(n_hidden, vocab_size) # [256, 85]
        
    def forward(self, xs):
        # xs = [1024, 3]
        
        x = self.i_h(xs) # [1024, 3, 256]
        
        outs, self.h = self.lstm(x, self.h)
        self.h = [i.detach() for i in self.h]
        
        out = self.h_o(outs) # [1024, 3, 85]
        return out

In [282]:
model = LSTMNetTorch2(vocab_size, 256, 1024, 1)
learner = Learner(dls, model, loss_func=multi_output_ce)
learner.fit_one_cycle(10, 5e-2)

epoch,train_loss,valid_loss,time
0,2.871213,2.338685,00:02
1,2.41554,2.146561,00:02
2,2.217885,2.092503,00:02
3,2.079339,2.04381,00:02
4,1.972076,2.004945,00:02
5,1.879082,1.96896,00:02
6,1.790795,1.926751,00:02
7,1.70488,1.891241,00:02
8,1.624404,1.866451,00:02
9,1.560885,1.863476,00:02


### After this, we could continue with more NLP based stuff and get into transformers etc, we would do that by following this [course](https://www.youtube.com/playlist?list=PLtmWHNX-gukKocXQOkQjuVxglSDYWsSh9)

Two great resources:

- https://colah.github.io/posts/2015-08-Understanding-LSTMs/
- http://karpathy.github.io/2015/05/21/rnn-effectiveness/