# Antibody-antigen binding distribution by different features

In [None]:
!pip3 install biopython pandas torch matplotlib seaborn tqdm

In [None]:
from Bio.PDB import PDBParser
import os, sys
import pandas as pd
import numpy as np
from collections import defaultdict
import torch
import warnings
import time as time
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append('/n/groups/marks/projects/sequence_structure/scripts')
parent_dir = '/n/groups/marks/projects/sequence_structure/sabdab'
sys.path.append(parent_dir)
from pdb_utils import *
from structure_utils import arrange_idx
import torch
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

In [None]:
# antibodies composed of different regions: framework (FR) and complementarity-determining (CDR) regions.
# FRs and CDRs are define as sequences of numbers converted to strings

# FR1 is defined as positions 1 to 26
FR1 = [str(n) for n in list(np.arange(1,26+1))]
CDR1 = [str(n) for n in list(np.arange(27,38+1))]
FR2 = [str(n) for n in list(np.arange(39,55+1))]
CDR2 = [str(n) for n in list(np.arange(56,65+1))]
FR3 = [str(n) for n in list(np.arange(66,104+1))]
# CDR3 has additional specific positions indicating insertions in the antibody sequence
CDR3 = [str(n) for n in list(np.arange(105,117+1))] + ["111A","111B","111C","111D","111E","111F","112A","112B","112C","112D","112E","112F","112G"]
FR4 = [str(n) for n in list(np.arange(118,129+1))]

FRs = FR1 + FR2 + FR3 + FR4
CDRs = CDR1 + CDR2 + CDR3

In [None]:
df = pd.DataFrame([["FR1",FR1],
                   ["CDR1",CDR1],
                   ["FR2",FR2],
                   ["CDR2",CDR2],
                   ["FR3",FR3],
                   ["CDR3",CDR3],
                   ["FR4",FR4]],
                  columns=["ab_loc","imgt_pos"])

# dictionary for quick lookup of region type (FR or CDR) based on a given position
df = df.set_index("ab_loc")
ab_loc_dict = dict(zip(df["imgt_pos"].explode(),df["imgt_pos"].explode().index))

In [None]:
# standardizing the format of unique identifiers in the dataset

def fix_uq_id(s):
    exp = s.str.split('_', expand=True)
    exp_u = exp.iloc[:,1:].apply(lambda x: x.str.upper())
    sfix = exp[[0]].join(exp_u)
    sfixed = sfix.apply('_'.join, axis=1)
    return (sfixed)

In [None]:
# import and preprocess dataset of antibody interactions, focusing on H chain type ond date 2022_12_21

#import the list of interaction summary
DATE = '2022_12_21'
chain_type = 'H'

pdb_repo = os.path.join(parent_dir, f'{DATE}/sabdab_dataset')
int_fnm = f'{DATE}/summary_tables/{chain_type}_fil_contact_interactions.csv'
int_df = pd.read_csv(os.path.join(parent_dir, int_fnm))
int_df['uq_id'] = fix_uq_id(int_df['uq_id'])
int_df = int_df.set_index('uq_id')


Construct Profile

In [None]:
def _construct_profile(tup):
    pdb_repo, pdb, chain, antchain, chain_type = tup
    uq_id = f'{pdb}_{chain_type}({chain})_ant({antchain})'
    #load the contacts file
    fnm = os.path.join(pdb_repo, pdb, f'{chain_type}({chain})_ant({antchain})_contact_cnts.pt')
    if os.path.isfile(fnm):
        cnts_ld = torch.load(fnm)
        #sum each count matrix for each imgt in the count file
        cnt_num = {k:np.sum(v) for k,v in cnts_ld.items()}
        cnt_num = pd.Series(cnt_num, dtype='float64')
        cnt_num['uq_id'] = uq_id
        return (cnt_num)
    else:
        print (f'{fnm} does not have contacts counted')
        return (None)
    
#set up an iterator for easy running
N = int_df.shape[0]
iterable = zip([pdb_repo]*N, int_df['pdb'], int_df[f'{chain_type}chain'], int_df['antigen_chain'], [chain_type]*N)
    
construct_list = process_map(_construct_profile, iterable)
construct_list = [c for c in construct_list if c is not None]
c_df = pd.DataFrame(construct_list).fillna(0).set_index('uq_id')

#order it correctly
imgt_order, uq_imgt_df, imgt_ii_key, ii_imgt_key = arrange_idx(c_df.columns, giveorder=True)
c_df = c_df.loc[:, imgt_order]

### Get the depth of each column in the alignment

align_df = pd.read_csv(os.path.join(parent_dir, f'{DATE}/summary_tables/summary_{chain_type}_sabdab_alignment.csv'))
align_df['uq_id'] = fix_uq_id(align_df['uq_id'])
#sub out to the structures that are in the parsed dataset
s_align_df = align_df[align_df['uq_id'].isin(int_df.index)]

#calculate the depth of the alignment at each column
a_df = s_align_df.iloc[:, s_align_df.columns.str.match('\d+')]
align_depth = (a_df != '-').sum(0)
#remove cols that we don't have counts for some reason
align_depth_k = align_depth[pd.Series(align_depth.index).isin(uq_imgt_df.imgt).values]
#normalize depth
align_depth_km = align_depth_k / a_df.shape[0]

#sub out the matrix for which we have numbering in the sequence alignment
indexer = uq_imgt_df.reset_index(drop=False)
indexer = indexer.set_index('imgt')
keep_idx = indexer.loc[align_depth_k.index, 'index']

### Remove columns with low depth

d_tr = 0.01 #depth threshold - changed, AH
#sub out and normalize the matrix
new_idx = keep_idx[align_depth_km > d_tr]
cfil_df = c_df.loc[:,new_idx.index]
cbin_df = cfil_df.copy()
cbin_df[cbin_df > 0] = 1.0

In [None]:
#compile the metadata features to look at

#import all the metadata for the interactions
sum_fnm = f'{DATE}/sabdab_summary_all_{DATE}.tsv'
sum_df = pd.read_csv(os.path.join(parent_dir, sum_fnm), sep='\t')
sum_df = sum_df.fillna('')
sum_df['uq_id'] = sum_df['pdb'] + '_' + chain_type + '(' + sum_df[f'{chain_type}chain'] + ')_ant(' + sum_df['antigen_chain'] + ')'
sum_df = sum_df.set_index('uq_id')
sum_df = sum_df[~sum_df.index.duplicated()]

#import the antibody types
atype_fnm = os.path.join(parent_dir, 'recourses/metadata_files/Atype_summary.tsv')
atype_df = pd.read_csv(atype_fnm).fillna('')
atype_df['uq_id'] = atype_df['pdb'] + '_' + chain_type + '(' + atype_df[f'{chain_type}chain'] + ')_ant(' + atype_df['antigen_chain'] + ')'
atype_df['full_id'] = atype_df['pdb'] + '_' + atype_df['Hchain'] + '_' + atype_df['Lchain'] + '_' + atype_df['antigen_chain']
atype_df = atype_df.set_index('uq_id')
atype_df = atype_df[~atype_df.index.duplicated()]

#compile with the interaction ordering
meta_cols = ['antigen_type', 'heavy_species', 'light_species', 'heavy_subclass',
             'light_subclass', 'light_ctype']

met_df = sum_df.loc[cbin_df.index, meta_cols]
met_df['antibody_type'] = atype_df.reindex(met_df.index).loc[:,'antibody_type']

Distribution of Contact Counts

In [None]:
sns.set_theme(style='whitegrid', font_scale=1.2)
sns.histplot(x=cbin_df.sum(axis=1),
             binwidth=1, kde=True, color="b")

plt.xlabel("Number of Antigen-Contacting Positions")
plt.ylabel(f"Number of V{chain_type} Complexes")
plt.tight_layout()
plt.savefig(f"/n/groups/marks/projects/public_epitope/outputs/figures/dist_all_contacts-{chain_type}.png")

In [None]:
sns.set_theme(style='whitegrid', font_scale=1.2)
sns.histplot(x=cbin_met_df[cbin_met_df["antibody_type"] == "Fv"]["FR_contacts"],
             binwidth=1, stat="density", kde=True, color="b", label="Fv", alpha=0.5)
sns.histplot(x=cbin_met_df[cbin_met_df["antibody_type"] == "VHH"]["FR_contacts"],
             binwidth=1, stat="density", kde=True, color="r", label="VHH", alpha=0.5)
sns.histplot(x=cbin_met_df[cbin_met_df["antibody_type"] == "scFv"]["FR_contacts"],
             binwidth=1, stat="density", kde=True, color="g", label="scFv", alpha=0.5)
plt.xlabel("Number of Antigen-Contacting Framework Positions")
plt.ylabel(f"Density of V{chain_type} Complexes")
plt.legend()
plt.tight_layout()
plt.savefig(f"/n/groups/marks/projects/public_epitope/outputs/figures/dist_fr_contacts-{chain_type}.png")

Distribution of Framework Region (FR) Contacts

In [None]:
cbin_met_df = met_df.merge(cbin_df, on="uq_id")
cbin_met_df["FR_contacts"] = cbin_met_df[[pos for pos in FRs if pos in cbin_df.columns]].sum(axis=1)

In [None]:
print(cbin_met_df[cbin_met_df["antibody_type"] == "Fv"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "Fv"]["FR_contacts"].median())
# print(cbin_met_df[cbin_met_df["antibody_type"] == "VHH"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "VHH"]["FR_contacts"].median())
# print(cbin_met_df[cbin_met_df["antibody_type"] == "scFv"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "scFv"]["FR_contacts"].median())

In [None]:
n = 2
print(cbin_met_df[(cbin_met_df["antibody_type"] == "Fv") & (cbin_met_df["FR_contacts"] <= n)].shape[0] / cbin_met_df[(cbin_met_df["antibody_type"] == "Fv")].shape[0])
# print(cbin_met_df[(cbin_met_df["antibody_type"] == "scFv") & (cbin_met_df["FR_contacts"] <= n)].shape[0] / cbin_met_df[(cbin_met_df["antibody_type"] == "scFv")].shape[0])

Most Common Interacting Residues

In [None]:
cbin_df_contact_counts = cbin_df.sum().sort_values(ascending=False)
cbin_df_contact_counts = pd.DataFrame(cbin_df_contact_counts, columns=["contact_counts"])

cbin_df_contact_counts["ab_loc"] = ""
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR1), "ab_loc"] = "FR1"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR1), "ab_loc"] = "CDR1"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR2), "ab_loc"] = "FR2"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR2), "ab_loc"] = "CDR2"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR3), "ab_loc"] = "FR3"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR3), "ab_loc"] = "CDR3"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR4), "ab_loc"] = "FR4"

cbin_df_contact_counts_noncdr3 = cbin_df_contact_counts[~cbin_df_contact_counts.index.isin(CDR3)]
cbin_df_contact_counts_noncdr3_gt_avg = cbin_df_contact_counts_noncdr3[cbin_df_contact_counts_noncdr3["contact_counts"] > cbin_df_contact_counts_noncdr3["contact_counts"].mean()]

# cbin_df_contact_counts_noncdr3_gt_avg[cbin_df_contact_counts_noncdr3_gt_avg["ab_loc"].str[:3] != "CDR"]

Residues at Most Common Positions

In [None]:
def get_resid(tup, imgt_pos):
    pdb_repo, pdb, chain, antchain, chain_type = tup
    
    fnm = os.path.join(pdb_repo, pdb, f'{chain_type}({chain})_ant({antchain})_contact_details.csv')
    int_details = pd.read_csv(fnm, index_col=0)
    # print(int_details)

    imgt_pos_aa = int_details[(int_details["PDB"] == pdb) & (int_details["chain_type"] == chain_type) &
                                (int_details["antibody_chain"] == chain) &
                                ((int_details["antibody_idx"] == float(imgt_pos)) | (int_details["antibody_idx"] == str(imgt_pos)))]["antibody_AA"].unique()

    
    assert len(imgt_pos_aa) <= 1
    if len(imgt_pos_aa) == 1: imgt_pos_aa = imgt_pos_aa[0]
    else: imgt_pos_aa = None

    return imgt_pos_aa

In [None]:
"""
the imgt_pos_aas_all_df DataFrame is a table that contains information about the amino acids (abbreviated as "aa") 
found at specific positions (abbreviated as "imgt_pos") within antibody-antigen complexes. 
Each row in this DataFrame represents one specific position in an interaction, and it provides the following information:

"imgt_pos": The position within the interaction where the amino acid is located.
"uq_id": A unique identifier for the antibody-antigen interaction, specifying the dataset location, PDB file, antibody chain, antigen chain, and chain type.
"aa": The amino acid present at the specified position in the interaction.
"""
imgt_pos_aas_all = []

for imgt_pos in tqdm(cbin_df_contact_counts_noncdr3_gt_avg.index, total=len(cbin_df_contact_counts_noncdr3_gt_avg.index)):

    for ind, row in int_df.iterrows():
        tup = ('/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset', row["pdb"], row["Hchain"], row["antigen_chain"], chain_type)
        pdb_repo, pdb, chain, antchain, chain_type = tup
        uq_id = f'{pdb}_{chain_type}({chain})_ant({antchain})'
        try:
            imgt_pos_aa = get_resid(tup, imgt_pos)
            imgt_pos_aas_all.append([imgt_pos, uq_id, imgt_pos_aa])
        except:
            pass

In [None]:
imgt_pos_aas_all_df = pd.DataFrame(imgt_pos_aas_all, columns=["imgt_pos","uq_id","aa"])
imgt_pos_aas_all_df[~imgt_pos_aas_all_df["aa"].isnull()]

In [None]:
imgt_pos_w_high_contact_count_conservation = []

for imgt_pos in imgt_pos_aas_all_df["imgt_pos"].drop_duplicates():
    pos_df = imgt_pos_aas_all_df[imgt_pos_aas_all_df["imgt_pos"] == imgt_pos]

    imgt_pos_w_high_contact_count_conservation.append([imgt_pos, cbin_df_contact_counts_noncdr3_gt_avg.loc[imgt_pos]["ab_loc"],
                                                       cbin_df_contact_counts_noncdr3_gt_avg.loc[imgt_pos]["contact_counts"], pos_df["aa"].value_counts(normalize=True).iloc[0],
                                                       pos_df["aa"].value_counts(normalize=True).index[0]])

In [None]:
imgt_pos_w_high_contact_count_conservation_df = pd.DataFrame(imgt_pos_w_high_contact_count_conservation, columns=["imgt_pos","ab_loc","contact_counts","conservation_highest_pct","conservation_highest_aa"])

In [None]:
imgt_pos_w_high_contact_count_conservation_df.sort_values(by=["contact_counts"], ascending=False)

Distribution of Non-CDR3 and CDR3 Contacts by Frequency

In [None]:
# plot individual complexes
# first, get individual contact counts per complex and count that number for each complex
# then, filter those counts for CDR3
# then plot

In [None]:
all_complexes = cbin_df.to_numpy()

noncdr3_counts_per_complex = []
cdr3_counts_per_complex = []

# filter out CDR3, since we're interested in non-CDR3 residues
non_cdr3_df = cbin_df_contact_counts[cbin_df_contact_counts['ab_loc'] != 'CDR3']

# get non-CDR3 positions
arr_noncdr3_pos = non_cdr3_df.index.to_numpy()
arr_noncdr3_pos = [int(el) for el in arr_noncdr3_pos]

for row in all_complexes:
    non_cdr3_counts = 0
    cdr3_counts = 0
    for i in range(len(row)):
        # if the position is a noncdr3 position
        if i in arr_noncdr3_pos:
            non_cdr3_counts += row[i] 
        else:
            cdr3_counts += row[i] 

    noncdr3_counts_per_complex.append(non_cdr3_counts)
    cdr3_counts_per_complex.append(cdr3_counts)

In [None]:
# Assuming cbin_df has complex IDs as index
all_complexes = cbin_df.to_numpy()
complex_ids = cbin_df.index

noncdr3_counts_per_complex = []
cdr3_counts_per_complex = []

# filter out CDR3, since we're interested in non-CDR3 residues
non_cdr3_df = cbin_df_contact_counts[cbin_df_contact_counts['ab_loc'] != 'CDR3']

# get non-CDR3 positions
arr_noncdr3_pos = non_cdr3_df.index.to_numpy()
arr_noncdr3_pos = [int(el) for el in arr_noncdr3_pos]

# Use a list of tuples to store complex ID with its counts
noncdr3_counts_with_ids = []
cdr3_counts_with_ids = []

for idx, row in enumerate(all_complexes):
    non_cdr3_counts = 0
    cdr3_counts = 0
    for i in range(len(row)):
        if i in arr_noncdr3_pos:
            non_cdr3_counts += row[i]
        else:
            cdr3_counts += row[i]

    # Append the complex ID and count as a tuple
    noncdr3_counts_with_ids.append((complex_ids[idx], non_cdr3_counts))
    cdr3_counts_with_ids.append((complex_ids[idx], cdr3_counts))

# convert to DataFrames
noncdr3_df_with_ids = pd.DataFrame(noncdr3_counts_with_ids, columns=['Complex_ID', 'NonCDR3_Count'])
cdr3_df_with_ids = pd.DataFrame(cdr3_counts_with_ids, columns=['Complex_ID', 'CDR3_Count'])

noncdr3_df_with_ids

In [None]:
# grabbing pdb ids with the highest number of counts

# do this by sorting=
noncdr3_df_sorted = noncdr3_df_with_ids.sort_values(by='NonCDR3_Count', ascending=False)

cdr3_df_sorted = cdr3_df_with_ids.sort_values(by='CDR3_Count', ascending=False)

# get the top 10 entries for noncdr3 counts
top10_noncdr3 = noncdr3_df_sorted.head(10)

# get the top 10 entries for cdr3 counts
top10_cdr3 = cdr3_df_sorted.head(10)

Distribution of Non-CDR3 and CDR3 Contacts by Position

In [None]:
# get contact counts for each position
contacts_per_position = []

for i in range(1, all_complexes.shape[1]):

    col = all_complexes[:i]
    contacts_per_position.append(sum(col[0]))

In [None]:
cbin_df.drop_duplicates()


# for each column (postition), add the number to the array if present
positions = [str(i) for i in range(cbin_df.shape[0]) if str(i) in cbin_df]
noncdr3_contact_positions = []
cdr3_contact_positions = []

# iterate over columns
for i in range(cbin_df.shape[0]):

    if str(i) in cbin_df:
        col = cbin_df[str(i)]

        for el in col:
            # if there exists a contact, add to the contact positions array
            if el == float(1):

                # check if cdr3 or not 
                if i in arr_noncdr3_pos:
                    noncdr3_contact_positions.append(i)
                else:
                    cdr3_contact_positions.append(i)
        

#[cbin_df[str(i)].sum() for i in range(cbin_df.shape[0]) if str(i) in cbin_df]


In [None]:
# plot the distribution of contacts per non CDR3 position
plt.figure(figsize=(10, 6))
sns.histplot(x=cdr3_contact_positions, kde=True, color="b")

plt.xlabel("Position")
plt.ylabel("Number of Antigen Contacts") # across all complexes
plt.title("Distribution of Antigen-Contacting Non-CDR3 Residue Positions")
plt.tight_layout()
plt.savefig("/n/groups/marks/projects/public_epitope/outputs/figures/distribution_antigen_contacting_noncdr3_by_position.png")
plt.show()

# plot the distribution of contacts per CDR3 position
plt.figure(figsize=(10, 6))
sns.histplot(x=noncdr3_contact_positions, kde=True, color="b")

plt.xlabel("Position")
plt.ylabel("Number of Antigen Contacts")
plt.title("Distribution of Antigen-Contacting CDR3 Residue Positions")
plt.tight_layout()
plt.savefig("/n/groups/marks/projects/public_epitope/outputs/figures/distribution_antigen_contacting_cdr3_by_position.png")
plt.show()