Link the AHR matrix to be used below

In [None]:
file_path = './outputs/attention_head_redundancy/boolq_cosine_sim_BERT_1000.pkl'
model_link = "rycecorn/bert-fine-tuned-boolq"

# Import libraries and systems check

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
import pickle
import time
from tqdm import tqdm

In [None]:
import transformers
print(transformers.__version__)

In [None]:
# Check if CUDA is available
print(torch.cuda.is_available())

# Load redudancy scores from pickle and construct redudancy DF

In [None]:
with open(file_path, 'rb') as file:
    boolq_cosine_sim_ft = pickle.load(file)

# Return average redundancies
average_redundancies = [sum(values) / len(values) for values in zip(*boolq_cosine_sim_ft)]

# Structure for storing (head1, head2, similarity)
num_heads = 144
test_store = []
avg_idx = 0
for i in range(num_heads):
    for j in range(i+1, num_heads):
      test_store.append([i+1, j+1, average_redundancies[avg_idx]])
      avg_idx +=  1
        

# Dataframe that stores every head vs every other head for analysis

# Create matrix with all pairs for analysis
original_corr_df = pd.DataFrame(test_store, columns=['head_a', 'head_b', 'similarity'])
double_corr_df = pd.DataFrame(test_store, columns=['head_a', 'head_b', 'similarity'])
reverse_corr_df = double_corr_df[['head_b', 'head_a', 'similarity']].copy()
reverse_corr_df = reverse_corr_df.rename(columns={'head_b': 'head_a', 'head_a': 'head_b'})

all_pairs_corr = pd.concat([original_corr_df, reverse_corr_df], axis=0)

## Order head by redundancy scores against all other heads

In [None]:
sorted_avg_corr = all_pairs_corr.groupby('head_a').mean().sort_values(by='similarity', ascending=False)

# Helper functions

## Dictionary that stores the heads with its corresponding layer index and head index

In [None]:
# Create dictionary with any number of heads and layers
def get_layers_heads(heads_each_layer, layers_total):
    values_layers_heads = {}
    total_heads = heads_each_layer * layers_total
    
    for head in np.arange(1, total_heads+1):
        layer = (head - 1) // layers_total + 1
        if layer == 1:
            values_layers_heads[head] = (layer, head)
        else:
            head_in_layer = head - (layers_total * (layer - 1))
            values_layers_heads[head] = (layer, head_in_layer)
        
    return values_layers_heads

## Get exact heads to prune at each interval

In [None]:
def get_heads_to_prune(redundant_heads):
    # Input is heads ordered by redundancy
    prune_ratios = np.arange(0.05, 0.96, 0.05)
    number_of_heads = len(redundant_heads)
    heads_to_prune_interval = []
    
    for ratio in prune_ratios:
        num_heads_to_prune = round(number_of_heads * ratio)
        heads_to_prune_interval.append(redundant_heads[:num_heads_to_prune])
    
    return heads_to_prune_interval

## Batch & preprocess the data for DataLoader

In [None]:
def collate_fn(batch):
    # Prepare the batch for DataLoader
    input_ids = torch.stack([torch.tensor(d['input_ids']) for d in batch])
    attention_mask = torch.stack([torch.tensor(d['attention_mask']) for d in batch])
    token_type_ids = torch.stack([torch.tensor(d['token_type_ids']) for d in batch])
    labels = torch.tensor([d['labels'] for d in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'labels': labels
    }

def preprocess_function(batch):
    # Tokenize the batch of questions and passages
    tokenized_data = tokenizer(batch["question"], batch["passage"],
                               truncation=True, padding="max_length", max_length=512)

    # Create labels for the entire batch
    tokenized_data['labels'] = [1 if answer else 0 for answer in batch['answer']]
    return tokenized_data


# Load validation set

## Load model link and tokenizer

In [None]:
# Model link and tokenizer
model_link = "rycecorn/bert-fine-tuned-boolq"
tokenizer = AutoTokenizer.from_pretrained(model_link)

In [None]:
%%time
# Dataset prep
from datasets import load_dataset
dataset = load_dataset("boolq")

# Apply the function to the dataset in batches
tokenized_datasets = dataset.map(preprocess_function, batched=True)
validation_set = tokenized_datasets['validation']
validation_loader = DataLoader(validation_set, batch_size=16, collate_fn=collate_fn)

# Create layer and head index dictionary

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_link)
layers = len(model.bert.encoder.layer)
heads_each_layer = model.config.num_attention_heads

values_layers_heads = get_layers_heads(heads_each_layer, layers)

print(f"The model has {layers} layers with {heads_each_layer} heads in each layer")

# Set pruning conditions

In [None]:
# Storing the heads to prune
heads_in_order = sorted_avg_corr.index.tolist()
pruning_interval_heads = get_heads_to_prune(heads_in_order)

# Pruning ratios
prune_percentage = np.arange(0.05, 0.96, 0.05)

# Creates dictionary of the interval and heads to be pruned for that interval. Has format {'prune_ratio' : [heads_to_be_pruned]}
prune_dict = {}
i = 0
for percentage in prune_percentage:
    prune_dict[round(percentage, 2)] = pruning_interval_heads[i]
    i += 1

print(prune_dict)

## Heads at each interval for plotting

In [None]:
def extract_new_items(lists):
    result = [pruning_interval_heads[0]]
    for i in range(1, len(lists)):
        previous_list = lists[i - 1]
        current_list = lists[i]
        new_items = [item for item in current_list if item not in previous_list]
        result.append(new_items)
    return result

In [None]:
new_heads_each_interval = extract_new_items(pruning_interval_heads)
prune_percentage = np.around(np.arange(0.05, 0.96, 0.05), 2)
dict_new_heads_each_interval = {}
for i in range(0, len(prune_percentage)):
    dict_new_heads_each_interval[prune_percentage[i]] = new_heads_each_interval[i]

dict_new_heads_each_interval

In [None]:
# Save to a pickle file
with open('BERT_BoolQ_heads_pruned.pkl', 'wb') as f:
    pickle.dump(dict_new_heads_each_interval, f)

# Main process

## Prune and evaluate

In [None]:
# Model specific data
num_attention_heads = model.config.num_attention_heads
head_size = model.config.hidden_size // num_attention_heads

def prune_heads(model, heads_to_prune, head_size):
    for head_to_prune in heads_to_prune:
        layer_index = values_layers_heads[head_to_prune][0] - 1
        head_index = values_layers_heads[head_to_prune][1] - 1

        # Zero out the specific head in the query, key, and value matrices
        for matrix in ['query', 'key', 'value']:
            weight = getattr(model.bert.encoder.layer[layer_index].attention.self, matrix).weight.data
            bias = getattr(model.bert.encoder.layer[layer_index].attention.self, matrix).bias.data
            weight[:, head_index * head_size:(head_index + 1) * head_size] = 0
            bias[head_index * head_size:(head_index + 1) * head_size] = 0

def evaluate_model(model, validation_loader, device):
    model.to(device)
    model.eval()
    
    total_eval_accuracy = 0
    total_eval_loss = 0
    final_comp_time = 0

    for batch in validation_loader:
        b_input_ids = batch['input_ids'].to(device)
        b_attention_mask = batch['attention_mask'].to(device)
        b_token_type_ids = batch['token_type_ids'].to(device)
        b_labels = batch['labels'].to(device)
        
        start_time = time.time()
        with torch.no_grad():
            outputs = model(b_input_ids, token_type_ids=b_token_type_ids, attention_mask=b_attention_mask, labels=b_labels)

        loss = outputs.loss
        total_eval_loss += loss.item()
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        end_time = time.time()
        final_comp_time += end_time - start_time
        accuracy = (predictions == b_labels).cpu().numpy().mean()
        total_eval_accuracy += accuracy

    avg_val_accuracy = total_eval_accuracy / len(validation_loader)
    avg_val_loss = total_eval_loss / len(validation_loader)
    
    return avg_val_accuracy, avg_val_loss, final_comp_time

# Main loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pruning_results = []

# Load and save the initial state of the model
initial_model = AutoModelForSequenceClassification.from_pretrained(model_link)
initial_state_dict = initial_model.state_dict()

for prune_ratio in tqdm(prune_percentage, desc='Pruning ratios'):
    # Load the initial state of the model
    model.load_state_dict(initial_state_dict)
    
    heads_prune_this_iteration = prune_dict[round(prune_ratio, 2)]
    prune_heads(model, heads_prune_this_iteration, head_size)
    
    avg_val_accuracy, avg_val_loss, final_comp_time = evaluate_model(model, validation_loader, device)
    
    pruning_results.append((round(prune_ratio, 2), len(heads_prune_this_iteration), avg_val_accuracy, final_comp_time))


### Baseline performance

In [None]:
# Case where no heads are pruned
# Load and save the initial state of the model
initial_model = AutoModelForSequenceClassification.from_pretrained(model_link)
initial_state_dict = initial_model.state_dict()

# Evaluate the base case where no heads are pruned
model.load_state_dict(initial_state_dict)
base_accuracy, base_loss, base_comp_time = evaluate_model(model, validation_loader, device)

print(f"Base accuracy: {base_accuracy}")
print(f"Base inference time: {base_comp_time}")

## Visualize and save performance plot

In [None]:
# Save pruning & prune evaluation as dataframe
pruning_df = pd.DataFrame(pruning_results, columns=['Pruning ratio', 'Num heads pruned', 'Average validation accuracy', 'Comp time'])
pruning_df.to_csv('./outputs/pruning/boolq_prune_df.csv', index=False)

In [None]:
x = pruning_df['Pruning ratio']

plt.plot(x, pruning_df['Average validation accuracy'], label='Pruning')
plt.title('BoolQ / BERT-base')
plt.xlabel('Pruning ratio', fontsize=10)
plt.ylabel('Accuracy')
plt.xticks(np.arange(start=0.05, stop= 1, step=0.05), fontsize=8)
plt.axhline(0.7489, color='r', linestyle='--', label='No pruning')
plt.grid(True)
plt.ylim(0.35, 0.8) # change to your model performance baselines
plt.legend()
plt.savefig('./outputs/pruning/prune_boolq_BERT-FT.png')
plt.show()