# LSTM

In [134]:
import torch
import torch.nn as nn
import numpy as np

import matplotlib.pyplot as plt

# Explore LSTM

In [135]:
input_size = 9
hidden_size = 16
num_layers = 2
lstm = nn.LSTM(input_size, hidden_size, num_layers,batch_first=True)
lstm

LSTM(9, 16, num_layers=2, batch_first=True)

In [136]:
batch_size = 3
sequence_length = 8

X = torch.randn(batch_size,sequence_length,input_size)

H = torch.randn(num_layers,batch_size,hidden_size) # cell state
C = torch.randn(num_layers,batch_size,hidden_size) # hidden state

hidden_input = (H,C)

out, (h,c) = lstm(X,hidden_input)
X.shape , out.shape, h.shape, c.shape

(torch.Size([3, 8, 9]),
 torch.Size([3, 8, 16]),
 torch.Size([2, 3, 16]),
 torch.Size([2, 3, 16]))

In [137]:
h[0,0,:] # hidden state of first layer , first batch and all time steps

tensor([-0.1325, -0.1257, -0.2387,  0.1612, -0.0332,  0.0516,  0.2478,  0.0851,
         0.0167, -0.1328, -0.1525,  0.2975, -0.1094, -0.1648,  0.0196,  0.1214],
       grad_fn=<SliceBackward0>)

In [138]:
for p in lstm.named_parameters():
    if 'weight' in p[0]:
        print(f'{p[0]}: {p[1].shape}') # 64 cames from 16*4 where 4 is the number of gates and 16 is the hidden size

weight_ih_l0: torch.Size([64, 9])
weight_hh_l0: torch.Size([64, 16])
weight_ih_l1: torch.Size([64, 16])
weight_hh_l1: torch.Size([64, 16])


# Example

In [139]:
text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer volutpat bibendum risus id molestie. Integer sit amet arcu vitae leo maximus convallis. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean eu dui feugiat, interdum metus eget, tempor ante. Maecenas luctus enim nec diam pellentesque rhoncus. Suspendisse vestibulum suscipit ultrices. Sed at gravida enim. Integer laoreet risus risus, ac semper sem feugiat ut. Mauris et massa ipsum.' \
'Sed laoreet eget ipsum efficitur venenatis. Aenean cursus porttitor suscipit. Etiam hendrerit leo consequat nulla egestas, nec tristique diam mattis. Morbi fermentum elementum nisl, id porta lorem luctus posuere. Etiam dapibus lacus sed mauris molestie, eget lacinia lacus lobortis. Etiam imperdiet vitae ante eget vestibulum. Duis purus sapien, accumsan eu sollicitudin at, faucibus eget lacus. Maecenas sit amet risus faucibus, pulvinar elit quis, iaculis erat. Vivamus condimentum volutpat orci at bibendum. Duis ut convallis magna, ut pretium ipsum. Aenean feugiat varius dui ut luctus. Pellentesque eros massa, pharetra et purus quis, ullamcorper volutpat justo. Ut nec enim eu ante pulvinar hendrerit vitae imperdiet leo.' \
'Nullam et nibh et purus faucibus fermentum. Nunc cursus maximus lacus, pretium congue tellus bibendum id. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc interdum lacus ac lacus malesuada scelerisque. Mauris urna risus, lacinia at est sed, pulvinar varius sem. Aenean enim nisi, condimentum eleifend finibus non, vehicula eu nibh. Donec pretium ligula purus, eu commodo risus iaculis eu. Fusce vehicula dolor eget rhoncus auctor. Ut iaculis, sem a faucibus dignissim, odio magna scelerisque nulla, a pulvinar mi lectus quis ante. Vestibulum sem enim, ullamcorper non augue sodales, fermentum dignissim augue. Phasellus vitae mollis ipsum, a sodales nunc. In tempor sagittis risus vitae ornare. Etiam tristique tellus vehicula, iaculis quam sed, facilisis nunc. Nunc suscipit nisi eu suscipit viverra. Interdum et malesuada fames ac ante ipsum primis in faucibus.'

In [140]:
print(text)

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer volutpat bibendum risus id molestie. Integer sit amet arcu vitae leo maximus convallis. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean eu dui feugiat, interdum metus eget, tempor ante. Maecenas luctus enim nec diam pellentesque rhoncus. Suspendisse vestibulum suscipit ultrices. Sed at gravida enim. Integer laoreet risus risus, ac semper sem feugiat ut. Mauris et massa ipsum.Sed laoreet eget ipsum efficitur venenatis. Aenean cursus porttitor suscipit. Etiam hendrerit leo consequat nulla egestas, nec tristique diam mattis. Morbi fermentum elementum nisl, id porta lorem luctus posuere. Etiam dapibus lacus sed mauris molestie, eget lacinia lacus lobortis. Etiam imperdiet vitae ante eget vestibulum. Duis purus sapien, accumsan eu sollicitudin at, faucibus eget lacus. Maecenas sit amet risus faucibus, pulvinar elit quis, iaculis erat. Vivamus condimentum volutpat orci at bibendum. Duis ut convallis magna, ut pr

In [141]:
alphabet = set(text.lower())
len(alphabet)

25

In [142]:
alphabet_map = {c:i for i,c in enumerate(alphabet)}
alphabet_map

{'t': 0,
 'c': 1,
 ' ': 2,
 'u': 3,
 'v': 4,
 'a': 5,
 's': 6,
 'g': 7,
 'f': 8,
 'i': 9,
 'e': 10,
 'l': 11,
 'n': 12,
 'h': 13,
 'r': 14,
 '.': 15,
 'j': 16,
 'o': 17,
 'd': 18,
 ',': 19,
 'x': 20,
 'p': 21,
 'b': 22,
 'q': 23,
 'm': 24}

In [143]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
data = torch.zeros((len(text),1),dtype=torch.int64,device=device)

for i,c in enumerate(text.lower()):
    data[i,0] = alphabet_map[c]

data

tensor([[11],
        [17],
        [14],
        ...,
        [ 3],
        [ 6],
        [15]], device='mps:0')

In [186]:
class LSTMModel(nn.Module):
    def __init__(self, input_size,embedding_size, hidden_size, num_layers, num_classes):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, num_classes)
       
        
    def forward(self, x,h=None):
       
        embedding = self.embedding(x)
        out,hidden  = self.lstm(embedding,h) 
        # Decode the hidden state of the last time step
        out = self.fc(out[-1])
        return out,hidden

In [187]:
hidden_size = 512
num_layers = 3
sequence_length = 80
embedding_size = 64
input_size = len(alphabet)

model = LSTMModel(len(alphabet),embedding_size, hidden_size, num_layers, len(alphabet)).to(device)
model

LSTMModel(
  (embedding): Embedding(25, 64)
  (lstm): LSTM(64, 512, num_layers=3)
  (fc): Linear(in_features=512, out_features=25, bias=True)
)

In [188]:
lossFun = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001)

In [189]:
offset = 0
x = data[offset:offset+sequence_length]
y = data[offset+sequence_length]
x.shape , y.shape
y

tensor([3], device='mps:0')

In [191]:
pred = model(x)
pred[0].shape # since we have 80 time step, we need only the last one to predict

torch.Size([1, 25])

In [192]:
def trainLSTM(model,data,lossFun,optimizer,epochs):
    losses = np.zeros(epochs)
    accuracy = np.zeros(epochs)
    for epoch in range(epochs):
        txtLoss = []
        txtAcc = []

        for offset in range(len(data)-sequence_length):
            x = data[offset:offset+sequence_length].to(device)
            y = data[offset+sequence_length].to(device)
            pred = model(x)[0]
            loss = lossFun(pred,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            txtLoss.append(loss.item())
            acc = (torch.argmax(pred,dim=1) == y).float()
            txtAcc.append(acc.item())

        accuracy[epoch] = np.mean(txtAcc)
        losses[epoch] = np.mean(txtLoss)
        print(f'Loss {epoch+1}/{epochs}: {losses[epoch]} Accuracy {epoch+1}/{epochs}: {accuracy[epoch]}')    
            


In [193]:
trainLSTM(model,data,lossFun,optimizer,10)

Loss 1/10: 2.6824845738312386 Accuracy 1/10: 0.18920282542885974
Loss 2/10: 2.291729949234527 Accuracy 2/10: 0.2815338042381433
Loss 3/10: 2.0235883114818582 Accuracy 3/10: 0.3521695257315843
Loss 4/10: 1.744490340944355 Accuracy 4/10: 0.43693239152371344
Loss 5/10: 1.3967397079701318 Accuracy 5/10: 0.5332996972754793
Loss 6/10: 1.0752430218155027 Accuracy 6/10: 0.641271442986882
Loss 7/10: 0.7934677864220305 Accuracy 7/10: 0.7346115035317861
Loss 8/10: 0.6053192191701051 Accuracy 8/10: 0.7936427850655903
Loss 9/10: 0.4393341816199679 Accuracy 9/10: 0.8541876892028254
Loss 10/10: 0.36431626299200937 Accuracy 10/10: 0.8859737638748738


In [212]:
number2char = {v:k for k,v in alphabet_map.items()}

In [213]:
t = ''
for l in x:
    t+=number2char[l.item()]
t    

'lorem ipsum dolor sit amet, consectetur adipiscing elit. integer volutpat bibend'

# Generate new text

In [214]:
generated_length =300
hidden = None


first_letter = torch.tensor(alphabet_map['b']).reshape(1,1).to(device)

generated_txt = number2char[first_letter.item()]
for i in range(generated_length):
    pred,hidden = model(first_letter,hidden)
    pred = torch.argmax(pred,dim=1)
    first_letter = pred.reshape(1,1)
    char = number2char[pred.item()]
    generated_txt += char
generated_txt

'bi at bibendum risus id molestie, interdum et malesuada fames ac ante ipsum primis in faucibus quis sit amet, facilisis nunc. nunc suscipit viverra. irterdum et malesuada fames ac ante ipsum primis in faucibus quis sit amet, facilisis nunc. nunc suscipit viverra. irterdum et malesuada fames ac ante i'

In [236]:
test_text = 'Proin interdum, nisi in tempor mollis, risus erat condimentum ipsum, id fringilla elit est ut massa. Donec at auctor ipsum, eget laoreet risus. Donec nisi sapien, euismod id massa id, accumsan accumsan ex. Nulla facilisi. Ut vulputate lacus libero, in tristique nisl viverra id. Nunc euismod tincidunt odio. Quisque gravida id est eget venenatis.'.lower()
part_test = test_text[:61]
part_test

'proin interdum, nisi in tempor mollis, risus erat condimentum'

In [237]:
part_test_tensor = torch.zeros((len(part_test),1),dtype=torch.int64,device=device)

for i,c in enumerate(part_test.lower()):
    part_test_tensor[i,0] = alphabet_map[c]

hidden = None

pred,hidden = model(part_test_tensor,hidden)
pred = torch.argmax(pred,dim=1)
first_letter = pred.reshape(1,1)
char = number2char[pred.item()]
part_test += char

for i in range(generated_length):
    pred,hidden = model(first_letter,hidden)
    pred = torch.argmax(pred,dim=1)
    first_letter = pred.reshape(1,1)
    char = number2char[pred.item()]
    part_test += char
part_test  


'proin interdum, nisi in tempor mollis, risus erat condimentum dignissim augue. phasellus vehicula, iaculis quam sed, facilisis nunc. nunc suscipit viverra. irterdum et malesuada fames ac ante ipsum primis in faucibus quis sit amet, facilisis nunc. nunc suscipit viverra. irterdum et malesuada fames ac ante ipsum primis in faucibus quis sit amet, facilisis nunc.'