<a href="https://colab.research.google.com/github/debasishcharanbehera1998/Assignments/blob/main/LSTM_handson.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Building a Next word prediction model using RNN that can predict the next word given a context

In [1]:
text = """Virat Kohli (Hindi pronunciation is an Indian international cricketer who currently plays Test cricket
        and ODI cricket for India. Kohli is a former T20I player and a former
        captain of the Indian national cricket team. He is a right-handed batsman
        and an occasional unorthodox right arm quick bowler. He currently
        represents Royal Challengers Bengaluru in the IPL and Delhi in
        domestic cricket. He holds the record as the highest run-scorer
        in IPL, ranks third in T20I, third in ODI, and stands as the
        fourth-highest in international cricket.[4] He also holds the
        record for scoring the most centuries in ODI cricket and stands
        second in the list of most international centuries scored. Hence,
        Kohli is widely regarded as one of the greatest batsmen of all time
        and the modern era. Kohli was a key member of the Indian team that
        won the 2011 Cricket World Cup, 2013 Champions Trophy and 2024 T20
        World Cup and captained India to win the ICC Test match three"""

In [2]:
import torch

In [3]:
# !nvidia-smi

## converting the text into labled dataset

In [4]:
word2idx = {word:i for i, word in enumerate(set(text.split()))}
word2idx

{'Cup,': 0,
 'right-handed': 1,
 'unorthodox': 2,
 'most': 3,
 'batsmen': 4,
 'widely': 5,
 'won': 6,
 '2011': 7,
 'win': 8,
 'pronunciation': 9,
 'cricket': 10,
 'one': 11,
 'international': 12,
 'Challengers': 13,
 'time': 14,
 'scoring': 15,
 'an': 16,
 'T20I': 17,
 'currently': 18,
 'run-scorer': 19,
 'also': 20,
 'for': 21,
 'World': 22,
 'of': 23,
 'record': 24,
 'Indian': 25,
 'Trophy': 26,
 'arm': 27,
 'modern': 28,
 'ODI': 29,
 'ODI,': 30,
 'Hence,': 31,
 'India': 32,
 'was': 33,
 'team': 34,
 '2013': 35,
 'fourth-highest': 36,
 'match': 37,
 'former': 38,
 '(Hindi': 39,
 'cricket.[4]': 40,
 'the': 41,
 'captained': 42,
 'is': 43,
 'centuries': 44,
 'third': 45,
 'as': 46,
 'Royal': 47,
 'and': 48,
 'Cup': 49,
 'captain': 50,
 'Test': 51,
 'a': 52,
 'highest': 53,
 'Virat': 54,
 'team.': 55,
 'bowler.': 56,
 'member': 57,
 'all': 58,
 'greatest': 59,
 'Kohli': 60,
 'to': 61,
 'list': 62,
 'India.': 63,
 'holds': 64,
 'quick': 65,
 '2024': 66,
 'era.': 67,
 'that': 68,
 'domest

In [5]:
from torch.utils.data import Dataset

## In order to create any custom dataset in torch you need to define 3
 class methods
 1. __init__
 2. __len__
 3. __getitem__

In [7]:
class customDataset(Dataset):
  def __init__(self, text, word2idx, seq_length):
    self.text = text
    self.word2idx = word2idx
    self.seq_length = seq_length

  def __len__(self):
    return len(self.text) - self.seq_length

  def __getitem__(self,index):
    sequence = [self.word2idx[word] for word in self.text[index:index+self.seq_length]]
    target = self.word2idx[self.text[index+self.seq_length]]

    return torch.tensor(sequence), torch.tensor(target)

In [8]:
dataset = customDataset(text.split(),word2idx,20)

In [9]:
dataset[10]

(tensor([18, 80, 51, 10, 48, 29, 10, 21, 63, 60, 43, 52, 38, 17, 73, 48, 52, 38,
         50, 23]),
 tensor(41))

In [10]:
idx2word = {i:word for word,i in word2idx.items()}

In [None]:
# for m in [84, 10, 56, 95, 42, 80, 95, 86, 33, 91]:
#   print(idx2word[m])

In [11]:
from torch.utils.data import DataLoader

In [12]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [13]:
import torch.nn as nn

In [14]:
class LSTMmodel(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size):
    super(LSTMmodel, self).__init__()
    self.embed = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size,hidden_size,10,batch_first=True)
    self.fc = nn.Linear(hidden_size,vocab_size)

  def forward(self, x, h0, c0):
    embed = self.embed(x)
    out,(h_n, c_n) = self.lstm(embed, (h0,c0))
    output = self.fc(out[:,-1,:])
    return output, (h_n, c_n)

In [16]:
model = LSTMmodel(len(word2idx),embed_size= 128,hidden_size = 256).to('cuda')

In [17]:
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

## Trainig the RNN model

In [18]:
%%timeit

for epoch in range(10):
  for input,label in dataloader:
    optimizer.zero_grad()
    input = input.to('cuda')
    label = label.to('cuda')
    h0 = torch.zeros(10,input.size(0),256).to('cuda')
    c0 = torch.zeros(10,input.size(0),256).to('cuda')
    outputs, _ = model(input,h0,c0)
    loss = criterion(outputs,label)
    loss.backward()
    optimizer.step()
  print(f"Epoch {epoch} : Loss : {loss.item()}")

Epoch 0 : Loss : 4.591339111328125
Epoch 1 : Loss : 4.60338830947876
Epoch 2 : Loss : 4.588670253753662
Epoch 3 : Loss : 4.606540203094482
Epoch 4 : Loss : 4.592824459075928
Epoch 5 : Loss : 4.593382835388184
Epoch 6 : Loss : 4.571767330169678
Epoch 7 : Loss : 4.5704450607299805
Epoch 8 : Loss : 4.5910210609436035
Epoch 9 : Loss : 4.5783305168151855
Epoch 0 : Loss : 4.584235668182373
Epoch 1 : Loss : 4.580408096313477
Epoch 2 : Loss : 4.576195240020752
Epoch 3 : Loss : 4.5593085289001465
Epoch 4 : Loss : 4.573576927185059
Epoch 5 : Loss : 4.564978122711182
Epoch 6 : Loss : 4.562230110168457
Epoch 7 : Loss : 4.573721408843994
Epoch 8 : Loss : 4.585414409637451
Epoch 9 : Loss : 4.556032180786133
Epoch 0 : Loss : 4.591492176055908
Epoch 1 : Loss : 4.579452991485596
Epoch 2 : Loss : 4.561901569366455
Epoch 3 : Loss : 4.58981990814209
Epoch 4 : Loss : 4.57583475112915
Epoch 5 : Loss : 4.590928077697754
Epoch 6 : Loss : 4.563047409057617
Epoch 7 : Loss : 4.546274185180664
Epoch 8 : Loss : 4.

In [19]:
input_seq = torch.tensor([word2idx[word] for word in text.split()[-10:]]).unsqueeze(0).to('cuda')
h0 = torch.zeros(10,input_seq.size(0),256).to('cuda')
c0 = torch.zeros(10,input_seq.size(0),256).to('cuda')
ouput, _ = model(input_seq,h0,c0)
predicted_word = torch.argmax(ouput).item()
print(f"The next predicted word is : {idx2word[predicted_word]}")

The next predicted word is : the
