In [37]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import re
from collections import Counter
import requests

In [38]:
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return text.split()

def build_vocab(word_list, vocab_size=5000):
    most_common = Counter(word_list).most_common(vocab_size - 1)
    vocab = {w: i+1 for i, (w, _) in enumerate(most_common)}
    vocab["<UNK>"] = 0
    return vocab

def encode_words(word_list, vocab):
    return [vocab.get(word, vocab["<UNK>"]) for word in word_list]

url = "https://www.gutenberg.org/files/11/11-0.txt"
response = requests.get(url)
raw_text = response.text

words = preprocess_text(raw_text)
vocab = build_vocab(words)
encoded = encode_words(words, vocab)
idx2word = {i: w for w, i in vocab.items()}

In [39]:

class LanguageModelDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx : idx + self.seq_len])
        y = torch.tensor(self.data[idx + self.seq_len])
        return x, y

dataset = LanguageModelDataset(encoded, seq_len=5)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)


In [40]:
class SimpleGRU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SimpleGRU, self).__init__()
        self.W_z = nn.Linear(input_dim, hidden_dim)
        self.U_z = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.W_r = nn.Linear(input_dim, hidden_dim)
        self.U_r = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.W_h = nn.Linear(input_dim, hidden_dim)
        self.U_h = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x, h_0=None):
        seq_len, batch_size, _ = x.size()
        h_t = torch.zeros(batch_size, self.U_h.out_features, device=x.device) if h_0 is None else h_0
        outputs = []
        for t in range(seq_len):
            x_t = x[t]
            z_t = torch.sigmoid(self.W_z(x_t) + self.U_z(h_t))
            r_t = torch.sigmoid(self.W_r(x_t) + self.U_r(h_t))
            h_tilde = torch.tanh(self.W_h(x_t) + self.U_h(r_t * h_t))
            h_t = (1 - z_t) * h_t + z_t * h_tilde
            outputs.append(h_t.unsqueeze(0))
        return torch.cat(outputs, dim=0), h_t


In [41]:
class CustomGRULM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = SimpleGRU(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, embedding_dim)
        out, _ = self.gru(x)
        return self.fc(out[-1])


In [42]:
vocab_size = len(vocab)
model = CustomGRULM(vocab_size, embedding_dim=32, hidden_dim=128)
optimizer = torch.optim.SGD(model.parameters(), lr=0.3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [43]:
for epoch in range(5):
    total_loss = 0
    for inputs, targets in loader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 2448.6518
Epoch 2, Loss: 2152.6536
Epoch 3, Loss: 1983.1302
Epoch 4, Loss: 1842.9782
Epoch 5, Loss: 1745.3019


In [44]:
def predict_next_word(seed_text):
    model.eval()
    seq_len = 5
    seed_words = preprocess_text(seed_text)[-seq_len:]
    encoded_input = encode_words(seed_words, vocab)
    if len(encoded_input) < seq_len:
        encoded_input = [0] * (seq_len - len(encoded_input)) + encoded_input
    input_tensor = torch.tensor([encoded_input])
    with torch.no_grad():
        output = model(input_tensor)
        next_word_id = torch.argmax(output, dim=-1).item()
    return idx2word[next_word_id]

In [45]:
print("Input:", "she was not a bit")
print("Prediction:", predict_next_word("she was not a bit"))

Input: she was not a bit
Prediction: afraid
