<a href="https://colab.research.google.com/github/drashyabansel/GenerativeAI/blob/main/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
text = """Virat Kohli (Hindi pronunciation is an Indian international cricketer who currently plays Test cricket
 and ODI cricket for India. Kohli is a former T20I player and a former
 and an occasional unorthodox right arm quick bowler. He currently
 represents Royal Challengers Bengaluru in the IPL and Delhi in
 domestic cricket. He holds the record as the highest run-scorer
 in IPL, ranks third in T20I, third in ODI, and stands as the
 fourth-highest in international cricket. [4] He also holds the record
 for scoring the most centuries in ODI cricket and stands
 second in the list of most international centuries scored. Hence,
 Kohli is widely regarded as one of the greatest batsmen of all time
 and the modern era. Kohli was a key member of the Indian team that
 won the 2011 Cricket World Cup, 2013 Champions Trophy and 2024 T20
 World Cup and captained India to win the ICC Test match three """


In [None]:
import torch

In [None]:
!nvidia-smi

Fri Apr 18 08:09:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   56C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
word2idx = {word: i for i, word in enumerate(set(text.split()))}
print(word2idx)

{'captained': 0, 'key': 1, 'ICC': 2, 'India': 3, 'World': 4, 'to': 5, 'also': 6, 'Cup': 7, 'widely': 8, '2024': 9, 'record': 10, 'holds': 11, 'pronunciation': 12, 'Challengers': 13, 'Hence,': 14, 'Trophy': 15, 'Royal': 16, 'ranks': 17, 'stands': 18, 'occasional': 19, 'regarded': 20, 'ODI': 21, 'modern': 22, 'second': 23, 'domestic': 24, '2013': 25, 'international': 26, 'era.': 27, 'cricket': 28, 'T20I,': 29, 'Delhi': 30, 'team': 31, 'all': 32, 'batsmen': 33, 'fourth-highest': 34, 'win': 35, 'in': 36, 'highest': 37, 'IPL,': 38, 'most': 39, '2011': 40, 'unorthodox': 41, 'player': 42, 'run-scorer': 43, '(Hindi': 44, 'cricket.': 45, 'third': 46, 'member': 47, 'Virat': 48, 'represents': 49, 'Kohli': 50, 'three': 51, 'He': 52, 'Cricket': 53, 'bowler.': 54, '[4]': 55, 'was': 56, 'T20': 57, 'IPL': 58, 'right': 59, 'greatest': 60, 'match': 61, 'of': 62, 'India.': 63, 'quick': 64, 'Champions': 65, 'cricketer': 66, 'a': 67, 'Test': 68, 'an': 69, 'T20I': 70, 'former': 71, 'Bengaluru': 72, 'scoring

In [None]:
from torch.utils.data import Dataset

# In order to define any custom Dataset you need to define 3 class methods
1. __init__
2. __len__
3. __getitem__

In [None]:
class customDataset(Dataset):
  def __init__(self, text, word2idx, seq_length):
    self.text = text
    self.word2idx = word2idx
    self.seq_length = seq_length

  def __len__(self):
    return len(self.text) - self.seq_length

  def __getitem__(self, index):
    sequence = [self.word2idx[word] for word in self.text[index:index+self.seq_length]]
    target = self.word2idx[self.text[index+self.seq_length]]
    return torch.tensor(sequence), torch.tensor(target)

dataset = customDataset(text.split(), word2idx, 5)

In [None]:
dataset[10]

(tensor([85, 82, 68, 28, 90]), tensor(21))

In [None]:
idx2word = {i:word for word,i in word2idx.items()}

In [None]:
from torch.utils.data import DataLoader

In [None]:
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

In [None]:
import torch.nn as nn

In [None]:
class LSTM(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size) -> None:
    super().__init__()
    self.embed = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, x, h0, c0):
    embed = self.embed(x)
    # out,h = self.lstm(embed, h0)
    out,(h_n, c_n) = self.lstm(embed, (h0, c0))
    output = self.fc(out[:,-1,:])
    return output, (h_n, c_n)



In [None]:
model = LSTM(vocab_size=len(word2idx), embed_size=100, hidden_size=256).to("cuda")

In [None]:
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)
criteria = nn.CrossEntropyLoss()

# Train an LSTM Model

In [None]:
for epoch in range(50):
  for input, label in dataloader:
    input = input.to("cuda")
    label = label.to("cuda")

    optimizer.zero_grad()
    h0 = torch.zeros(1, input.size(0), 256).to("cuda")
    c0 = torch.zeros(1, input.size(0), 256).to("cuda")
    outputs, _ = model(input, h0, c0)
    loss = criteria(outputs, label)
    loss.backward()
    optimizer.step()
  print(f"Epoch : {epoch} : Loss : {loss.item()}")

Epoch : 0 : Loss : 4.521947383880615
Epoch : 1 : Loss : 4.51588249206543
Epoch : 2 : Loss : 4.459827899932861
Epoch : 3 : Loss : 4.4485602378845215
Epoch : 4 : Loss : 4.49717903137207
Epoch : 5 : Loss : 4.433282375335693
Epoch : 6 : Loss : 4.355049133300781
Epoch : 7 : Loss : 4.374637126922607
Epoch : 8 : Loss : 4.332235813140869
Epoch : 9 : Loss : 4.428161144256592
Epoch : 10 : Loss : 4.279419422149658
Epoch : 11 : Loss : 4.438754558563232
Epoch : 12 : Loss : 4.4265217781066895
Epoch : 13 : Loss : 4.164323329925537
Epoch : 14 : Loss : 4.2708001136779785
Epoch : 15 : Loss : 4.311882495880127
Epoch : 16 : Loss : 4.119652271270752
Epoch : 17 : Loss : 4.242851734161377
Epoch : 18 : Loss : 4.194119930267334
Epoch : 19 : Loss : 4.279506683349609
Epoch : 20 : Loss : 4.386092662811279
Epoch : 21 : Loss : 4.265775680541992
Epoch : 22 : Loss : 3.9238440990448
Epoch : 23 : Loss : 3.8120718002319336
Epoch : 24 : Loss : 4.0277485847473145
Epoch : 25 : Loss : 3.3375377655029297
Epoch : 26 : Loss : 

In [None]:
input_seq = torch.tensor([word2idx[word] for word in text.split()[-9:-4]]).unsqueeze(0).to("cuda")
h0 = torch.zeros(1, input_seq.size(0), 256).to("cuda")
c0 = torch.zeros(1, input_seq.size(0), 256).to("cuda")

In [None]:
out, _ = model(input_seq, h0, c0)

In [None]:
out.argmax().item()

10

In [None]:
print("Input Sequence : ", text.split()[-10:])
print("Input Sequence : ", text.split()[-9:-4])

print("The next word prediction : ", idx2word[out.argmax().item()])

Input Sequence :  ['and', 'captained', 'India', 'to', 'win', 'the', 'ICC', 'Test', 'match', 'three']
Input Sequence :  ['captained', 'India', 'to', 'win', 'the']
The next word prediction :  record
