In [1]:
from collections import Counter
import re

In [3]:
with open("/content/shakespeare.txt", "r" ,encoding="utf8") as f:
  text = f.read()
text = text.lower()

In [4]:
tokens = re.findall(r"\w+|[^\w\s]", text)

In [5]:
vocab = sorted(set(tokens))

In [6]:
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

In [7]:
indices = [word2idx[w] for w in tokens]

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader

In [9]:
class RNNLM_Dataset(Dataset):
    def __init__(self, data, context_size):
        self.data = data
        self.context_size = context_size

    def __len__(self):
        return len(self.data) - self.context_size

    def __getitem__(self, idx):
      x = self.data[idx : idx + self.context_size]          # Input sequence
      y = self.data[idx + 1 : idx + self.context_size + 1]  # Next-word targets
      return torch.tensor(x), torch.tensor(y)


In [10]:
context_size = 15
dataset = RNNLM_Dataset(indices, context_size)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [11]:
import torch
import torch.nn.functional as F
import torch.nn as nn


class RNNLM(nn.Module):
  def __init__(self, vocab_size, hidden_size=128, embed_dim=128, num_layers=1):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size,embed_dim)
    self.rnn = nn.RNN(embed_dim, hidden_size,num_layers,batch_first=True)
    self.fc = nn.Linear(hidden_size,vocab_size)

  def forward(self, x, hidden):
    emb = self.embedding(x)
    out,hidden = self.rnn(emb,hidden)
    logits = self.fc(out)
    return logits,hidden

  def init_hidden(self, batch_size):
    # Initialize hidden state to zeros
    weight = next(self.parameters())
    return torch.zeros(self.rnn.num_layers, batch_size,self.rnn.hidden_size, device=weight.device)


In [12]:
model = RNNLM(vocab_size, 128, 128, 1)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
batch_size = 64

In [14]:
for epoch in range(1,5):
  total_loss = 0.0
  hidden = model.init_hidden(batch_size)
  for x_batch,y_batch in dataloader:
    optimizer.zero_grad()
    logits, hidden = model(x_batch, hidden.detach())
    logits = logits.view(-1, vocab_size)
    y_batch = y_batch.view(-1)
    loss = criterion(logits, y_batch)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    avg = total_loss / len(dataloader)
  print(f"Epoch {epoch}/{10}, Loss: {avg:.4f}")

Epoch 1/10, Loss: 4.6980
Epoch 2/10, Loss: 3.9143
Epoch 3/10, Loss: 3.5542
Epoch 4/10, Loss: 3.3357
