# final layer cross-chain attention - average

This code extracts the final layer cross-chain attention of BALM-paired for 1000 antibodies from our test dataset and exports the results as a csv file.

## setup

In [None]:
import os
import torch
import pandas as pd
from tqdm.notebook import tqdm
from transformers import (
    RobertaTokenizer, 
    RobertaForMaskedLM
)
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from random import sample

## load model

In [None]:
# replace with actual model path
model_path = './BALM-paired/'

In [None]:
model = RobertaForMaskedLM.from_pretrained(model_path)

## load + tokenize data

In [None]:
# replace with actual data path
data_path = './test_dataset_1000seqs.csv'

In [None]:
df_selected = pd.read_csv(data_path)
df_selected['text'] = df_selected['text'].str.replace('<cls><cls>', '</s>')

In [None]:
seqs = list(df_selected['text'])
seq_names = list(df_selected['sequence_id'])
cdrs = list(df_selected['cdr_mask'])

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("../tokenizer")

In [None]:
tokenized_data = []
for s in tqdm(seqs):
    tokenized_data.append(tokenizer(s, return_tensors='pt'))

In [None]:
# finalize inputs
inputs = list(zip(seq_names, seqs, tokenized_data, cdrs))

## functions for processing attention

In [None]:
def avg_heads(cc_attention_df):
    head_dfs = []
    for head in sorted(cc_attention_df["head"].unique()):
        cc_attention_sq = pd.pivot(
            data = cc_attention_df[cc_attention_df["head"] == head], 
            index="position1", 
            columns="position2", 
            values="attention",
        )

        hlen = sum(cc_attention_sq[1].isna())
        hl_sqdf = cc_attention_sq.iloc[:hlen, hlen:]
        lh_sqdf = cc_attention_sq.iloc[hlen:, :hlen].T

        light = hl_sqdf.mean(axis=0) # attention to the light chain (from the heavy)
        heavy = lh_sqdf.mean(axis=1) # attention to the heavy chain (from the light) 

        sum_sqdf = pd.concat([heavy, light])

        head_dfs.append(sum_sqdf)
        
    head_dfs = sum(head_dfs) / 16
    
    return head_dfs

In [None]:
def calculate_attention(seq_id, cdrs, attention_by_pos):
    count_non_cdr = 0.0
    count_cdr = 0.0
    num_cdr_pos = 0
    total_pos = 0
    for cdr, num in zip(cdrs, attention_by_pos):
        total_pos += 1
        if float(cdr) == 1:
            count_cdr += num
            num_cdr_pos += 1
        else:
            count_non_cdr += num
            
    total_atten = count_cdr + count_non_cdr
    cdr_perc = count_cdr / total_atten * 100
    non_cdr_perc = count_non_cdr / total_atten * 100
    
    cdr_seq_perc = num_cdr_pos / total_pos * 100
    
    data = [[seq_id, total_pos, num_cdr_pos, count_cdr, count_non_cdr, cdr_seq_perc, cdr_perc, non_cdr_perc]]
    df = pd.DataFrame(data, columns=['Sequence_id',
                                     'Num-Total-Pos', 
                                     'Num-CDR-Pos',
                                     'CDR-atten', 
                                     'Non-CDR-atten',
                                     'CDR-seq%',
                                     'CDR-atten%',
                                     'Non-CDR-atten%'])
    return df

## generate and export attention matrix

In [None]:
results = pd.DataFrame(columns=['Sequence_id',
                                'Num-Total-Pos', 
                                'Num-CDR-Pos',
                                'CDR-atten', 
                                'Non-CDR-atten',
                                'CDR-seq%',
                                'CDR-atten%',
                                'Non-CDR-atten%'])

In [None]:
# Set the model to evaluation mode
model.eval()

# Forward pass through the model
with torch.no_grad():
    for name, seq, tokens, cdrs in tqdm(inputs):
        print(f"Input: {name}")
        
        outputs = model(
            **tokens, 
            output_attentions=True,
            output_hidden_states=True,
        )
        
        # parse the sequence
        h, l = seq.split('</s>')
        h_positions = list(range(1, len(h) + 1))
        l_positions = list(range(len(h) + 2, len(h) + 2 + len(l)))
        all_positions = h_positions + l_positions
        
        # Get the attention values for each layer and attention head
        attentions = outputs.attentions
        num_heads = attentions[0].size(1)

        layer = 23
        all_attentions = []
        layer_attentions = attentions[layer] # for last layer only
        for head in tqdm(range(num_heads)): # for each head in that layer
            head_attentions = layer_attentions[0, head]
            for p1 in all_positions:
                for p2 in all_positions:
                    p1_region = "heavy" if p1 in h_positions else "light"
                    p2_region = "heavy" if p2 in h_positions else "light"
                    comp_type = f"intra-{p1_region}" if p1_region == p2_region else "cross-chain"
                    all_attentions.append(
                        {
                            "position1": p1,
                            "position2": p2,
                            "comparison": comp_type,
                            "attention": head_attentions[p1, p2].item(),
                            "layer": layer,
                            "head": head
                        }
                    )
        
        # Convert to dataframe
        attention_df = pd.DataFrame(all_attentions)
        
        
        # Cross-chain attention
        cdrs = cdrs.replace("--","")
        atten_cc = attention_df[attention_df["comparison"] == "cross-chain"]
        layer_avg = avg_heads(atten_cc)
        cc_df = calculate_attention(name, cdrs, layer_avg)

        results = pd.concat([results, cc_df], axis=0)

In [None]:
results.mean()

In [None]:
results.std()

In [None]:
results.to_csv('./attention-results/BALM-paired_1kattention-results.csv', index=False)