# 1. Libraries and Additional Functions

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from collections import Counter
import re, os, pickle
import numpy as np
import matplotlib.pyplot as plt

In [35]:
def save_PKL(curr_DICT,path):
    with open(path, 'wb') as file:
        pickle.dump(curr_DICT, file)
    return curr_DICT

def load_PKL(path):
    with open(path, 'rb') as file:
        curr_DICT = pickle.load(file)
    return curr_DICT

def setup_device():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    return device

# 2. Define Variable Initializations

In [36]:
# Set random seed
torch.manual_seed(1234)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1234)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
params_unregularized = {
    "model_name": "lang_model_LSTM_unregularized",
    "save_interval": 10,
    "batch_size": 20,
    "seq_length": 20,
    "layers": 2,
    "decay": 2,
    "rnn_size": 200,
    "dropout": 0.0,
    "init_weight": 0.1,
    "lr": 1.0,
    "vocab_size": 12000,
    "max_epoch": 4,
    "max_max_epoch": 13,
    "max_grad_norm": 5,
}

In [37]:
params_regularized = {
    "model_name": "lang_model_LSTM_regularized",
    "save_interval": 10,
    "batch_size": 20,
    "seq_length": 35,
    "layers": 2,
    "decay": 1.2,
    "rnn_size": 650,
    "dropout": 0.5,
    "init_weight": 0.05,
    "lr": 1.0,
    "vocab_size": 12000,
    "max_epoch": 6,
    "max_max_epoch": 39,
    "max_grad_norm": 5,
}


# 3. Extract and Transform Data

In [38]:
def build_vocab(data,vocab_size):
    vocab_map = {}
    vocab_idx = 0
    
    word_freq = Counter(word for word in data)
    
    most_common_words = [word for word, _ in word_freq.most_common(vocab_size - 1)]  # Leave space for <unk>
    
    for word in most_common_words:
        vocab_map[word] = vocab_idx
        vocab_idx += 1
    
    vocab_map['<unk>'] = vocab_idx
    return vocab_map

def load_data(data_type,params):    
    data = load_dataset("tiny_shakespeare", trust_remote_code=True)[data_type]["text"][0]
    
    data = data.replace('\n',' <eos> ')
    data = data.replace("'"," '")
    data = data.replace("--"," --")
    data = re.sub(r'[,!?;:]', '', data)
    data = re.sub(r'[.,!?;:]+(?=\s<eos>)', '', data)
    data = data.lower()
    data = data.split()
    
    vocab_map = build_vocab(data,params["vocab_size"])
            
    x = torch.tensor([vocab_map.get(word, vocab_map['<unk>']) for word in data], dtype=torch.long)
    return x

def replicate(x_inp,batch_size):
    s = x_inp.size(0)
    x = torch.zeros((s//batch_size, batch_size), dtype=torch.long)

    for i in range(batch_size):
        start = i*s//batch_size  # Mimic Lua rounding
        finish = start + x.size(0)
        x[:, i] = x_inp[start:finish]
    return x

def traindataset(params):
    x = load_data("train",params)
    x = replicate(x,params["batch_size"])
    return x

def validdataset(params):
    x = load_data("validation",params)
    x = replicate(x,params["batch_size"])
    return x

def testdataset(params):
    x = load_data("test",params)
    x = x.view(-1, 1).expand(-1,params["batch_size"]).clone()
    return x

def getdatasets(params):
    data_train = traindataset(params)
    data_valid = validdataset(params)
    data_test = testdataset(params)
    return data_train, data_valid, data_test

In [39]:
# Call dataset functions
data_train_unregularized, data_valid_unregularized, data_test_unregularized = getdatasets(params_unregularized)

In [40]:
data_train_regularized, data_valid_regularized, data_test_regularized = getdatasets(params_regularized)

# 4. LSTM Cell Class 

In [41]:
# LSTM Cell implementation
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.W = nn.Linear(input_size, 4*hidden_size)
        self.U = nn.Linear(hidden_size, 4*hidden_size)
    
    def forward(self, x, prev_c, prev_h):
        gates = self.W(x) + self.U(prev_h)
        i, f, o, g = gates.chunk(4, dim=-1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        g = torch.tanh(g)
        next_c = f * prev_c + i * g
        next_h = o * torch.tanh(next_c)
        return next_c, next_h

# 5. LSTM Model

In [42]:

# LSTM Model
class LSTMModel(nn.Module):
    def __init__(self, params):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(params["vocab_size"], params["rnn_size"])
        self.lstm_cells = nn.ModuleList([
            LSTMCell(params["rnn_size"], params["rnn_size"]) for _ in range(params["layers"])
        ])
        self.fc = nn.Linear(params["rnn_size"], params["vocab_size"])
        self.dropout = nn.Dropout(params["dropout"])
        self.init_weights(params)

    def init_weights(self, params):
        for param in self.parameters():
            param.data.uniform_(-params["init_weight"], params["init_weight"])

    def forward(self, x, hidden_states):
        x = self.embedding(x)
        next_hidden_states = []
        for i, cell in enumerate(self.lstm_cells):
            prev_c, prev_h = hidden_states[i]
            next_c, next_h = cell(x, prev_c, prev_h)
            x = next_h  # Input for next layer
            next_hidden_states.append((next_c, next_h))
        x = self.dropout(x)
        x = self.fc(x)
        return torch.log_softmax(x, dim=-1), next_hidden_states

    def init_hidden(self, batch_size, params):
        return [(torch.zeros(batch_size, params["rnn_size"]).to(device),
                 torch.zeros(batch_size, params["rnn_size"]).to(device))
                for _ in range(params["layers"])]
        
    def save(self, params, epoch_num, metrics_DICT, best_model=False):
        if best_model:
            model_path = f"models/{params['model_name']}_best.pth"
            metrics_path = f"data/output_data/{params['model_name']}_best.pkl"
        else:
            model_path = f"models/{params['model_name']}_epoch_{epoch_num + 1}.pth"
            metrics_path = f"data/output_data/{params['model_name']}_epoch_{epoch_num + 1}.pkl"
        torch.save(self.state_dict(), model_path)
        save_PKL(metrics_DICT, metrics_path)
        return
    
def check_and_load_model(params,epoch_num,best_model=False,load_model=True):
    if best_model:
        model_path = f"models/{params['model_name']}_best.pth"
        metrics_path = f"data/output_data/{params['model_name']}_best.pkl"
    else:
        model_path = f"models/{params['model_name']}_epoch_{epoch_num + 1}.pth"
        metrics_path = f"data/output_data/{params['model_name']}_epoch_{epoch_num + 1}.pkl"
    model = LSTMModel(params)
    if os.path.exists(model_path) and load_model:
        model.load_state_dict(torch.load(model_path))
    if os.path.exists(model_path) and load_model:
        metrics_DICT = load_PKL(metrics_path)
    else:
        metrics_DICT = {"epoch_list": [], "train_ppl": [], "valid_ppl": [], "test_ppl": 0, "best_valid_ppl": np.inf}
    device = setup_device()
    model.to(device) 
    return model,metrics_DICT

In [43]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epoch_num_unregularized = 0
criterion = nn.NLLLoss()
model_unregularized,metrics_unregularized_DICT = check_and_load_model(params_unregularized,epoch_num_unregularized)
optimizer_unregularized = optim.SGD(model_unregularized.parameters(), lr=params_unregularized["lr"])

In [44]:
# Initialize model
epoch_num_regularized = 0
model_regularized,metrics_regularized_DICT = check_and_load_model(params_regularized,epoch_num_regularized)
optimizer_regularized = optim.SGD(model_regularized.parameters(), lr=params_regularized["lr"])

# 6. Train and Evaluate Model

In [45]:
def fp(data, model, params, hidden_states):
    model.train()
    x = data[:-1]
    y = data[1:]
    x, y = x.to(device), y.to(device)
    output, hidden_states = model(x, hidden_states)
    loss = criterion(output.view(-1, params["vocab_size"]), y.view(-1))
    return loss, hidden_states

def bp(loss, model, params, optimizer):
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
    optimizer.step()
    import gc
    gc.collect()

def run_epoch(data, model, params, optimizer):
    hidden_states = model.init_hidden(params["batch_size"],params)
    total_loss = 0
    num_batches = max(1, data.size(0) // params["seq_length"] - 1)
    for batch in range(num_batches):
        x = data[batch * params["seq_length"]:(batch + 1) * params["seq_length"]]
        y = data[batch * params["seq_length"] + 1:(batch + 1) * params["seq_length"] + 1]
        loss, hidden_states = fp(x, model, params, hidden_states)
        bp(loss, model, params, optimizer)
        total_loss += loss.item()
        hidden_states = [(c.detach(), h.detach()) for (c, h) in hidden_states]
    return torch.exp(torch.tensor(total_loss / num_batches))

def train(data_train, data_valid, model, metrics_DICT, params, optimizer):
    if not metrics_DICT["train_ppl"]:
        epoch_list = []
        train_ppl_list = []
        valid_ppl_list = []
    else:
        epoch_list = metrics_DICT["epoch_list"].copy()
        train_ppl_list = metrics_DICT["train_ppl"].copy()
        valid_ppl_list = metrics_DICT["valid_ppl"].copy()
    curr_epoch = len(metrics_DICT["epoch_list"])
    best_val_ppl = metrics_DICT["best_valid_ppl"]
    max_num_epochs = params["max_max_epoch"]
    for epoch in range(curr_epoch,params["max_max_epoch"]):
        train_ppl = run_epoch(data_train, model, params, optimizer)
        valid_ppl = evaluate(data_valid, model, metrics_DICT, params)[0]
        print(f'Epoch {epoch + 1}/{max_num_epochs}, Train PPL: {train_ppl:.4f}, Val PPL: {valid_ppl:.4f}')
        epoch_list.append(epoch + 1)
        train_ppl_list.append(train_ppl)
        valid_ppl_list.append(valid_ppl)
        if epoch > params["max_epoch"]:
            params["lr"] /= params["decay"]
        if (epoch + 1) % params["save_interval"] == 0 or epoch == params["max_max_epoch"] - 1:
            metrics_DICT["epoch_list"] = epoch_list.copy()
            metrics_DICT["train_ppl"] = train_ppl_list.copy()
            metrics_DICT["valid_ppl"] = valid_ppl_list.copy()
            model.save(params, epoch, metrics_DICT)
        if valid_ppl < best_val_ppl:
            best_val_ppl = valid_ppl
            metrics_DICT["epoch_list"] = epoch_list.copy()
            metrics_DICT["train_ppl"] = train_ppl_list.copy()
            metrics_DICT["valid_ppl"] = valid_ppl_list.copy()
            metrics_DICT["best_valid_ppl"] = best_val_ppl
            metrics_DICT["best_epoch"] = epoch + 1
            model.save(params, epoch, metrics_DICT, best_model=True)
    return model, params, metrics_DICT

def evaluate(data, model, metrics_DICT, params, best_model=False):
    model.eval()
    hidden_states = model.init_hidden(params["batch_size"],params)
    total_loss = 0
    num_batches = max(1, data.size(0) // params["seq_length"] - 1)
    with torch.no_grad():
        for batch in range(num_batches):
            x = data[batch * params["seq_length"]:(batch + 1) * params["seq_length"]]
            y = data[batch * params["seq_length"] + 1:(batch + 1) * params["seq_length"] + 1]
            loss, hidden_states = fp(x, model, params, hidden_states)
            total_loss += loss.item()
            hidden_states = [(c.detach(), h.detach()) for (c, h) in hidden_states]
    eval_ppl = torch.exp(torch.tensor(total_loss/num_batches))
    metrics_DICT["test_ppl"] = eval_ppl
    
    if best_model:
        epoch = metrics_unregularized_DICT["best_epoch"]
    else:
        epoch = len(metrics_DICT["epoch_list"])
    model.save(params, epoch, metrics_DICT, best_model)
    return eval_ppl, metrics_DICT

In [None]:

# Training loop
model_unregularized, metrics_unregularized_DICT = train(data_train_unregularized, data_valid_unregularized, model_unregularized, metrics_unregularized_DICT, params_unregularized, optimizer_unregularized)

# Evaluation loop
best_model=True
model_unregularized, metrics_unregularized_DICT = check_and_load_model(params_unregularized,metrics_unregularized_DICT["best_epoch"],best_model)
test_eval_ppl_unregularized, metrics_unregularized_DICT = evaluate(data_test_unregularized, model_unregularized, metrics_unregularized_DICT, params_unregularized, best_model)
print(f"Test PPL: {test_eval_ppl_unregularized:.2f}")

Epoch 1/13, Train PPL: 490.3062, Val PPL: 327.2083
Epoch 2/13, Train PPL: 442.6985, Val PPL: 332.6233
Epoch 3/13, Train PPL: 391.6645, Val PPL: 364.2330
Epoch 4/13, Train PPL: 342.1205, Val PPL: 405.4155
Epoch 5/13, Train PPL: 307.6555, Val PPL: 441.9352


In [None]:
# Training loop
model_regularized, metrics_regularized_DICT = train(data_train_regularized, data_valid_regularized, model_regularized, metrics_regularized_DICT, params_regularized, optimizer_regularized)

# Evaluation loop
best_model=True
model_regularized, metrics_regularized_DICT = check_and_load_model(params_regularized,metrics_regularized_DICT["best_epoch"],best_model)
test_eval_ppl_regularized, metrics_regularized_DICT = evaluate(data_test_regularized, model_regularized, metrics_regularized_DICT, params_regularized, best_model)
print(f"Test PPL: {test_eval_ppl_regularized:.2f}")

# 7. Plotting Results

In [None]:
#Visualize the train and validation loss
def plot_train_valid(params,metrics_DICT):
    plt.figure();
    plt.plot(metrics_DICT['epoch_list'],metrics_DICT['train_loss_list'], label=f'Train PPL', color='blue', linestyle='--', marker='o');
    plt.plot(metrics_DICT['epoch_list'],metrics_DICT['val_loss_list'], label=f'Validation PPL', color='green', linestyle='-', marker='x');
    plt.title(f'{params["model_name"]} Training and Validation PPL');
    plt.xlabel('Epochs');
    plt.ylabel('Loss');
    plt.legend();
    plt.grid();
    plt.xlim(0,max(metrics_DICT['epoch_list'])+1);
    return

In [None]:
plot_train_valid(params_unregularized,metrics_unregularized_DICT)

In [None]:
plot_train_valid(params_regularized,metrics_regularized_DICT)