In [1]:
import torch
import torch.nn as nn
import nltk
from nltk.corpus import treebank
from collections import Counter
from nltk.tokenize import TreebankWordTokenizer
from loguru import logger
import time
import argparse
import matplotlib.pyplot as plt
import os
import pickle as pkl
from safetensors.torch import save_file, load_file
import json

In [2]:
# Set NLTK data path to project's data directory
nltk_data_path = os.path.join(os.getcwd(), 'data', 'nltk_data')
nltk.data.path.append(nltk_data_path)

# Download required NLTK data
if not os.path.exists(nltk_data_path):
    print(f"Downloading NLTK data to: {nltk_data_path}")
    nltk.download('treebank', download_dir=nltk_data_path)
    nltk.download('punkt', download_dir=nltk_data_path)

In [3]:
def load_and_preprocess_data():
    """
    Load and preprocess text data from Penn Treebank corpus.
    
    Returns:
        tuple: Contains three lists of tokenized sentences:
            - train_data: Training set (80% of data)
            - test_data: Test set (10% of data) 
            - val_data: Validation set (10% of data)
            
    Each sentence is preprocessed by:
        1. Converting to lowercase
        2. Tokenizing using TreebankWordTokenizer
        3. Splitting into train/test/val sets
    """
    tokenizer = TreebankWordTokenizer()
    
    # Get sentences from Penn Treebank corpus
    sentences = treebank.sents()

    # Process each sentence
    processed = []
    for sent in sentences:
        # E sent is a list of words
        # Join the sentence into a single string and tokenize
        text = ' '.join(sent)
        tokens = tokenizer.tokenize(text.lower())
        processed.append(tokens)

    # Split into train, test, and validation sets
    train_data = processed[:int(len(processed) * 0.8)]
    test_data = processed[int(len(processed) * 0.8):int(len(processed) * 0.9)]
    val_data = processed[int(len(processed) * 0.9):]

    return train_data, test_data, val_data

In [4]:
def build_voab(data, min_freq=2):
    """
    Build vocabulary from training data with minimum frequency threshold.
    
    Args:
        data (list): List of tokenized sentences where each sentence is a list of tokens
        min_freq (int, optional): Minimum frequency threshold for including words. Defaults to 2.
        
    Returns:
        tuple: Contains two dictionaries:
            - word_to_idx: Maps words to unique integer indices
            - idx_to_word: Maps indices back to words
            
    The vocabulary includes special tokens:
        - <unk>: Unknown words
        - <pad>: Padding token
        - <bos>: Beginning of sentence
        - <eos>: End of sentence
    """
    counter = Counter()
    for sent in data:
        counter.update(sent)

    # Create vocabulary with special tokens
    words = ['<unk>', '<pad>', '<bos>', '<eos>']
    words.extend([word for word, freq in counter.items() if freq >= min_freq])

    word_to_idx = {word: idx for idx, word in enumerate(words)}
    idx_to_word = {idx: word for idx, word in enumerate(words)}

    return word_to_idx, idx_to_word


In [5]:
def process_data(data, word_to_idx):
    """
    Process raw text data into model-ready format by converting tokens to indices.
    
    Args:
        data (list): List of tokenized sentences where each sentence is a list of tokens
        word_to_idx (dict): Dictionary mapping words to unique integer indices
        
    Returns:
        list: List of torch tensors, where each tensor contains the indices for a sentence
            including <bos> and <eos> tokens
            
    Each sentence is processed by:
        1. Converting tokens to their vocabulary indices
        2. Adding beginning-of-sentence (<bos>) and end-of-sentence (<eos>) tokens
        3. Converting to a PyTorch tensor
    """
    processed = []
    for sent in data:
        # convert tokens to indices
        indices = [word_to_idx.get(token, word_to_idx['<unk>']) for token in sent]
        # Add <bos> and <eos> tokens
        indices = [word_to_idx['<bos>']] + indices + [word_to_idx['<eos>']]
        processed.append(torch.tensor(indices))
    
    return processed

In [6]:
def create_batches(data, word_to_idx, batch_size=32):
    """
    Create batches from processed data for model training.
    
    Args:
        data (list): List of torch tensors containing processed sentences
        word_to_idx (dict): Dictionary mapping words to unique integer indices
        batch_size (int, optional): Size of each batch. Defaults to 32.
        
    Returns:
        list: List of torch tensors, where each tensor is a batch of padded sequences
            with shape (batch_size, max_sequence_length)
            
    The function:
        1. Sorts sequences by length in descending order for efficient padding
        2. Groups sequences into batches of specified size
        3. Pads shorter sequences in each batch to match the longest sequence
        4. Converts batches to torch tensors
    """
    data.sort(key=lambda x: len(x), reverse=True)
    total_len = len(data)
    num_batches = (total_len + batch_size - 1) // batch_size

    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        max_len = len(batch[0])
        padded = [torch.cat([seq, torch.tensor([word_to_idx['<pad>']] * (max_len - len(seq)))]) if len(seq) < max_len else seq for seq in batch]
        batches.append(torch.stack(padded))
    return batches

In [7]:
class Network(nn.Module):
    """
    Neural network model for language modeling using LSTM.
    
    Args:
        vocab_size (int): Size of the vocabulary
        embed_size (int): Dimension of word embeddings
        hidden_size (int): Number of features in the hidden state
        num_layers (int): Number of recurrent layers
        cell (str, optional): Type of RNN cell to use. Currently only supports 'lstm'. Defaults to 'lstm'
        dropout (float, optional): Dropout probability. Defaults to 0.5
        
    Attributes:
        layers (int): Number of recurrent layers
        hidden_size (int): Size of hidden state
        embed (nn.Embedding): Word embedding layer
        cell (nn.LSTM): LSTM layer
        dropout (nn.Dropout): Dropout layer
        fc (nn.Linear): Final linear layer that maps to vocabulary size
    """
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, cell='lstm', dropout=0.5):
        super(Network, self).__init__()
        self.layers = num_layers
        self.hidden_size = hidden_size
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.cell = None
        if cell == 'lstm':
            # Point: difference between nn.LSTM and nn.LSTMCell
            self.cell = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        elif cell == 'gru':
            self.cell = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length)
            hidden (tuple, optional): Initial hidden state. Defaults to None
            
        Returns:
            tuple:
                - logits (torch.Tensor): Output logits of shape (batch_size, sequence_length, vocab_size)
                - hidden (tuple): Final hidden state and cell state
        """
        batch_size = x.size(0)
        if hidden is None:
            h0 = torch.zeros(self.layers, batch_size, self.hidden_size).to(x.device)
            c0 = torch.zeros(self.layers, batch_size, self.hidden_size).to(x.device)
            hidden = (h0, c0)

        embeds = self.dropout(self.embed(x))    #point: why dropout here?
        output, hidden = self.cell(embeds, hidden)
        output = self.dropout(output)
        logits = self.fc(output)
        return logits, hidden

In [8]:
def train(model, train_batches, criterion, optimizer, device):
    """
    Train the model for one epoch.
    
    Args:
        model (nn.Module): The neural network model
        train_batches (torch.Tensor): Training data batches
        criterion: Loss function
        optimizer: Optimizer for updating model parameters
        
    Returns:
        float: Average loss over all batches for this epoch
    """
    model.train()
    total_loss = 0.0

    for batch in train_batches:
        optimizer.zero_grad()
        inputs = batch[:, :-1].to(device)   # all tokens except last
        targets = batch[:, 1:].to(device)  # all tokens except first
        outputs, _ = model(inputs)
        loss = criterion(outputs.reshape(-1, outputs.size(-1)),
                         targets.reshape(-1))
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # gradient clipping
        optimizer.step()

    # compute loss
    total_loss += loss.item()
    total_loss = total_loss / len(train_batches)

    return total_loss        


def evaluate(model, eval_batches, criterion, device):
    """
    Evaluate the model on validation/test data.
    
    Args:
        model (nn.Module): The neural network model
        eval_batches (torch.Tensor): Evaluation data batches
        criterion: Loss function
        
    Returns:
        float: Average loss over all batches in the evaluation set
    """
    model.eval()
    total_loss = 0.0

    # point: why no_grad() is needed when we have model.eval()
    with torch.no_grad():
        for batch in eval_batches:
            inputs = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device)

            outputs, _ = model(inputs)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)),
                             targets.reshape(-1))
            total_loss += loss.item()


    return total_loss / len(eval_batches)

def calculate_perplexity(loss):
    return torch.exp(torch.tensor(loss))

In [9]:
def visualize(train_metric, valid_metric, title, xlabel, ylabel, figname):
    """
    Visualize and save a metric plot.
    
    Args:
        metric (list): Values to plot
        title (str): Title of the plot
        xlabel (str): Label for x-axis
        ylabel (str): Label for y-axis 
        figname (str): Filename to save the plot
        
    Returns:
        None
    """
    plt.plot(train_metric, label='train')
    plt.plot(valid_metric, label='valid')
    plt.legend(loc='upper right')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(figname)
    plt.close()

In [10]:
# function to load vocabulary mappings
def load_vocab(vocab_path):
    with open(vocab_path, 'r') as f:
        vocab_dict = json.load(f)
    return vocab_dict['word_to_idx'], vocab_dict['idx_to_word']

# function to load tensors from safetensors file
def load_tensors(file_path):
    loaded_tensors = load_file(file_path)
    return loaded_tensors['train'], loaded_tensors['valid'], loaded_tensors['test']


In [11]:
def generate_text(model, word_to_idx, idx_to_word, device, seed_text='the', max_length=20):
    """Generate text using the trained language model.
    
    Args:
        model: The trained language model
        word_to_idx (dict): Dictionary mapping words to indices
        idx_to_word (dict): Dictionary mapping indices to words  
        seed_text (str, optional): Initial text to condition generation on. Defaults to 'the'.
        max_length (int, optional): Maximum number of words to generate. Defaults to 20.
        
    Returns:
        str: Generated text as a space-separated string of words
    """
    model.eval()
    words = seed_text.lower().split()
    indices = [word_to_idx.get(word, word_to_idx['<unk>']) for word in words]
    indices = [word_to_idx['<bos>']] + indices

    with torch.no_grad():
        for _ in range(max_length):
            input_tensor = torch.tensor(indices).unsqueeze(0).to(device)
            output, _ = model(input_tensor)
            next_token_idx = output[0, -1].argmax().item()

            if next_token_idx == word_to_idx['<eos>']:
                break

            indices.append(next_token_idx)

    generated_words = [idx_to_word[idx] for idx in indices[1:]] # skip <bos>

    return ' '.join(generated_words)

In [25]:

def get_impact_function(model, valid_batches, criterion, device):
    """Get the impact function for each parameter in the model.
    
    The impact function is calculated by accumulating gradients over the validation set.
    This represents how much each parameter impacts the model's performance on validation data.
    
    Args:
        model: The trained language model
        valid_batches: Validation data batches
        criterion: Loss function
        device: Device to run computations on (cuda/cpu/mps)
        
    Returns:
        dict: Dictionary mapping parameter names to their accumulated gradients
    """
    # initialize gradient dictionary to store gradients of each parameter
    valid_grads = {name: torch.zeros_like(param.data) for name, param in model.named_parameters() if param.requires_grad}
    
    # set model to training mode so the dropout is applied
    model.eval()

    for batch in valid_batches:
        model.zero_grad()
        inputs = batch[:, :-1].to(device)
        targets = batch[:, 1:].to(device)
        
        outputs, _ = model(inputs)
        loss = criterion(outputs.reshape(-1, outputs.size(-1)),
                         targets.reshape(-1))
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                valid_grads[name] += param.grad.detach()

    grads = torch.cat([value.view(-1) for name, value in valid_grads.items()])
    
    return grads
    
def get_fisher_matrix(model, train_batches, criterion, device, epsilon=1e-7, lambda_reg=1E-3):
    model.eval()
    fisher_matrix = {name: torch.zeros_like(param.data) for name, param in model.named_parameters() if param.requires_grad}
    grad_matrix = {name: torch.zeros_like(param.data) for name, param in model.named_parameters() if param.requires_grad}
    n_samples = 0
    
    for batch in train_batches:
        model.zero_grad()
        inputs = batch[:, :-1].to(device)
        targets = batch[:, 1:].to(device)
        n_samples += inputs.size(0)
        outputs, _ = model(inputs)
        loss = criterion(outputs.reshape(-1, outputs.size(-1)),
                         targets.reshape(-1))
        loss.backward()

        # compute the empirical fisher matrix (diagonal of the real fisher matrix) 
        # and collect gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher_matrix[name] += torch.pow(param.grad.detach(), 2)
        
    # normalize the fisher values per parameter and invert
    for name in fisher_matrix:
        temp = (fisher_matrix[name] / (n_samples)) + lambda_reg * 1.0
        fisher_matrix[name] = 1 / (temp + epsilon)

    grads = torch.cat([value.view(-1) for name, value in fisher_matrix.items()])
    
    return grads


def get_fisher_influence_function(model, train_batches, valid_batches, criterion, device):
    fisher_values = get_fisher_matrix(model, train_batches, criterion, device)
    valid_grads = get_impact_function(model, valid_batches, criterion, device)
    # valid_grads = torch.unsqueeze(valid_grads, 0)
    
    influence_values = []

    for batch in train_batches:
        for b in range(batch.size(0)):
            model.zero_grad()
            inputs = batch[b: b+1, :-1].to(device)
            targets = batch[b: b+1, 1:].to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)),
                         targets.reshape(-1))
            loss.backward()
            
            grads = torch.cat([param.grad.detach().view(-1) for name, param in model.named_parameters() if param.requires_grad])

            # fisher_values = fisher_values * grads

            influence_values.append(torch.sum(valid_grads * fisher_values * grads))

    influence_values = torch.tensor(influence_values)
    print(f"influence_values: {influence_values.shape}")
    return influence_values


In [13]:
trial_name = 'lstm-1'

if not os.path.exists('logs'): os.makedirs('logs')
log_dir = os.path.join('logs', trial_name)
if not os.path.exists(log_dir): os.makedirs(log_dir)

# configure the logger
logger.add(os.path.join(log_dir, 'logs.log'), format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", level="INFO")

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
logger.info(f"Using device: {device}")

# load and tokenize the data
train_data, valid_data, test_data = load_and_preprocess_data()
logger.info(f"Loaded data")
# build a vocabulary
word_to_idx, idx_to_word = build_voab(train_data)
logger.info(f"Built vocabulary of size {len(word_to_idx)}")

# convert to tensors and add <bos> and <eos>
train_tensors = process_data(train_data, word_to_idx)
valid_tensors = process_data(valid_data, word_to_idx)
test_tensors = process_data(test_data, word_to_idx)
logger.info("processed all data")

# create data directory if it doesn't exist
if not os.path.exists('data'): os.makedirs('data')

# save vocabulary mappings
vocab_dict = {
    'word_to_idx': {word: int(idx) for word, idx in word_to_idx.items()},  # convert any tensor indices to int
    'idx_to_word': {int(idx): word for idx, word in idx_to_word.items()}   # convert any tensor indices to int
}

# This import should be moved to the first cell
vocab_path = os.path.join('data', 'treebank_vocab.json')
with open(vocab_path, 'w') as f:
    json.dump(vocab_dict, f)

logger.info("Saved vocabulary mappings to JSON file")


# pad the tokens and create batches
batch_size = 32
train_batches = create_batches(train_tensors, word_to_idx, batch_size)
valid_batches = create_batches(valid_tensors, word_to_idx, batch_size)
test_batches = create_batches(test_tensors, word_to_idx, batch_size)
logger.info(f"Created batches for train, test, valid tensors")

# save tensors using torch.save
tensors_dict = {
    'train': [tensor.to('cpu') for tensor in train_batches],
    'valid': [tensor.to('cpu') for tensor in valid_batches],
    'test': [tensor.to('cpu') for tensor in test_batches]
}

torch.save(
    tensors_dict,
    os.path.join('data', 'treebank_batches_tensors.pt')
)
logger.info("Saved processed tensors to file")

# load tensors
# tensors_dict = torch.load(os.path.join('data', 'treebank_batches_tensors.pt'))
# train_batches = [tensor.to(device) for tensor in tensors_dict['train']]
# valid_batches = [tensor.to(device) for tensor in tensors_dict['valid']] 
# test_batches = [tensor.to(device) for tensor in tensors_dict['test']]


# Initialize the model and training components
vocab_size = len(word_to_idx)
embed_size = 300
hidden_size = 512
num_layers = 1
dropout = 0.5
lr = 0.001

model = Network(vocab_size, embed_size, hidden_size, num_layers, cell='lstm', dropout=dropout).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


# Training loop
num_epochs = 25
train_ppls = []
valid_ppls = []
train_losses = []
valid_losses = []
best_valid_ppl = float('inf')

for e in range(num_epochs):
    train_loss = train(model, train_batches, criterion, optimizer, device)
    valid_loss = evaluate(model, valid_batches, criterion, device)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    train_ppl = calculate_perplexity(train_loss)
    valid_ppl = calculate_perplexity(valid_loss)
    
    train_ppls.append(train_ppl)
    valid_ppls.append(valid_ppl)

    logger.info(f'Epoch {e+1}/{num_epochs}:')
    logger.info(f'Train loss: {train_loss:.2f}, Perplexity: {train_ppl:.2f}')
    logger.info(f'Valid loss: {valid_loss:.2f}, Perplexity: {valid_ppl:.2f}')

    if valid_ppl < best_valid_ppl:
        best_valid_ppl = valid_ppl
        torch.save(model.state_dict(), os.path.join(log_dir, 'best_model.pt'))
        logger.info(f'New best validation perplexity: {valid_ppl:.2f}')

visualize(train_ppls, valid_ppls, f'Perplexity for {trial_name}', 'epochs', 'ppl', os.path.join(log_dir, 'perplexity.png'))
visualize(train_losses, valid_losses, f'Loss for {trial_name}', 'epochs', 'loss', os.path.join(log_dir, 'losses.png'))

# Load the best model for final evaluation
model.load_state_dict(torch.load(os.path.join(log_dir, 'best_model.pt')))

test_loss = evaluate(model, test_batches, criterion, device)
test_ppl = calculate_perplexity(test_loss)
logger.info(f"Test Perplexity with best model: {test_ppl:.2f}")



[32m2025-05-01 14:23:59.439[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [1mUsing device: mps[0m
[32m2025-05-01 14:24:00.060[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mLoaded data[0m
[32m2025-05-01 14:24:00.067[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mBuilt vocabulary of size 4899[0m
[32m2025-05-01 14:24:00.094[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mprocessed all data[0m
[32m2025-05-01 14:24:00.102[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m45[0m - [1mSaved vocabulary mappings to JSON file[0m
[32m2025-05-01 14:24:00.115[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m53[0m - [1mCreated batches for train, test, valid tensors[0m
[32m2025-05-01 14:24:00.123[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m66[0m - [1mSaved processed tensors to file[0m
Consider using tens

In [14]:
generate_text(model, word_to_idx, idx_to_word, device, seed_text="the", max_length=20)

'the company said 0 it has redeemed its rights .'

In [15]:

n_params = 0
for param in list(model.parameters()):
    n_params += param.numel()
print(f"Number of parameters: {n_params}")


Number of parameters: 5649959


In [None]:
# compute fisher influence function
fisher_influence_values = get_fisher_influence_function(model, train_batches, valid_batches, criterion, device)

In [27]:
influence_values = {'fisher_influence_values': fisher_influence_values}
save_file(influence_values, os.path.join(log_dir, 'fisher_influence_values.safetensors'))