In [1]:
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
# Text data loading and preprocessing
class TextDataset(Dataset):
    def __init__(self, file_path:str, vocab = None, build_vocab = False):
        
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        sentences = [re.sub(r'[^a-zA-Z\s]', '', line.strip()) for line in lines if line.strip() != '']
        sentences = [sentence.lower() for sentence in sentences]
        tokenized_sentences = [sentence.split() for sentence in sentences]
        tokenized_sentences = [tokens for tokens in tokenized_sentences if len(tokens) >= 2]
        self.sentences = tokenized_sentences
        
        if build_vocab:
            self.build_vocab(self.sentences)
        else:
            self.vocab = vocab
        
        # Create data pairs: the input is all words except the last, and the target is the last word
        self.data = []
        for tokens in self.sentences:
            input_tokens = tokens[:-1]
            target_token = tokens[-1]
            # convert words to their index in the vocab; if isn’t found, use the index for <UNK>
            input_indices = [self.vocab.get(token, self.vocab.get("<UNK>")) for token in input_tokens]
            target_index = self.vocab.get(target_token, self.vocab.get("<UNK>"))
            self.data.append((input_indices, target_index))
    
    def build_vocab(self, sentences:list):
        vocab = {"<PAD>": 0, "<UNK>": 1}
        idx = 2
        for tokens in sentences:
            for token in tokens:
                if token not in vocab:
                    vocab[token] = idx
                    idx += 1
        self.vocab = vocab
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx:int):
        return self.data[idx]

# Padding data
def collate_fn(batch:list) -> tuple:
    inputs, targets = zip(*batch)
    max_len = max(len(seq) for seq in inputs)
    padded_inputs = []
    # for each sequence, add 0's at the beginning so they all have the same length
    for seq in inputs:
        padded_seq = [0] * (max_len - len(seq)) + seq
        padded_inputs.append(padded_seq)
    return torch.tensor(padded_inputs, dtype=torch.long), torch.tensor(targets, dtype=torch.long)

# Create the training and test datasets using the file paths
train_dataset = TextDataset("wiki.train.txt", build_vocab=True)
vocab = train_dataset.vocab 
test_dataset = TextDataset("wiki.test.txt", vocab=vocab, build_vocab=False)

# Create DataLoaders to load the data in batches
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [3]:
class LSTMCell(nn.Module):
    def __init__(self, input_size:int, hidden_size:int):
        super(LSTMCell, self).__init__()
        self.hidden_size = hidden_size
        # weights for the gates
        self.Wf = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wi = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wc = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wo = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, x:torch.Tensor, hidden:tuple) -> tuple:
        h_prev, c_prev = hidden
        combined_input = torch.cat((x, h_prev), dim=1)
        # forget gate
        ft = torch.sigmoid(self.Wf(combined_input))
        # input gate
        it = torch.sigmoid(self.Wi(combined_input))
        ct_tilde = torch.tanh(self.Wc(combined_input))
        # new memory cell state
        ct = ft * c_prev + it * ct_tilde
        # output gate
        ot = torch.sigmoid(self.Wo(combined_input))
        ht = ot * torch.tanh(ct)
        return ht, ct

class LSTMModel(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = LSTMCell(input_size, hidden_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_length, _ = x.size()
        h_t = torch.zeros(batch_size, self.hidden_size)
        c_t = torch.zeros(batch_size, self.hidden_size)
        # unrolling the LSTM over the sequence
        for t in range(seq_length):
            h_t, c_t = self.lstm_cell(x[:, t, :], (h_t, c_t))
        return h_t


class TextGenModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_size: int):
        super(TextGenModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_model = LSTMModel(embedding_dim, hidden_size)
        # final layer maps hidden state to logits over the vocabulary
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embeds = self.embedding(x) 
        lstm_out = self.lstm_model(embeds)
        logits = self.fc(lstm_out)
        return logits

In [4]:
# Hyperparameters
embedding_dim = 50
hidden_size = 50
vocab_size = len(vocab)

# Instantiate the model, loss function, and optimizer
model = TextGenModel(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_size=hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 20 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

Epoch 1, Batch 0, Loss: 10.2302
Epoch 1, Batch 20, Loss: 10.1823
Epoch 1, Batch 40, Loss: 9.8095
Epoch 1, Batch 60, Loss: 8.7315
Epoch 1, Batch 80, Loss: 8.7918
Epoch 2, Batch 0, Loss: 8.0929
Epoch 2, Batch 20, Loss: 7.8539
Epoch 2, Batch 40, Loss: 7.8226
Epoch 2, Batch 60, Loss: 7.8888
Epoch 2, Batch 80, Loss: 8.1555


In [5]:
# Evaluation loop
model.eval()
total_test_loss = 0.0
with torch.no_grad():
    for inputs, targets in test_dataloader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        total_test_loss += loss.item()
avg_test_loss = total_test_loss / len(test_dataloader)
print(f"Test Cross Entropy Loss: {avg_test_loss:.4f}")

Test Cross Entropy Loss: 7.8042
