# final layer cross-chain attention - by region

This code extracts the final layer cross-chain attention of ft-ESM for 1000 antibodies (and averages by FR and CDR regions) from our test dataset and exports the results as a csv file.

## setup

In [None]:
import os
import torch
import pandas as pd
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer, 
    EsmForMaskedLM
)
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from random import sample

## load model

In [None]:
# replace with actual model path
model_path = './ft-ESM/'

In [None]:
model = EsmForMaskedLM.from_pretrained(model_path)

## load + tokenize data

In [None]:
# replace with actual data path
data_path = './test_dataset_1000seqs.csv'

In [None]:
df_selected = pd.read_csv(data_path)

In [None]:
seqs = list(df_selected['text'])
seq_names = list(df_selected['sequence_id'])
cdrs = list(df_selected['cdr_mask'])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [None]:
tokenized_data = []
for s in tqdm(seqs):
    tokenized_data.append(tokenizer(s, return_tensors='pt'))

In [None]:
inputs = list(zip(seq_names, seqs, tokenized_data, cdrs))

## functions for processing attention

In [None]:
# average heads for cross-chain attention
def avg_heads(cc_attention_df):
    head_dfs = []
    for head in sorted(cc_attention_df["head"].unique()):
        cc_attention_sq = pd.pivot(
            data = cc_attention_df[cc_attention_df["head"] == head], 
            index="position1", 
            columns="position2", 
            values="attention",
        )

        hlen = sum(cc_attention_sq[1].isna())
        hl_sqdf = cc_attention_sq.iloc[:hlen, hlen:]
        lh_sqdf = cc_attention_sq.iloc[hlen:, :hlen].T

        light = hl_sqdf.mean(axis=0)
        heavy = lh_sqdf.mean(axis=1)

        sum_sqdf = pd.concat([heavy, light])

        head_dfs.append(sum_sqdf)
        
    head_dfs = sum(head_dfs) / 20
    
    return head_dfs

In [None]:
# attention by cdr
def atten_by_cdr(seq_id, cdrs, layer_avg):
    cdr_list = [int(bit) for bit in cdrs]
    total_pos = len(cdr_list)

    group_size = 6
    current_group_sum = 0
    current_group_length = 0
    
    group_sums = []
    group_lengths = []
    zero_sums = 0
    
    for bit, value in zip(cdr_list, layer_avg):
        if bit == 1:
            current_group_sum += value
            current_group_length += 1
        elif bit == 0:
            zero_sums += value
    
        if bit == 0 and current_group_sum != 0:
            group_sums.append(current_group_sum)
            group_lengths.append(current_group_length)
            current_group_sum = 0
            current_group_length = 0
    
    if current_group_sum != 0: # check for unfinished group at end
        group_sums.append(current_group_sum)
        group_lengths.append(current_group_length)

    # calculate whole sequence %s
    cdr_seq_perc = sum(group_lengths) / total_pos * 100
    total_atten = sum(group_sums) + zero_sums
    non_cdr_perc = zero_sums / total_atten * 100
    cdr_perc = sum(group_sums) / total_atten * 100

    # normalize groups for ratio: % of total attention / % of total sequence
    avg_non_cdr = (zero_sums * total_pos) / ((total_pos - sum(group_lengths)) * total_atten)
    groups_relative = [(i * total_pos)/ (l * total_atten) for i, l in zip(group_sums, group_lengths)]
    groups = ["H1", "H2", "H3", "L1", "L2", "L3"]

    # reformat results
    data = [[seq_id, total_pos, sum(group_lengths), sum(group_sums), zero_sums, 
             cdr_seq_perc, cdr_perc, non_cdr_perc, avg_non_cdr, groups_relative[0], groups_relative[1], 
             groups_relative[2], groups_relative[3], groups_relative[4], groups_relative[5]]]
    df = pd.DataFrame(data, columns=['Sequence_id',
                                     'Num-Total-Pos', 
                                     'Num-CDR-Pos',
                                     'CDR-atten', 
                                     'Non-CDR-atten',
                                     'CDR-seq%',
                                     'CDR-atten%',
                                     'Non-CDR-atten%',
                                     'Avg_Non_CDR',
                                     "H1", "H2", "H3", "L1", "L2", "L3"])
    return df

## generate and export attention matrix

In [None]:
results = pd.DataFrame(columns=['Sequence_id',
                                'Num-Total-Pos', 
                                'Num-CDR-Pos',
                                'CDR-atten', 
                                'Non-CDR-atten',
                                'CDR-seq%',
                                'CDR-atten%',
                                'Non-CDR-atten%',
                                'Avg_Non_CDR',
                                "H1", "H2", "H3", "L1", "L2", "L3"])

In [None]:
# Set the model to evaluation mode
model.eval()

# Forward pass through the model
with torch.no_grad():
    for name, seq, tokens, cdrs in tqdm(inputs):
        
        outputs = model(
            **tokens, 
            output_attentions=True,
            output_hidden_states=True,
        )
        
        # parse the sequence
        h, l = seq.split('<cls><cls>')
        h_positions = list(range(1, len(h) + 1))
        l_positions = list(range(len(h) + 3, len(h) + 3 + len(l)))
        all_positions = h_positions + l_positions
        
        # Get the attention values for each layer and attention head
        attentions = outputs.attentions
        num_heads = attentions[0].size(1)

        layer = 32
        all_attentions = []
        layer_attentions = attentions[layer] # for last layer only
        for head in range(num_heads): # for each head in that layer
            head_attentions = layer_attentions[0, head]
            for p1 in all_positions:
                for p2 in all_positions:
                    p1_region = "heavy" if p1 in h_positions else "light"
                    p2_region = "heavy" if p2 in h_positions else "light"
                    comp_type = f"intra-{p1_region}" if p1_region == p2_region else "cross-chain"
                    all_attentions.append(
                        {
                            "position1": p1,
                            "position2": p2,
                            "comparison": comp_type,
                            "attention": head_attentions[p1, p2].item(),
                            "layer": layer,
                            "head": head
                        }
                    )
        
        # Convert to dataframe
        attention_df = pd.DataFrame(all_attentions)
        
        # Cross-chain attention by cdr group
        cdrs = cdrs.replace("--","")
        atten_cc = attention_df[attention_df["comparison"] == "cross-chain"]
        layer_avg = avg_heads(atten_cc)
        cc_df = atten_by_cdr(name, cdrs, layer_avg)

        results = pd.concat([results, cc_df], axis=0)

In [None]:
results.to_csv('./attention-results/ft-ESM_1kattention-byregion.csv', index=False)

## plot

In [None]:
res = results.rename(columns = {'Avg_Non_CDR':'FR'})
res = res[['FR', 'H1', 'H2', 'H3', 'L1', 'L2', 'L3']]

In [None]:
mean = res.mean(axis=0)
std = res.std(axis=0)

In [None]:
colors = ['#833f94', '#259c8d', '#259c8d', '#259c8d', '#259c8d', '#259c8d', '#259c8d']

In [None]:
plt.figure(figsize=[4, 3])
plt.bar(
    mean.index, 
    mean, 
    #yerr = std,
    width = 0.8,
    alpha = 0.95,
    color=colors,
)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylabel("Attention Ratio", fontsize=10)
ax.tick_params(axis="both", labelsize=9)
plt.errorbar(mean.index, mean, yerr=std, 
             fmt='.', elinewidth=1.5, markersize=0, capsize=2, color='#59565a')
plt.tick_params(bottom = False)
plt.tight_layout()
plt.savefig("./ft-ESM_CDR-plot.jpg", dpi=300)