# EECS 595 HW3: Parts 5: Fine-Tuning BERT for Classification

In Parts 1 to 3, you built and pre-trained a tiny BERT-like model for the langauge of product reviews. Pre-training is exciting and necessary, but the real power comes from adapting a pre-trained model's parameters to a particular task. Now that you have a copy of that model saved, let's try to adapt it for the same classification task as in Homework 2: sentiment analysis. 

To start, like in Part 4, you'll need to import the relevant code from BERT into this notebook. You can either do this with a copy/paste if your code is finalized, or if you have your code as a script somewhere (e.g., for submitting to Great Lakes for training), you can import the classes and methods with `import` statements from that file, e.g., something like `from myfile import BERT`. The latter approach is probably better since it will make this notebook much less cluttered and reduce any chances for typos/missing pieces when copying.

*Part 5* will have you test out the "classification" mode of the `BERT` model you implemented. Here, you'll once again need to write a `Dataset` and `collator` function to read in the sentiment data. These are _very_ common steps you'll need as a practitioner (since you'll mostly be starting from pre-trained models), so we wanted to give you practice. The code will look similar to what you had in Part 3 too.

**Important Note:** You can run this notebook to test your initial pre-trained model from your CPU for classification. However, all of tasks should be done using the fully-trained version you get from Great Lakes.


In [None]:
import os
import math
import numpy as np
import random
import logging

# Bring in PyTorch
import torch
import torch.nn as nn

# Most of the examples have typing on the signatures for readability
from typing import Optional, Callable, List, Tuple
from copy import deepcopy
# For data loading
from torch.utils.data import Dataset, IterableDataset, TensorDataset, DataLoader
import json
import glob
import gzip
import bz2
#import wandb

from sklearn.metrics import precision_recall_fscore_support

import pandas as pd

import matplotlib.pyplot as plt

# For progress and timing
from tqdm.auto import tqdm, trange
import time

# check if gpu is available
device = 'cpu' 
if torch.backends.mps.is_available():
    device = 'mps'
if torch.cuda.is_available():
    device = 'cuda'
print(f"Using '{device}' device")

## Part 5.0: Be able to run BERT

You'll need to have the `BERT` class in this notebook to run BERT so import (or copy) all the necessary code in the cell below. It's better to import the classes/functions from the file you used to submit to Great Lakes since that code works. This can/should be the same code you used in Part 4.0 during your BERT exploration.

In [None]:
################################################################
#                     TODO: YOUR CODE HERE                     #
#
# 1. Create a tokenizer for the BERT model
# 2. Import (or copy) all the necessary code to run the BERT model. 
#
################################################################
 


## Part 5.1 Build a `Dataset` for Sentiment Classification

To train the model, we'll create a new `Dataset` that turns our text data into sequences of token IDs. This code will look similar to what we have in the MLM `Dataset` class.

In [None]:

class_train_path = './sentiment.train.csv' # column names: sentence, label
class_test_path = './sentiment.dev.csv' # column names: sentence, label

class ClassificationDataset(Dataset):
    def __init__(self, tokenizer, data_filename: str, max_len=128):

        ############################################################
        #             TODO: YOUR CODE                              #
        # 
        # 1 save the arguments and load the data from the specified file
        ############################################################        
        pass

    def __len__(self):
        return len(self.data)
    
    def tokenize_and_prepare(self):
        ############################################################
        #             TODO: YOUR CODE                              #
        #
        # 1. Tokenize the data and prepare it for the model
        #
        # NOTE: for memory efficiency, you can delete the data field 
        # after tokenizing the data and just retain the tokenized version
        ############################################################

        pass

    def __getitem__(self, idx):
        ############################################################
        #                   TODO: YOUR CODE                        #
        #
        # 1. Look up the tokenized data and label for the specified index
        # 2. Create an attention mask for the data
        # 3. Return the ids, attention mask, and label
        ############################################################
        pass

## Part 5.2: Implement a collate function for sentiment data

Just like in our MLM training for BERT, we'll need another collator function that turns sequences into batches. Your collate function will look similar to what you had for MLM.

In [None]:
def classification_collate_fn(batch):
    '''
    Collate function for the classification dataset.

    Args:
    - batch: list of tuples of the form (input_ids, attention_mask, label)
    '''
    ############################################################
    #                     TODO: YOUR CODE                      #
    # 1. Pad the input_ids and attention_mask
    # 2. Return the input_ids, attention_mask, and labels as a tuple
    #
    ############################################################
    pass

In [None]:
# test the collate function
batch_size = 8
classification_dataset = ClassificationDataset(tokenizer, class_train_path)
classification_dataset.tokenize_and_prepare()
classification_dataloader = DataLoader(classification_dataset, batch_size=batch_size, shuffle=False, collate_fn=classification_collate_fn)


for input_ids, attention_mask, labels in classification_dataloader:
    print(input_ids.shape)
    print(attention_mask.shape)
    print(labels.shape)
    break

## Part 5.3 Load your pre-trained model

Load the parameters in saved `state_dict` into a new instances of a `BERT` model. You did something similar to this in Homework 2 when you loaded the saved embeddings for your attention classified. See pytorch's [documentation](https://pytorch.org/tutorials/beginner/saving_loading_models.html) for some guidance here. Depending on how and when you saved your model, you _might_ run into some complaints about missing parameters, so you might need to change thet `strict` argument when loading. 

This code will be similar to the code as in Part 4.1 in the BERT exploration notebook _except_ that you need to make sure the `BERT` model's `mode` is set to "classification".

In [None]:
################################################################
#                     TODO: YOUR CODE HERE                     #
#
# 1. Create a BERT model for classification
# 2. Load the pre-trained BERT model parameters from your saved file
# 3. Move the model the the appropriate device if needed (e.g. GPU, MPS)
#
################################################################




## Part 5.4 Load in the development data

During training, we'll want to evaluate our model: how well is it doing over time? We can use the performance metrics to do model selection and choose what's our "best" model for evaluating or use in production.

The first part is relatively straightforward part: Load in the development data like you loaded in the training data. We'll use it later when training for some evaluation.

The second part will be to write our `evaluate_model` function that will take in a model and a `DataLoader` (e.g,. for the development data) and return the scores. We'll call this function periodically during training to get an updated estimate of how well the model is learning and performing.

In [None]:
################################################################
#             TODO: YOUR CODE                                  #
# 1. Load a classification dataset for the dev data
# 2. Create a dataloader for the dev data
#
################################################################


def evaluate_model(model, dataloader, device):
    '''
    Returns the precision, recall, and f1-score of the model on the data in the specified dataloader.
    '''

    ################################################################
    #             TODO: YOUR CODE                                  #
    #
    # 1. Set the model to evaluation mode
    # 2. Iterate through the dataloader and make predictions
    # 3. Calculate the precision, recall, and f1-score
    #
    # HINT: You can use sklearn.metrics to calculate the metrics
    ################################################################

    pass



## Part 5.5 Fine-Tune a BERT Classifier

Let's put those pre-trained parameters to work! Using your BERT model for classification, this part will have you write a training loop. The training loop will look similar to those you have done before at this point.

For this part of the assignment, you can train it entirely on your laptop. The dataset should fit and be relatively fast to train. For some reference, on a macbook M1, training with a batch size of 8 takes around 7 minutes for one epoch using 'mps' and 25 minutes with 'cpu'. The model and data both fit in 8GB.

When training, we'll periodically call `evaluate_model` to score the current parameters. It's common to choose one metric to consider what is the "best" and then save those parameters so that at the end of training, you can load back in which model did best. We'll do the same here using the F1 score. Be sure to save this classifier model to a different filename (different from the MLM pre-trained model's one!) so you can load it back it.

We'll add a bit more fancy Weights & Biases instrumentation for training our model here. Specifically, in addition to reporting loss, we'll also periodically score the model on the evaluation set and report those numbers to `wandb` every 1000 steps. When training large models, it's helpful to not only get a sense of the loss but also the model's performance on the task you actually want it to do. Often loss and task-specific performance are closely correlated, but not always! The plots on `wandb` can help show us when we can stop training by looking at when the task-specific performance converges; the model's loss may continue to go down, but this could be that the model is overfitting. See the wandb [documentation](https://docs.wandb.ai/guides/track/launch) for examples on how to log metrics.

In [None]:
################################################################
#             TODO: YOUR CODE                                  #
#
# 1. Define the hyperparameters, loss function, optimizer, etc. for training
# 2. Initialize wandb 
#
################################################################


# Keep track of the losses for quick plotting after
losses = []

# Train the model
for epoch in range(num_epochs):
    for input_ids, attention_mask, labels in tqdm(classification_dataloader):

        ################################################################
        #             TODO: YOUR CODE                                  #
        #
        # 1. Predict the instances in the batch
        # 2. Compute the loss and update the weights
        # 3. Every `reporting_interval` batches, score the model on the 
        #    dev data using the `evaluate_model` function and report those scores to wandb
        # 4. Keep track of the model parameters with the best f1 score and save it to disk
        #
        ################################################################


plt.plot(losses)

## Part 5.6 Evaluate on the test data

Once you've finished training, let's see how well our best model does on the test data. Do the following steps:

1. Load in the best model's parameters from your saved file so that we can use it for classification on the test data
2. Create a `Dataset` instances and `DataLoader` for the test data
3. Generate the prediction file for the test data (this code will look similar to `evaluate_model`)
4. Upload the predictions to Kaggle.

The test data format to upload to kaggle is the same as in Homework 2.


In [None]:
################################################################
#             TODO: YOUR CODE                                  #
################################################################

## Optional Visualization

If you're curious try putting in some example text and visualizing what the classification model's different heads are looking at. You can also try contrast the attention focus with the original pre-trained BERT (you'll need to load this in separately).

In [None]:
################################################################
# The visualization code below is provided for you. You can use it to visualize the attention weights of the BERT model.
# It takes the list of layer/head indices so that you can compare the attention weights across difference layers/heads

def attention_visualizer(sentence, model, tokenizer, layers=[-1], heads=[-1]):
    model.eval()
    with torch.no_grad():
        # Encode the sentence
        tokenized_input = tokenizer.encode(sentence)
        ids = tokenized_input.ids
        
        # Forward pass to get attention weights
        _, attns = model(torch.tensor([ids]).to(device))
        print(attns.shape)
        
        # Determine the number of layers and heads to compare
        num_layers = len(layers)
        num_heads = len(heads)
        
        # Set up the figure for subplots
        fig, axs = plt.subplots(num_layers, num_heads, figsize=(num_heads*5.5, num_layers*5.5))
        
        # Handle the case for single subplot to maintain consistency
        if num_layers == 1 and num_heads == 1:
            axs = np.array([[axs]])
        
        # Convert axs to an array for easy indexing if it's not already
        if not isinstance(axs, np.ndarray):
            axs = np.array(axs)
        
        # Ensure axs is 2D
        if axs.ndim == 1:
            axs = np.expand_dims(axs, axis=0 if num_layers == 1 else 1)
        
        for i, layer in enumerate(layers):
            for j, head in enumerate(heads):
                # attn has shape [B, L, H, T_q, T_k]
                attn = attns[0][layer, head, :, :]
                
                # Extract the attention weights for visualization
                attn_weights = attn.squeeze(0).cpu().numpy()
                
                # Get the tokens for labels
                tokens = tokenized_input.tokens
                
                # Plot the attention heatmap
                cax = axs[i, j].matshow(attn_weights, cmap='viridis')
                fig.colorbar(cax, ax=axs[i, j], fraction=0.046, pad=0.04)
                
                # Set the tick labels
                axs[i, j].set_xticks(np.arange(len(tokens)))
                axs[i, j].set_yticks(np.arange(len(tokens)))
                axs[i, j].set_xticklabels(tokens, rotation=90, fontsize=12)
                axs[i, j].set_yticklabels(tokens, fontsize=12)
                axs[i, j].set_xlabel('Key Sequences', fontsize=14)
                axs[i, j].set_ylabel('Query Sequences', fontsize=14)
                axs[i, j].set_title(f'Layer: {layer}, Head: {head}', fontsize=16)
        
        plt.tight_layout()
        plt.show()

# Example usage
sentence = "I liked the book and characters because I could relate to it."
# Assuming 'bert' and 'tokenizer' are defined and initialized
# You should replace 'layers' and 'heads' with the specific indices you want to visualize
attention_visualizer(sentence, bert, tokenizer, layers=[0, 1], heads=[0, 1])
