# 1. Importing Libraries

In [1]:
import torch
import os
import re
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import OrderedDict

# 2. Define Variable Initializations

In [22]:
# Set random seed
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1234)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
params = {
    "batch_size": 20,
    "seq_length": 20,
    "layers": 2,
    "decay": 2,
    "rnn_size": 200,
    "dropout": 0.0,
    "init_weight": 0.1,
    "lr": 1.0,
    "vocab_size": 10000,
    "max_epoch": 4,
    "max_max_epoch": 13,
    "max_grad_norm": 5,
}

# 3. Extract and Transform Data

In [23]:
ptb_path = "./data/"

vocab_idx = 0
vocab_map = {}

In [24]:
def tokenize(text):
    """ Mimic Lua's stringx.split() behavior. """
    return re.findall(r"\S+", text)  # Split by whitespace but keep everything else intact.

def load_data(fname):
    global vocab_idx
    
    print(f"\n[Loading Data] File: {fname}")
    
    # Read file and replace newlines with <eos>
    with open(fname, 'r', encoding='utf-8') as f:
        data = f.read()
    
    data = data.replace('\n', '<eos>')  # Match Lua behavior
    data = data.split()  # Equivalent to stringx.split in Lua
            
    x = torch.zeros(len(data), dtype=torch.long)
    
    for i in range(len(data)):
        if data[i] not in vocab_map:
            vocab_idx += 1
            vocab_map[data[i]] = vocab_idx
        x[i] = vocab_map[data[i]]
    return x

def replicate(x_inp, batch_size):
    """ Replicates and shifts data exactly like Lua """
    s = x_inp.size(0)
    x = torch.zeros((s // batch_size, batch_size), dtype=torch.long)

    for i in range(batch_size):
        start = round(i * s / batch_size)  # Mimic Lua rounding
        finish = start + x.size(0)
        x[:, i] = x_inp[start:finish]
    return x

def traindataset(batch_size):
    x = load_data(os.path.join(ptb_path, "ptb.train.txt"))
    x = replicate(x, batch_size)
    return x

def testdataset(batch_size):
    x = load_data(os.path.join(ptb_path, "ptb.test.txt"))
    x = x.view(-1, 1).expand(-1, batch_size).clone()
    return x

def validdataset(batch_size):
    x = load_data(os.path.join(ptb_path, "ptb.valid.txt"))
    x = replicate(x, batch_size)
    return x

datasets = {
    "traindataset": traindataset,
    "validdataset": validdataset,
    "testdataset": testdataset
}

In [26]:
# Call dataset functions

data_train = traindataset(params["batch_size"])
data_valid = validdataset(params["batch_size"])
data_test = testdataset(params["batch_size"])


[Loading Data] File: ./data/ptb.train.txt

[Loading Data] File: ./data/ptb.valid.txt

[Loading Data] File: ./data/ptb.test.txt


# 4. LSTM Cell Class 

In [None]:
# Custom LSTM implementation
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.W = nn.Linear(input_size, 4*hidden_size)
        self.U = nn.Linear(hidden_size, 4*hidden_size)

    def forward(self, x, prev_c, prev_h):
        gates = self.W(x) + self.U(prev_h)
        i, f, o, g = gates.chunk(4, dim=-1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        g = torch.tanh(g)
        next_c = f * prev_c + i * g
        next_h = o * torch.tanh(next_c)
        return next_c, next_h

# 5. LSTM Model

In [None]:
# LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, params):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(params["vocab_size"], params["rnn_size"])
        self.lstm_cells = nn.ModuleList([
            LSTMCell(params["rnn_size"], params["rnn_size"]) for _ in range(params["layers"])
        ])
        self.fc = nn.Linear(params["rnn_size"], params["vocab_size"])
        self.dropout = nn.Dropout(params["dropout"])

    def forward(self, x, hidden_states):
        x = self.embedding(x)
        next_hidden_states = []
        for i, cell in enumerate(self.lstm_cells):
            prev_c, prev_h = hidden_states[i]
            next_c, next_h = cell(x, prev_c, prev_h)
            x = next_h  # Input for next layer
            next_hidden_states.append((next_c, next_h))
        x = self.dropout(x)
        x = self.fc(x)
        return torch.log_softmax(x, dim=-1), next_hidden_states

    def init_hidden(self, batch_size):
        return [(torch.zeros(batch_size, params["rnn_size"]).to(device),
                 torch.zeros(batch_size, params["rnn_size"]).to(device))
                for _ in range(params["layers"])]

In [None]:
# Initialize model
model = LSTMModel(params).to(device)
for param in model.parameters():
    param.data.uniform_(-params["init_weight"], params["init_weight"])

criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=params["lr"])

# 6. Train and Evaluate Model

In [None]:
def train(data,params):
    model.train()
    hidden_states = model.init_hidden(params["batch_size"])
    total_loss = 0
    num_batches = max(1, data.size(0) // params["seq_length"] - 1)  # Avoid division by zero

    for batch in range(num_batches):
        x = data[batch * params["seq_length"]:(batch + 1) * params["seq_length"]]
        y = data[batch * params["seq_length"] + 1:(batch + 1) * params["seq_length"] + 1]
        x, y = x.to(device), y.to(device)

        model.zero_grad()
        output, hidden_states = model(x, hidden_states)

        loss = criterion(output.view(-1, params["vocab_size"]), y.view(-1))
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
        optimizer.step()

        total_loss += loss.item()
        
        # Detach hidden states to prevent backprop through entire history
        hidden_states = [(c.detach(), h.detach()) for (c, h) in hidden_states]

        if batch % 100 == 0:
            print(f"Batch {batch}/{num_batches}, Loss: {loss.item():.4f}")

    return torch.exp(total_loss / num_batches)  # Return perplexity

def evaluate(data, params):
    model.eval()
    hidden_states = model.init_hidden(params["batch_size"])
    total_loss = 0
    num_batches = max(1, data.size(0) // params["seq_length"] - 1)  # Avoid division by zero

    with torch.no_grad():
        for batch in range(num_batches):
            x = data[batch * params["seq_length"]:(batch + 1) * params["seq_length"]]
            y = data[batch * params["seq_length"] + 1:(batch + 1) * params["seq_length"] + 1]
            x, y = x.to(device), y.to(device)

            output, hidden_states = model(x, hidden_states)
            loss = criterion(output.view(-1, params["vocab_size"]), y.view(-1))
            total_loss += loss.item()

            # Detach hidden states to free memory
            hidden_states = [(c.detach(), h.detach()) for (c, h) in hidden_states]

    return torch.exp(total_loss / num_batches)  # Return perplexity


In [None]:
# Training loop
for epoch in range(params["max_max_epoch"]):
    train_ppl = train(data_train,params)
    valid_ppl = evaluate(data_valid,params)
    print(f"Epoch {epoch+1}, Train PPL: {train_ppl:.2f}, Valid PPL: {valid_ppl:.2f}")
    if epoch > params["max_epoch"]:
        params["lr"] /= params["decay"]

test_ppl = evaluate(data_test)
print(f"Test PPL: {test_ppl:.2f}")

Batch 0/2322, Loss: 8.4566
Batch 100/2322, Loss: 6.4724
Batch 200/2322, Loss: 6.4335
Batch 300/2322, Loss: 6.3697
Batch 400/2322, Loss: 6.5697
Batch 500/2322, Loss: 6.4503
Batch 600/2322, Loss: 6.4747
Batch 700/2322, Loss: 6.7417
Batch 800/2322, Loss: 6.6693
Batch 900/2322, Loss: 6.3771
Batch 1000/2322, Loss: 6.4083
Batch 1100/2322, Loss: 6.6248
Batch 1200/2322, Loss: 6.2482
Batch 1300/2322, Loss: 6.2818
Batch 1400/2322, Loss: 6.3309
Batch 1500/2322, Loss: 6.4394
Batch 1600/2322, Loss: 6.4286
Batch 1700/2322, Loss: 6.4815
Batch 1800/2322, Loss: 6.0974
Batch 1900/2322, Loss: 6.1978
Batch 2000/2322, Loss: 6.4504
Batch 2100/2322, Loss: 6.2663
Batch 2200/2322, Loss: 6.1753
Batch 2300/2322, Loss: 6.3246


IndexError: index out of range in self