# BoolQ - BERT - Main file

In [None]:
from datasets import load_dataset
import torch
import numpy as np
import pickle
import json
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import io
import PIL
from PIL import Image, ImageEnhance
import pickle
import gc

In [None]:
# Transform for CNN
from torchvision import transforms
input_size = 224

transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),  # Resize all images to the same size
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize images
])

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes, input_size=input_size):
        super(SimpleCNN, self).__init__()
        # Convolutional Layer 1
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 2
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 3
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 4
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.act4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Convolutional Layer 5
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.act5 = nn.ReLU()
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Adjusted size calculation after 5 pooling layers
        size_after_conv = input_size // 32  # Each pooling layer halves the dimension
        self.fc1 = nn.Linear(256 * size_after_conv * size_after_conv, num_classes)

    def forward(self, x):
        x = self.pool1(self.act1(self.conv1(x)))
        x = self.pool2(self.act2(self.conv2(x)))
        x = self.pool3(self.act3(self.conv3(x)))
        x = self.pool4(self.act4(self.conv4(x)))
        x = self.pool5(self.act5(self.conv5(x)))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [None]:
dataset = load_dataset("boolq")

In [None]:
input_pairs = [[row['question'], row['passage']] for row in dataset['train']]

In [None]:
# Model link and tokenizer
model_link = "rycecorn/bert-fine-tuned-boolq"
tokenizer = AutoTokenizer.from_pretrained(model_link)
model = AutoModelForSequenceClassification.from_pretrained(model_link, output_attentions=True)

In [None]:
def get_self_attention_matrix(input_sentence, attention_head):
    test = input_pairs[input_sentence]
    tokenized_input = tokenizer(test[0], test[1], truncation=True, padding=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
      outputs = model(**tokenized_input)
    attention_outputs = torch.stack(outputs.attentions)
    flattened_attention_matrices = attention_outputs.view(-1, attention_outputs.size(3), attention_outputs.size(4))
    flattened_attention_matrices = flattened_attention_matrices
    
    selected_attention_head = flattened_attention_matrices[attention_head].cpu().numpy()
    
      # Clear intermediate variables
    del test, tokenized_input, outputs, attention_outputs, flattened_attention_matrices
    torch.cuda.empty_cache()
    gc.collect()
    return selected_attention_head

In [None]:
def classify_head_pattern(attention_pattern):
    # plot attention pattern
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_pattern, cmap='magma', interpolation='nearest')
    
    #Convert to bytes and to image file
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.clf()
    plt.close()
    buf.seek(0)
    png_variable = buf.getvalue()
    image_buffer = io.BytesIO(png_variable)
    
    # Close buffer - delete from memory
    buf.close()
    
    # Prep image for CNN
    img = PIL.Image.open(image_buffer)
    imgCropped = img.crop(box= (205, 96, 820, 713))
    enhancer = PIL.ImageEnhance.Contrast(imgCropped)
    enhanced_image = enhancer.enhance(4.0)
    gray_image = enhanced_image.convert("L")
    img = transform(gray_image.convert('RGB'))
    img = img.unsqueeze(0)
    
    # Clear intermediate variables and collect garbage
    del attention_pattern, png_variable, imgCropped, enhancer, enhanced_image, gray_image
    torch.cuda.empty_cache()
    gc.collect()
    
    return img

In [None]:
model_cnn = torch.load('ahr_cnn_75_acc.pth')

In [None]:
sentence_number = random.randint(0, 999)
head_number = random.randint(0, 73)
print(f'Input num: {sentence_number}')
print(f'Head num: {head_number}')

head_check = get_self_attention_matrix(sentence_number, head_number)
test_img = classify_head_pattern(head_check)

with torch.no_grad():
    output = model_cnn(test_img)
    _, predicted = torch.max(output, 1)
    print(f'Predicted class: {predicted.item()}') # Labels for classifier are range 0-4

In [None]:
heads_to_prune = [34, 42, 29, 60, 59, 58, 43, 40, 44, 55, 53, 27, 51, 48, 50, 25, 45, 33, 49, 38, 52, 54, 31, 56, 39, 32, 36, 35, 37, 57, 28, 47, 63, 46, 70, 30, 16, 26, 15, 18, 64, 7, 68, 8, 71, 20, 62, 5, 24, 72, 66, 61, 12, 14, 6, 65, 67, 17, 1, 69, 23, 41, 10, 21, 9, 2, 19, 4]
attention_pattern_dict = {i: [0, 0, 0, 0, 0] for i in range(1, 145)}

In [None]:
%%time
for head in tqdm(heads_to_prune, desc='head_count'):
    head_num_in_model = head - 1 # For index purpose
    i = 0
    while i != 80:
        sentence_number = random.randint(0, len(input_pairs) - 1)
        head_attention_matrix = get_self_attention_matrix(sentence_number, head_num_in_model)
        matrix_to_img = classify_head_pattern(head_attention_matrix)
        with torch.no_grad():
            output = model_cnn(matrix_to_img)
            _, predicted = torch.max(output, 1) # predicted.item() to get class

        # Update attention_pattern_dict
        attention_pattern_dict[head][predicted.item()] += 1
        i += 1
        # Clear variables and collect garbage
        del head_attention_matrix, matrix_to_img, output
        torch.cuda.empty_cache()
        gc.collect()

with open('analysis_BERT_boolq_attention_prune_count.pickle', 'wb') as f:
    pickle.dump(attention_pattern_dict, f)

# Plot patterns for each head pruned

### Get pattern count for each head

In [None]:
file_path = 'analysis_BERT_boolq_attention_prune_count.pickle'

# Open the pickle file and load the data
with open(file_path, 'rb') as file:
    head_attention_count = pickle.load(file)

### Get heads pruned at each interval

In [6]:
with open('BERT_BoolQ_heads_pruned.pkl', 'rb') as f:
    head_prune_intervals = pickle.load(f)

{0.05: [44, 43, 38, 41, 58, 65, 80],
 0.1: [45, 63, 51, 81, 39, 91, 96],
 0.15: [92, 49, 53, 103, 100, 88, 52, 50],
 0.2: [66, 75, 55, 74, 76, 57, 97],
 0.25: [72, 87, 73, 69, 62, 99, 56],
 0.3: [101, 37, 83, 118, 47, 67, 98],
 0.35: [61, 102, 59, 68, 86, 64, 90],
 0.4: [77, 60, 138, 144, 136, 84, 120, 133],
 0.45: [131, 122, 115, 139, 104, 40, 124],
 0.5: [126, 125, 132, 78, 143, 93, 105],
 0.55: [135, 128, 31, 79, 127, 82, 129],
 0.6: [141, 130, 106, 123, 108, 22, 121],
 0.65: [13, 95, 94, 54, 137, 109, 28, 70],
 0.7: [107, 21, 85, 134, 32, 18, 119],
 0.75: [112, 71, 27, 48, 116, 89, 114],
 0.8: [140, 113, 111, 16, 142, 8, 35],
 0.85: [33, 6, 5, 36, 10, 29, 7],
 0.9: [117, 1, 20, 24, 30, 15, 2, 26],
 0.95: [110, 12, 9, 46, 23, 3, 14]}

### Create dictionary with attention head and number of patterns pruned from each head

In [None]:
pruned_heads_attention_pattern_count = {i: [0, 0, 0, 0, 0] for i in np.around(np.arange(0, 0.96, 0.05), 2)}
for i in head_prune_intervals:
    for head in head_prune_intervals[i]:
        pruned_heads_attention_pattern_count[i] = [a + b for a, b in zip(head_attention_count[head], pruned_heads_attention_pattern_count[i])]

### Get cumulative pattern at each interval

In [None]:
# Function to calculate cumulative dictionary
def calculate_cumulative(attention_patterns):
    cumulative_patterns = {}
    cumulative_sum = [0, 0, 0, 0, 0]

    for key in sorted(attention_patterns.keys()):
        cumulative_sum = [a + b for a, b in zip(cumulative_sum, attention_patterns[key])]
        cumulative_patterns[key] = cumulative_sum.copy()
    
    return cumulative_patterns

cumulative_patterns = calculate_cumulative(pruned_heads_attention_pattern_count)
cumulative_patterns

### Plot for proportion of patterns at each pruning interval

In [None]:
data = pruned_heads_attention_pattern_count

In [None]:
# Updated code to plot with specified modifications

# Convert dictionary to list for plotting
labels = list(data.keys())
values = np.array(list(data.values()))

# Calculate proportions
proportions = values / values.sum(axis=1, keepdims=True)

# Define new category labels
category_labels = ['Vertical', 'Diagonal', 'Vert + Diag', 'Block', 'Homogenous']

# Plot
fig, ax = plt.subplots(figsize=(10, 6))  # Increased the height of the figure to 12

# Define bar width and positions
bar_width = 0.8  # Increased the bar width
x = np.arange(len(labels))

# Plot stacked bar chart with proportions
bottom = np.zeros(len(labels))
for i in range(proportions.shape[1]):
    ax.bar(x, proportions[:, i], bar_width, bottom=bottom, label=category_labels[i])
    bottom += proportions[:, i]

# Add labels and title
ax.set_xlabel('Key')
ax.set_ylabel('Proportion')
ax.set_title('Prune proportions - BERT - BoolQ')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

# Add space between bars
for bar in ax.patches:
    bar.set_x(bar.get_x() + 0.05)
    bar.set_width(bar.get_width() - 0.1)

plt.savefig('./outputs/pruned_props_BERT_BoolQ.png')
plt.show()


### Plot for cumulative proportion of patterns at each pruning interval

In [None]:
# Convert dictionary to list for plotting
labels = list(cumulative_patterns.keys())
values = np.array(list(cumulative_patterns.values()))

# Calculate proportions
proportions = values / values.sum(axis=1, keepdims=True)

# Define new category labels
category_labels = ['Vertical', 'Diagonal', 'Vert + Diag', 'Block', 'Homogenous']

# Plot
fig, ax = plt.subplots(figsize=(10, 6))  # Increased the height of the figure to 12

# Define bar width and positions
bar_width = 0.8  # Increased the bar width
x = np.arange(len(labels))

# Plot stacked bar chart with proportions
bottom = np.zeros(len(labels))
for i in range(proportions.shape[1]):
    ax.bar(x, proportions[:, i], bar_width, bottom=bottom, label=category_labels[i])
    bottom += proportions[:, i]

# Add labels and title
ax.set_xlabel('Key')
ax.set_ylabel('Proportion')
ax.set_title('Prune cumulative proportions - BERT - BoolQ')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

# Add space between bars
for bar in ax.patches:
    bar.set_x(bar.get_x() + 0.05)
    bar.set_width(bar.get_width() - 0.1)

plt.savefig('./outputs/cumu_props_BERT_BoolQ.png')
plt.show()


### Plot for pruned patterns trend over time

In [None]:
# Extracting keys and values
x = list(data.keys())
y_values = list(zip(*data.values()))

# Plotting the data with the new labels and all ticks on the x-axis
plt.figure(figsize=(10, 6))

for y in y_values:
    plt.plot(x, y)

plt.xlabel('X-axis')
plt.ylabel('Values')
plt.title('Pruned Patterns - BERT - BoolQ')
plt.xticks(x)  # Show all ticks on the x-axis
plt.legend(['Vertical', 'Diagonal', 'Vert + Diag', 'Block', 'Homogenous'], loc='lower left', fontsize='small')
plt.grid(True)
plt.savefig('/home/jovyan/Thesis/ahr_pattern_prune_analysis/outputs/pruned_patterns_BERT_BoolQ.png')
plt.show()