In [2]:
!pip3 install biopython pandas torch matplotlib seaborn tqdm



In [1]:
from Bio.PDB import PDBParser
import os, sys
import pandas as pd
import numpy as np
from collections import defaultdict
import torch
import warnings
import time as time
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append('/n/groups/marks/projects/sequence_structure/scripts')
parent_dir = '/n/groups/marks/projects/sequence_structure/sabdab'
sys.path.append(parent_dir)
from pdb_utils import *
from structure_utils import arrange_idx
import torch
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

In [2]:
# antibodies composed of different regions: framework (FR) and complementarity-determining (CDR) regions.
# FRs and CDRs are define as sequences of numbers converted to strings

# FR1 is defined as positions 1 to 26
FR1 = [str(n) for n in list(np.arange(1,26+1))]
CDR1 = [str(n) for n in list(np.arange(27,38+1))]
FR2 = [str(n) for n in list(np.arange(39,55+1))]
CDR2 = [str(n) for n in list(np.arange(56,65+1))]
FR3 = [str(n) for n in list(np.arange(66,104+1))]
# CDR3 has additional specific positions indicating insertions in the antibody sequence
CDR3 = [str(n) for n in list(np.arange(105,117+1))] + ["111A","111B","111C","111D","111E","111F","112A","112B","112C","112D","112E","112F","112G"]
FR4 = [str(n) for n in list(np.arange(118,129+1))]

FRs = FR1 + FR2 + FR3 + FR4
CDRs = CDR1 + CDR2 + CDR3

In [3]:
df = pd.DataFrame([["FR1",FR1],
                   ["CDR1",CDR1],
                   ["FR2",FR2],
                   ["CDR2",CDR2],
                   ["FR3",FR3],
                   ["CDR3",CDR3],
                   ["FR4",FR4]],
                  columns=["ab_loc","imgt_pos"])

# dictionary for quick lookup of region type (FR or CDR) based on a given position
df = df.set_index("ab_loc")
ab_loc_dict = dict(zip(df["imgt_pos"].explode(),df["imgt_pos"].explode().index))

In [4]:
# standardizing the format of unique identifiers in the dataset

def fix_uq_id(s):
    exp = s.str.split('_', expand=True)
    exp_u = exp.iloc[:,1:].apply(lambda x: x.str.upper())
    sfix = exp[[0]].join(exp_u)
    sfixed = sfix.apply('_'.join, axis=1)
    return (sfixed)

In [5]:
# import and preprocess dataset of antibody interactions, focusing on H chain type ond date 2022_12_21

#import the list of interaction summary
DATE = '2022_12_21'
chain_type = 'H'

pdb_repo = os.path.join(parent_dir, f'{DATE}/sabdab_dataset')
int_fnm = f'{DATE}/summary_tables/{chain_type}_fil_contact_interactions.csv'
int_df = pd.read_csv(os.path.join(parent_dir, int_fnm))
int_df['uq_id'] = fix_uq_id(int_df['uq_id'])
int_df = int_df.set_index('uq_id')


In [6]:
# investigating int_df
int_df
int_df.drop_duplicates().shape

(2088, 14)

In [7]:
def _construct_profile(tup):
    pdb_repo, pdb, chain, antchain, chain_type = tup
    uq_id = f'{pdb}_{chain_type}({chain})_ant({antchain})'
    #load the contacts file
    fnm = os.path.join(pdb_repo, pdb, f'{chain_type}({chain})_ant({antchain})_contact_cnts.pt')
    if os.path.isfile(fnm):
        cnts_ld = torch.load(fnm)
        #sum each count matrix for each imgt in the count file
        cnt_num = {k:np.sum(v) for k,v in cnts_ld.items()}
        cnt_num = pd.Series(cnt_num, dtype='float64')
        cnt_num['uq_id'] = uq_id
        return (cnt_num)
    else:
        print (f'{fnm} does not have contacts counted')
        return (None)
    
#set up an iterator for easy running
N = int_df.shape[0]
iterable = zip([pdb_repo]*N, int_df['pdb'], int_df[f'{chain_type}chain'], int_df['antigen_chain'], [chain_type]*N)
    
construct_list = process_map(_construct_profile, iterable)
construct_list = [c for c in construct_list if c is not None]
c_df = pd.DataFrame(construct_list).fillna(0).set_index('uq_id')

#order it correctly
imgt_order, uq_imgt_df, imgt_ii_key, ii_imgt_key = arrange_idx(c_df.columns, giveorder=True)
c_df = c_df.loc[:, imgt_order]

### Get the depth of each column in the alignment

align_df = pd.read_csv(os.path.join(parent_dir, f'{DATE}/summary_tables/summary_{chain_type}_sabdab_alignment.csv'))
align_df['uq_id'] = fix_uq_id(align_df['uq_id'])
#sub out to the structures that are in the parsed dataset
s_align_df = align_df[align_df['uq_id'].isin(int_df.index)]

#calculate the depth of the alignment at each column
a_df = s_align_df.iloc[:, s_align_df.columns.str.match('\d+')]
align_depth = (a_df != '-').sum(0)
#remove cols that we don't have counts for some reason
align_depth_k = align_depth[pd.Series(align_depth.index).isin(uq_imgt_df.imgt).values]
#normalize depth
align_depth_km = align_depth_k / a_df.shape[0]

#sub out the matrix for which we have numbering in the sequence alignment
indexer = uq_imgt_df.reset_index(drop=False)
indexer = indexer.set_index('imgt')
keep_idx = indexer.loc[align_depth_k.index, 'index']

### Remove columns with low depth

d_tr = 0.01 #depth threshold - changed, AH
#sub out and normalize the matrix
new_idx = keep_idx[align_depth_km > d_tr]
cfil_df = c_df.loc[:,new_idx.index]
cbin_df = cfil_df.copy()
cbin_df[cbin_df > 0] = 1.0

/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset/7mtb/H(H)_ant(G)_contact_cnts.pt does not have contacts counted


0it [00:00, ?it/s]

/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset/6pi7/H(F)_ant(D)_contact_cnts.pt does not have contacts counted
/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset/3ehb/H(C)_ant(B)_contact_cnts.pt does not have contacts counted


In [8]:
#compile the metadata features to look at

#import all the metadata for the interactions
sum_fnm = f'{DATE}/sabdab_summary_all_{DATE}.tsv'
sum_df = pd.read_csv(os.path.join(parent_dir, sum_fnm), sep='\t')
sum_df = sum_df.fillna('')
sum_df['uq_id'] = sum_df['pdb'] + '_' + chain_type + '(' + sum_df[f'{chain_type}chain'] + ')_ant(' + sum_df['antigen_chain'] + ')'
sum_df = sum_df.set_index('uq_id')
sum_df = sum_df[~sum_df.index.duplicated()]

#import the antibody types
atype_fnm = os.path.join(parent_dir, 'recourses/metadata_files/Atype_summary.tsv')
atype_df = pd.read_csv(atype_fnm).fillna('')
atype_df['uq_id'] = atype_df['pdb'] + '_' + chain_type + '(' + atype_df[f'{chain_type}chain'] + ')_ant(' + atype_df['antigen_chain'] + ')'
atype_df['full_id'] = atype_df['pdb'] + '_' + atype_df['Hchain'] + '_' + atype_df['Lchain'] + '_' + atype_df['antigen_chain']
atype_df = atype_df.set_index('uq_id')
atype_df = atype_df[~atype_df.index.duplicated()]

#compile with the interaction ordering
meta_cols = ['antigen_type', 'heavy_species', 'light_species', 'heavy_subclass',
             'light_subclass', 'light_ctype']

met_df = sum_df.loc[cbin_df.index, meta_cols]
met_df['antibody_type'] = atype_df.reindex(met_df.index).loc[:,'antibody_type']

In [9]:
cbin_met_df = met_df.merge(cbin_df, on="uq_id")
cbin_met_df["FR_contacts"] = cbin_met_df[[pos for pos in FRs if pos in cbin_df.columns]].sum(axis=1)

In [10]:
print(cbin_met_df[cbin_met_df["antibody_type"] == "Fv"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "Fv"]["FR_contacts"].median())
# print(cbin_met_df[cbin_met_df["antibody_type"] == "VHH"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "VHH"]["FR_contacts"].median())
# print(cbin_met_df[cbin_met_df["antibody_type"] == "scFv"]["FR_contacts"].mean(), cbin_met_df[cbin_met_df["antibody_type"] == "scFv"]["FR_contacts"].median())

1.9124260355029585 1.0


In [11]:
cbin_df_contact_counts = cbin_df.sum().sort_values(ascending=False)
cbin_df_contact_counts = pd.DataFrame(cbin_df_contact_counts, columns=["contact_counts"])

cbin_df_contact_counts["ab_loc"] = ""
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR1), "ab_loc"] = "FR1"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR1), "ab_loc"] = "CDR1"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR2), "ab_loc"] = "FR2"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR2), "ab_loc"] = "CDR2"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR3), "ab_loc"] = "FR3"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(CDR3), "ab_loc"] = "CDR3"
cbin_df_contact_counts.loc[cbin_df_contact_counts.index.isin(FR4), "ab_loc"] = "FR4"

cbin_df_contact_counts_noncdr3 = cbin_df_contact_counts[~cbin_df_contact_counts.index.isin(CDR3)]
cbin_df_contact_counts_noncdr3_gt_avg = cbin_df_contact_counts_noncdr3[cbin_df_contact_counts_noncdr3["contact_counts"] > cbin_df_contact_counts_noncdr3["contact_counts"].mean()]

# cbin_df_contact_counts_noncdr3_gt_avg[cbin_df_contact_counts_noncdr3_gt_avg["ab_loc"].str[:3] != "CDR"]

In [12]:
def get_resid(tup, imgt_pos):
    pdb_repo, pdb, chain, antchain, chain_type = tup
    
    fnm = os.path.join(pdb_repo, pdb, f'{chain_type}({chain})_ant({antchain})_contact_details.csv')
    int_details = pd.read_csv(fnm, index_col=0)
    # print(int_details)

    imgt_pos_aa = int_details[(int_details["PDB"] == pdb) & (int_details["chain_type"] == chain_type) &
                                (int_details["antibody_chain"] == chain) &
                                ((int_details["antibody_idx"] == float(imgt_pos)) | (int_details["antibody_idx"] == str(imgt_pos)))]["antibody_AA"].unique()

    
    assert len(imgt_pos_aa) <= 1
    if len(imgt_pos_aa) == 1: imgt_pos_aa = imgt_pos_aa[0]
    else: imgt_pos_aa = None

    return imgt_pos_aa

In [13]:
imgt_pos_aas_all = []

for imgt_pos in tqdm(cbin_df_contact_counts_noncdr3_gt_avg.index, total=len(cbin_df_contact_counts_noncdr3_gt_avg.index)):

    for ind, row in int_df.iterrows():
        tup = ('/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset', row["pdb"], row["Hchain"], row["antigen_chain"], chain_type)
        pdb_repo, pdb, chain, antchain, chain_type = tup
        uq_id = f'{pdb}_{chain_type}({chain})_ant({antchain})'
        try:
            imgt_pos_aa = get_resid(tup, imgt_pos)
            imgt_pos_aas_all.append([imgt_pos, uq_id, imgt_pos_aa])
        except:
            pass

100%|██████████████████████████████████████████████████████████████████████████| 22/22 [02:40<00:00,  7.31s/it]


In [14]:
imgt_pos_aas_all_df = pd.DataFrame(imgt_pos_aas_all, columns=["imgt_pos","uq_id","aa"])
imgt_pos_aas_all_df[~imgt_pos_aas_all_df["aa"].isnull()]

Unnamed: 0,imgt_pos,uq_id,aa
0,57,6fe4_H(F)_ant(A),ASN
1,57,7jmo_H(H)_ant(A),TYR
3,57,7cwn_H(K)_ant(B),SER
5,57,1nca_H(H)_ant(N),ASN
6,57,4ypg_H(H)_ant(D),SER
...,...,...,...
45834,69,4lsu_H(H)_ant(G),ARG
45837,69,7vyt_H(H)_ant(T),GLN
45846,69,5f72_H(S)_ant(K),LYS
45852,69,5umi_H(H)_ant(C),PRO


In [15]:
imgt_pos_aas_all_df.imgt_pos.value_counts()

imgt_pos
57    2085
36    2085
56    2085
67    2085
40    2085
82    2085
27    2085
52    2085
28    2085
65    2085
29    2085
63    2085
55    2085
62    2085
58    2085
35    2085
37    2085
66    2085
38    2085
59    2085
64    2085
69    2085
Name: count, dtype: int64

In [16]:
imgt_pos_w_high_contact_count_conservation = []

for imgt_pos in imgt_pos_aas_all_df["imgt_pos"].drop_duplicates():
    pos_df = imgt_pos_aas_all_df[imgt_pos_aas_all_df["imgt_pos"] == imgt_pos]

    imgt_pos_w_high_contact_count_conservation.append([imgt_pos, cbin_df_contact_counts_noncdr3_gt_avg.loc[imgt_pos]["ab_loc"],
                                                       cbin_df_contact_counts_noncdr3_gt_avg.loc[imgt_pos]["contact_counts"], pos_df["aa"].value_counts(normalize=True).iloc[0],
                                                       pos_df["aa"].value_counts(normalize=True).index[0]])

In [17]:
imgt_pos_w_high_contact_count_conservation_df = pd.DataFrame(imgt_pos_w_high_contact_count_conservation, columns=["imgt_pos","ab_loc","contact_counts","conservation_highest_pct","conservation_highest_aa"])

In [18]:
imgt_pos_w_high_contact_count_conservation_df.sort_values(by=["contact_counts"], ascending=False)

Unnamed: 0,imgt_pos,ab_loc,contact_counts,conservation_highest_pct,conservation_highest_aa
0,57,CDR2,1362.0,0.224595,TYR
1,36,CDR1,1351.0,0.380775,SER
2,64,CDR2,1333.0,0.206481,SER
3,59,CDR2,1310.0,0.225287,SER
4,38,CDR1,1135.0,0.281028,TYR
5,66,FR3,1120.0,0.345013,TYR
6,37,CDR1,1015.0,0.667988,TYR
7,35,CDR1,975.0,0.407827,SER
8,58,CDR2,879.0,0.232,SER
9,62,CDR2,877.0,0.208238,SER


In [19]:
def get_contact_info(tup, imgt_pos):
    pdb_repo, pdb, chain, antchain, chain_type = tup
    
    contact_info_file = os.path.join(pdb_repo, pdb, f'{chain_type}({chain})_ant({antchain})_contact_details.csv')
    int_details = pd.read_csv(contact_info_file, index_col=0)
    int_details["ab_aa_idx"] = int_details["antibody_AA"].astype(str) + int_details["antibody_idx"].astype(str)
    int_details["ag_aa_idx"] = int_details["antigen_AA"].astype(str) + int_details["antigen_idx"].astype(str)

    ab_aa_idx = int_details[(int_details["PDB"] == pdb) & (int_details["chain_type"] == chain_type) &
                        (int_details["antibody_chain"] == chain) &
                        ((int_details["antibody_idx"] == float(imgt_pos)) | (int_details["antibody_idx"] == str(imgt_pos)))]["ab_aa_idx"].unique()

    ag_aa_idx = int_details[(int_details["PDB"] == pdb) & (int_details["chain_type"] == chain_type) &
                        (int_details["antibody_chain"] == chain) &
                        ((int_details["antibody_idx"] == float(imgt_pos)) | (int_details["antibody_idx"] == str(imgt_pos)))]["ag_aa_idx"].to_list()

    return ab_aa_idx, ag_aa_idx

In [20]:
ag_contact_aas_all = []

for imgt_pos in tqdm(cbin_df_contact_counts_noncdr3_gt_avg.index, total=len(cbin_df_contact_counts_noncdr3_gt_avg.index)):
# for imgt_pos in tqdm(cbin_df.columns, total=len(cbin_df.columns)):
    
    for ind, row in int_df.iterrows():
        tup = ('/n/groups/marks/projects/sequence_structure/sabdab/2022_12_21/sabdab_dataset', row["pdb"], row["Hchain"], row["antigen_chain"], chain_type)
        pdb_repo, pdb, chain, antchain, chain_type = tup
        uq_id = f'{pdb}_{chain_type}({chain})_ant({antchain})'

        try:
            ab_aa, ag_aa = get_contact_info(tup, imgt_pos)
            if len(ab_aa) == 1: ab_aa = ab_aa[0]
            ag_contact_aas_all.append([uq_id, ab_aa, ag_aa])
        except:
            pass

100%|██████████████████████████████████████████████████████████████████████████| 22/22 [04:05<00:00, 11.16s/it]


In [21]:
contact_info = pd.DataFrame(ag_contact_aas_all, columns=["uq_id","ab_aa_idx","ag_aa_idx"])
contact_info = contact_info[contact_info["ag_aa_idx"].apply(lambda x: len(x)) > 0]
contact_info["ag_aa_idx"] = contact_info["ag_aa_idx"].apply(lambda arr: list(arr))

In [23]:
contact_info_expl = contact_info.explode("ag_aa_idx")
contact_info_expl = contact_info_expl[contact_info_expl["uq_id"].isin(contact_info_expl["uq_id"])]

motifs = []
for (uq_id1, ag_aa_idx1), grp1 in tqdm(contact_info_expl.groupby(["uq_id","ag_aa_idx"]), total=len(contact_info_expl.groupby(["uq_id","ag_aa_idx"]))):
    ab_motif = grp1["ab_aa_idx"].to_list()
    for (uq_id2, ag_aa_idx2), grp2 in contact_info_expl.groupby(["uq_id","ag_aa_idx"]):
        if (uq_id1 == uq_id2) & (ag_aa_idx1 == ag_aa_idx2): pass
        else:
            if ag_aa_idx1[:3] == ag_aa_idx2[:3]:
                ab_motif_intersection = set(ab_motif).intersection(set(grp2["ab_aa_idx"].to_list()))
                if len(ab_motif_intersection) > 1:
                    motifs.append([uq_id1, ag_aa_idx1[:3], ab_motif_intersection])

100%|██████████████████████████████████████████████████████████████████| 17591/17591 [1:29:25<00:00,  3.28it/s]


In [24]:
mdf = pd.DataFrame(motifs, columns=["uq_id","ag_aa","ab_motif"])

In [25]:
mdf.to_csv("all_motifs")

In [None]:
mdf = pd.DataFrame(motifs, columns=["uq_id","ag_aa","ab_motif"])
mdf["ab_motif"] = mdf["ab_motif"].apply(lambda s: sorted(list(s)))#.astype(str)
mdf
mdf.to_csv("all_motifs")
# mdf.sort_values(by=["ag_aa","ab_motif"])

# for abm, grp in mdf.groupby("ab_motif"):
#     print(grp)

# # mdf.value_counts()
# mdf[mdf.duplicated(keep=False)]

In [119]:
contact_info[contact_info["uq_id"].isin(contact_info["uq_id"].unique()[:2])]

Unnamed: 0,uq_id,ab_aa_idx,ag_aa_idx
0,6fe4_H(F)_ant(A),ASN57,"[ASP35, ASP36]"
1,7jmo_H(H)_ant(A),TYR57,"[GLY416, LYS417, ASP420, TYR421]"
2086,7jmo_H(H)_ant(A),SER36,"[LYS458, TYR473, GLN474, ALA475]"
4170,6fe4_H(F)_ant(A),GLY64,[ASP35]
4171,7jmo_H(H)_ant(A),SER64,"[THR415, GLY416, ASP420, ASN460]"
6256,7jmo_H(H)_ant(A),GLY59,"[TYR421, ARG457, LYS458, SER459, ASN460]"
8340,6fe4_H(F)_ant(A),TYR38,"[ASP35, THR37, SER50]"
8341,7jmo_H(H)_ant(A),TYR38,"[LYS417, LEU455, PHE456]"
10425,6fe4_H(F)_ant(A),ARG66,"[ASN33, ASP35, THR37, TRP48]"
10426,7jmo_H(H)_ant(A),PHE66,"[THR415, GLY416]"


In [124]:
contact_info.to_csv("all_contact_info")