# Extract attention matrices

This notebook extracts the attention matrices for every protein sequence, both wild-type and variant. 

- 11/11/2025 This is the initial set of cells in notebook that extract the attention matrice. Further cells will compare differences in attention matrices between variant and wild-type.
  

# Import libraries

In [9]:
from transformers import AutoTokenizer, EsmModel
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Define filepaths and model used, tokenise sequuences
The model used in this is the ESM2 650m parameter model. You can use different ones in pipeline, but this project usess the 650m. You should use the same one that was used for EvoProtGrad's expert step.

In [12]:
filepath = '../supplementary_data/results_table.csv'
df = pd.read_csv(filepath)
# Load tokenizer and model
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name, output_attentions=True)

df_wt = df[df['wt_or_var'] == "wt"].copy()
df_wt["sequence"] = df_wt["sequence"].astype(str)

df_var = df[df['wt_or_var'] == "var"].copy()
df_var["sequence"] = df_var["sequence"].astype(str)

FileNotFoundError: [Errno 2] No such file or directory: '../supplementary_data/results_table.csv'

# Extract matrices

This function extracts attention matrices, and is set to use the last layer, last batch, head 13, by default this is represted as [-1].[-1],[-13] in the sns heatmap plot lines.

In [15]:
def extract_matrices(DataFrame):
    for row in DataFrame.itertuples(index=False): 
        DataFrame["sequence"] = DataFrame["sequence"].astype(str) 
        inputs = tokenizer(row.sequence, return_tensors="pt", add_special_tokens=True)
        print(f"analysing {row.pdb_id}")
        # Forward pass with attentions
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extract attention matrices
        attentions = outputs.attentions  # tuple of (layer, batch, num_heads, seq_len, seq_len)
        
        # Remove BOS and EOS tokens (index 0 and -1)
        attentions = [
            att[:, :, 1:-1, 1:-1]  # keep only amino acids
            for att in attentions
        ]
        
        print(f"Number of layers: {len(attentions)}")
        
        
        # Representation last layers, last batch, head 13
        sns.heatmap(attentions[-1][-1][13].detach().numpy(), cmap="viridis")
        plt.title(row.pdb_id + " " + row.wt_or_var)
        #https://linkprotect.cudasvc.com/url?a=https%3a%2f%2fplt.show&c=E,1,T2GeLzzGVVWli4MNzS_hxs5CEOU9gM6vzF2Th5Y_Yp7uZl5aUZvWEDXekTOu_J7XdOmu9Ua_8KpHktnJisKAfQOmIBcuKSnbUSO_8Osn-Lm-czdM&typo=1()
        
        # Representation of last head
        sns.heatmap(attentions[-1][-1][-1].detach().numpy(), cmap="viridis")
        
        #https://linkprotect.cudasvc.com/url?a=https%3a%2f%2fplt.show&c=E,1,5C8Ox4WSjVU_mNLppJfD68jgqP51chuDsrERPDIkTbA9oAixP5o-GYw0799hRkOMnkY5FuZJwEXoXT9zELH-3APHdj0ERMh8uHfgUqx0g0FI67Zb67Y,&typo=1()
        plt.show()
        
        attentions = [att.squeeze(0) for att in attentions]
        
        print(f"Shape of attention for layer 0: {attentions[0].shape}")

print(extract_matrices(df_wt))
print(extract_matrices(df_var))

NameError: name 'df_wt' is not defined