# BoolQ - BERT - Main file

In [None]:
from datasets import load_dataset
import torch
import numpy as np
import pickle
import json
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import io
import PIL
from PIL import Image, ImageEnhance
import pickle
import gc

In [None]:
# Transform for CNN
from torchvision import transforms
input_size = 224

transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),  # Resize all images to the same size
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize images
])

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes, input_size=input_size):
        super(SimpleCNN, self).__init__()
        # Convolutional Layer 1
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 2
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 3
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Convolutional Layer 4
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.act4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Convolutional Layer 5
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.act5 = nn.ReLU()
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Adjusted size calculation after 5 pooling layers
        size_after_conv = input_size // 32  # Each pooling layer halves the dimension
        self.fc1 = nn.Linear(256 * size_after_conv * size_after_conv, num_classes)

    def forward(self, x):
        x = self.pool1(self.act1(self.conv1(x)))
        x = self.pool2(self.act2(self.conv2(x)))
        x = self.pool3(self.act3(self.conv3(x)))
        x = self.pool4(self.act4(self.conv4(x)))
        x = self.pool5(self.act5(self.conv5(x)))
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

In [None]:
dataset = load_dataset("boolq")

In [None]:
input_pairs = [[row['question'], row['passage']] for row in dataset['train']]

In [None]:
# Model link and tokenizer
model_link = "rycecorn/distil-bert-fine-tuned-boolq"
tokenizer = AutoTokenizer.from_pretrained(model_link)
model = AutoModelForSequenceClassification.from_pretrained(model_link, output_attentions=True)

In [None]:
def get_self_attention_matrix(input_sentence, attention_head):
    test = input_pairs[input_sentence]
    tokenized_input = tokenizer(test[0], test[1], truncation=True, padding=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
      outputs = model(**tokenized_input)
    attention_outputs = torch.stack(outputs.attentions)
    flattened_attention_matrices = attention_outputs.view(-1, attention_outputs.size(3), attention_outputs.size(4))
    flattened_attention_matrices = flattened_attention_matrices
    
    selected_attention_head = flattened_attention_matrices[attention_head].cpu().numpy()
    
      # Clear intermediate variables
    del test, tokenized_input, outputs, attention_outputs, flattened_attention_matrices
    torch.cuda.empty_cache()
    gc.collect()
    return selected_attention_head

In [None]:
def classify_head_pattern(attention_pattern):
    # plot attention pattern
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_pattern, cmap='magma', interpolation='nearest')
    
    #Convert to bytes and to image file
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.clf()
    plt.close()
    buf.seek(0)
    png_variable = buf.getvalue()
    image_buffer = io.BytesIO(png_variable)
    
    # Close buffer - delete from memory
    buf.close()
    
    # Prep image for CNN
    img = PIL.Image.open(image_buffer)
    imgCropped = img.crop(box= (205, 96, 820, 713))
    enhancer = PIL.ImageEnhance.Contrast(imgCropped)
    enhanced_image = enhancer.enhance(4.0)
    gray_image = enhanced_image.convert("L")
    img = transform(gray_image.convert('RGB'))
    img = img.unsqueeze(0)
    
    # Clear intermediate variables and collect garbage
    del attention_pattern, png_variable, imgCropped, enhancer, enhanced_image, gray_image
    torch.cuda.empty_cache()
    gc.collect()
    
    return img

In [None]:
model_cnn = torch.load('ahr_cnn_75_acc.pth')

In [None]:
sentence_number = random.randint(0, 999)
head_number = random.randint(0, 73)
print(f'Input num: {sentence_number}')
print(f'Head num: {head_number}')

head_check = get_self_attention_matrix(sentence_number, head_number)
test_img = classify_head_pattern(head_check)

with torch.no_grad():
    output = model_cnn(test_img)
    _, predicted = torch.max(output, 1)
    print(f'Predicted class: {predicted.item()}') # Labels for classifier are range 0-4

In [None]:
heads_to_prune = [34, 42, 29, 60, 59, 58, 43, 40, 44, 55, 53, 27, 51, 48, 50, 25, 45, 33, 49, 38, 52, 54, 31, 56, 39, 32, 36, 35, 37, 57, 28, 47, 63, 46, 70, 30, 16, 26, 15, 18, 64, 7, 68, 8, 71, 20, 62, 5, 24, 72, 66, 61, 12, 14, 6, 65, 67, 17, 1, 69, 23, 41, 10, 21, 9, 2, 19, 4]
attention_pattern_dict = {i: [0, 0, 0, 0, 0] for i in range(1, 73)}