# extracting cross-chain attention

This code generates the attention matrix for 5 therapeutic monoclonal antibodies (found in the therapetuic-mAbs.csv file) and exports the results as a csv file for each antibody.

The resulting csv files can be read into the Plot-Cross-Chain-Attention.ipynb file to plot the attention values.

### setup

In [None]:
import os
import torch
import pandas as pd
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM
)

### load model

In [None]:
%%bash
# download the fine-tuned ESM-2 model from zenodo
if [ ! -d "./models/ESM2-650M_paired-fine-tuning" ]; then
    curl -o 'ESM2-650M_paired-fine-tuned.tar.gz' -L 'https://zenodo.org/record/8253367/files/ESM2-650M_paired-fine-tuned.tar.gz?download=1'
    tar xzvf 'ESM2-650M_paired-fine-tuned.tar.gz' -C ./models
    rm 'ESM2-650M_paired-fine-tuned.tar.gz'
fi

In [None]:
model = AutoModelForMaskedLM.from_pretrained(
    './models/ESM2-650M_paired-fine-tuning/'
).to('cuda')

If you want to load the 650M parameter ESM-2 model prior to fine-tuning instead, uncomment the following code.

In [None]:
# model = AutoModelForMaskedLM.from_pretrained(
#    "facebook/esm2_t33_650M_UR50D"
# ).to('cuda')

### load + tokenize data

In [None]:
# Load therapeutic antibody sequences
df = pd.read_csv('./therapeutic-mAbs.csv')
seq_df = df[["Therapeutic", "Heavy Sequence", "Light Sequence"]].set_index("Therapeutic")

# Concat heavy and light chain sequences
seqs = []
for h, l in zip(seq_df['Heavy Sequence'], seq_df['Light Sequence']):
    seqs.append("{}<cls><cls>{}".format(h, l))
seq_names = list(seq_df.index.values)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [None]:
tokenized_data = []
for s in tqdm(seqs):
    tokenized_data.append(tokenizer(s, return_tensors='pt').to('cuda'))
    
i = {'input_ids': [t['input_ids'] for t in tokenized_data],
     'attention_mask': [t['attention_mask'] for t in tokenized_data]}

# finalize inputs
inputs = list(zip(seq_names, seqs, tokenized_data))

## generate and export attention matrix

In [None]:
# Set the model to evaluation mode
model.eval()

# Forward pass through the model
with torch.no_grad():
    for name, seq, tokens in tqdm(inputs):
        print(f"Input: {name}")
        
        outputs = model(
            **tokens, 
            output_attentions=True,
            output_hidden_states=True,
        )
        
        # parse the sequence
        h, l = seq.split('<cls><cls>')
        h_positions = list(range(1, len(h) + 1))
        l_positions = list(range(len(h) + 2, len(h) + 2 + len(l)))
        all_positions = h_positions + l_positions
        
        # Get the attention values for each layer and attention head
        attentions = outputs.attentions
        num_layers = len(attentions)
        num_heads = attentions[0].size(1)

        # Extract attention values for each attention head in every layer
        all_attentions = []
        for layer in tqdm(range(num_layers)[:]): #for each layer
            layer_attentions = attentions[layer]
            
            for head in range(num_heads): #for each head in that layer
                head_attentions = layer_attentions[0, head]
                for p1 in all_positions:
                    for p2 in all_positions:
                        p1_region = "heavy" if p1 in h_positions else "light"
                        p2_region = "heavy" if p2 in h_positions else "light"
                        comp_type = f"intra-{p1_region}" if p1_region == p2_region else "cross-chain"
                        all_attentions.append(
                            {
                                "position1": p1,
                                "position2": p2,
                                "comparison": comp_type,
                                "attention": head_attentions[p1, p2].item(),
                                "layer": layer,
                                "head": head
                            }
                        )
        
        # Convert to dataframe
        attention_df = pd.DataFrame(all_attentions)
        
        # Export to csv
        attention_df.to_csv(f'./attention-results/{name}.csv', index=False)