In [None]:
# Import general packages
import math
import sys
import os
import re
import glob
import json
import gzip
import itertools
from tqdm.auto import tqdm
import copy
from pathlib import PurePath
from collections import Counter

# Parallel processing
from joblib import Parallel, delayed

# import data and math packages
import numpy as np
import pandas as pd
from pandarallel import pandarallel
 
import networkx as nx
import matplotlib.pyplot as plt

from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from abnumber import Chain
from abnumber.exceptions import ChainParseError, MultipleDomainsChainParseError

# working with structures
import prody as pr
import pymol

import requests
# Working with sequences
import pyfastx
import swalign
# choose your own values here… 2 and -1 are common.
match = 2#
mismatch = -1
scoring = swalign.NucleotideScoringMatrix(match, mismatch)
sw = swalign.LocalAlignment(scoring)  # you can also choose gap penalties, etc...

import torch
from allennlp.commands.elmo import ElmoEmbedder
from pathlib import Path

model_dir = Path('/nfs/baron1/nolde/zhalevsky/uniref50_v2/')  # Seqvec data
weights = model_dir / 'weights.hdf5'
options = model_dir / 'options.json'
seqvec  = ElmoEmbedder(options,weights,cuda_device=0) # cuda_device=-1 for CPU

import Bio.PDB 
import Bio.PDB.ccealign
from Bio.PDB.ccealign import run_cealign

# Part 3

In [None]:
# Read database from part2
# Read excel file with Abs epitope 
# Calculate CDRs of Abs from excel file
# Filter excel file by unique CDRs
# Remove entries with _omi suffix in epitope
# Find intersection of pdb database and filtered excel file
# Compare list of residues contacted with RBD and contacts specific for epitope
# Excluding Abs with <2 specific contacts
# Create pymol file for Abs epitope visualization
# Manually reassign Abs epitope class based on pymol visualization
# Create pymol file for Abs epitope visualization (modified)
# Calculate centers (3 points for each epitope class) for epitop class
# Assign class to Ab from covadab database
# output database ready to RBD-AIM software

In [None]:
# I/O files locataion
datadir = '/nfs/baron1/nolde/zhalevsky/covidmap_v16' # path to I/O directory
dbfname = 'covidab_pdb_rbd_p4.json'                  # input file (output of part 2)
adddbf = '41586_2022_4980_MOESM3_ESM.xlsx'           # xls file from supplement of https://doi.org/10.1038/s41586-022-04980-y
ref_rbd = 'ref.pdb'                                  # reference RBD structure from pdb 7LOP
ref_fn = str(PurePath(datadir, ref_rbd))
ref_ace = 'ace2.pdb'                                 # Ace2 structure of Ace2-RBD complex (6m0j) fitted to refrbd
ace2_fn = str(PurePath(datadir, ref_ace))
pse1 = str(PurePath(datadir, 'df_filt.pse'))         # Pymol pse file with class of database intersection
pse2 = str(PurePath(datadir, 'df_filt2.pse'))        # Pymol pse file with class of database intersection after class reassignment
fn_centers = str(PurePath(datadir, 'centers_all.npy'))       # numpy array of cluster centers
fn_covadadb = str(PurePath(datadir, 'covadab_classes.json')) # output covadab part of database
fn_combo = str(PurePath(datadir, 'combo_classes.json')) # output database for RBD-AIM 
psefin = str(PurePath(datadir, 'df_all.pse'))        # Pymol pse file with class of covadab database

In [None]:
covadab = pd.read_json(PurePath(datadir, dbfname))
jul = pd.read_excel(PurePath(datadir, adddbf), header=[0,1], na_values='--')
jul.columns = [' '.join(col).strip() for col in jul.columns.values]

In [None]:
kname = jul.filter(like='Name').keys()[0]
hchain = jul.filter(like='Heavy chain AA').keys()[0]
lchain = jul.filter(like='Light chain AA').keys()[0]
kepit = jul.filter(like='Epitope').keys()[0]
kid = jul.filter(like='ID').keys()[0]

In [None]:
jul1 = jul[~(jul[hchain].isnull()|jul[lchain].isnull())].copy()
print(len(jul1))

In [None]:
def calc_cdrhl(row, hid, lid):
    try:
        # print(row[hid])
        chainh = Chain(row[hid], scheme='imgt', allowed_species='human', assign_germline=False)
        chainl = Chain(row[lid], scheme='imgt', allowed_species='human', assign_germline=False)
        # print(chainh.cdr1_seq,  chainh.cdr2_seq,  chainh.cdr3_seq, chainl.cdr1_seq,  chainl.cdr2_seq,  chainl.cdr3_seq)
        return chainh.cdr1_seq,  chainh.cdr2_seq,  chainh.cdr3_seq, chainl.cdr1_seq,  chainl.cdr2_seq,  chainl.cdr3_seq
    except ChainParseError:
            print(row[hid], row[lid])
            return None

In [None]:
pandarallel.initialize(progress_bar=True)
tqdm.pandas()
jul1[['CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3']] = jul1.progress_apply(calc_cdrhl, axis=1, 
                                                        args=[hchain, lchain], result_type='expand')

In [None]:
dups = jul1[jul1.duplicated(
    subset=[
        hchain, lchain,  
        'CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3'
    ], keep=False)]
print(len(dups))

In [None]:
dups.sort_values(by = ['CDRH1', 'CDRH2', 'CDRH3'])[[kname, kepit,  'CDRH1', 'CDRH2', 'CDRH3']]

In [None]:
jul2 = (jul1.groupby([hchain, lchain]).filter(lambda group: len(group[kepit].value_counts()) == 1))
print(len(jul2), len(jul1))

In [None]:
dups = jul2[jul2.duplicated(
    subset=[
        hchain, lchain,  
       'CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3'
    ], keep=False)]
print(len(dups))

In [None]:
dups.sort_values(by = ['CDRH1', 'CDRH2', 'CDRH3'])[[kname, kepit, 'CDRH1', 'CDRH2', 'CDRH3' ]]

In [None]:
jul2.drop_duplicates(
    subset=[
        hchain, lchain,  
       'CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3' 
         ], 
    keep='first', inplace=True)

In [None]:
jul3 = jul2.drop_duplicates(
    subset=[
         'CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3' 
         ], 
    keep='first', inplace=False)

In [None]:
jul3[jul3[kepit].str.contains('_Omi')]

In [None]:
jul4 = jul3[~jul3[kepit].str.contains('_Omi')].copy()
print(len(jul4))

In [None]:
Epit = sorted(list(set(jul4[kepit])))
jul4['POS_class'] = [Epit.index(x) for x in jul4[kepit]]


In [None]:
gvecs1 = covadab
df_merge = pd.merge(gvecs1, \
                    jul4[['CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3', 'POS_class']], \
                     on=['CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3'], how='inner')

In [None]:
print(len(df_merge), len(gvecs1))

In [None]:
for i in range(12):
    print(i, len(df_merge[df_merge['POS_class']==i]))
print(df_merge[df_merge['POS_class']==8]['Name'])

In [None]:
def dublen(x1, x2):
    a1 = list(x1)
    a2 = list(x2)
    return len(set(a1)) + len(set(a2)) - len(set(a1+a2))

Epit_cont = [(417, 456, 475), (485, 486, 487), (452, 484, 490), (346, 444, 446, 452), (444, 446, 499),
            (339, 345, 346), (346, 348, 452), (356, 357, 468), (357, 391, 462, 516, 518), (383, 385, 386),
            (376, 378, 408), (405, 503, 504)]
excl_names = []
for i in range(12):
    cur = df_merge[df_merge['POS_class']==i]
    for ind, row in cur.iterrows():
        if dublen(row['cont'], Epit_cont[i]) < 2:
            print(i, row['Name'], row['pdbid'], row['cont'])
            excl_names.append(row['Name'])
df_filt =  df_merge[~df_merge['Name'].isin(excl_names)]            
len(df_filt), len(df_merge)

In [None]:
allcl = ('A', 'B', 'C', 'D1', 'D2', 'E1', 'E2.1', 'E2.2', 'E3', 'F1', 'F2', 'F3')
    
colors = ['0x005a60', '0x2fbead', '0xfe5e44', '0xd1a684', '0xfde74c', '0xfe7d0e',
          '0xe29462', '0x4a4a4a', '0x909393', '0x2daaf0', '0x7565ff', '0xd458fb']

pymol.cmd.reinitialize()
pymol.cmd.load(ref_fn)
pymol.cmd.load(ace2_fn)
pymol.cmd.color('grey70')
for i in range(12):
    df = df_filt[df_filt['POS_class']==i]
    # print(i, len(df))
    color = colors[i]
    for ind, row in df.iterrows():
        abname = row['Name']
        ab = f'{abname}_{i}'
        # print(ab)
        pos = row['ab_ca']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS2_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_cdr']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS1_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_rbd']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS0_{ab}', vdw=1.0, pos=posString)
        pymol.cmd.bond(atom1=(f"name PS0_{ab}"), atom2=(f"name PS1_{ab}"))
        pymol.cmd.bond(atom1=(f"name PS1_{ab}"), atom2=(f"name PS2_{ab}"))
        pymol.cmd.color(color, ab)
        pymol.cmd.show('spheres', ab)

pymol.cmd.save(pse1)
   

In [None]:
# Manual class reassignment based on pse file
df_filt.loc[df_filt['Name']=='CR3022', 'POS_class'] = 9
df_filt.loc[df_filt['Name']=='S2A4', 'POS_class'] = 9
df_filt.loc[df_filt['Name']=='CV2-75', 'POS_class'] = 9
df_filt.loc[df_filt['Name']=='S2X35', 'POS_class'] =  10
df_filt.loc[df_filt['Name']=='BD667', 'POS_class'] =  6
df_filt.loc[df_filt['Name']=='NT-193', 'POS_class'] =  0
df_filt.loc[df_filt['Name']=='XGv282', 'POS_class'] =  3
df_filt.loc[df_filt['Name']=='BG1-24', 'POS_class'] =  2
df_filt.loc[df_filt['Name']=='BG7-20', 'POS_class'] =  2
df_filt.loc[df_filt['Name']=='Fab-15033', 'POS_class'] =  0
df_filt.loc[df_filt['Name']=='BD-667', 'POS_class'] =  6

In [None]:
allcl = ('A', 'B', 'C', 'D1', 'D2', 'E1', 'E2.1', 'E2.2', 'E3', 'F1', 'F2', 'F3')
    
colors = ['0x005a60', '0x2fbead', '0xfe5e44', '0xd1a684', '0xfde74c', '0xfe7d0e',
          '0xe29462', '0x4a4a4a', '0x909393', '0x2daaf0', '0x7565ff', '0xd458fb']

pymol.cmd.reinitialize()
pymol.cmd.load(ref_fn)
pymol.cmd.load(ace2_fn)
pymol.cmd.color('grey70')
for i in range(12):
    df = df_filt[df_filt['POS_class']==i]
    # print(i, len(df))
    color = colors[i]
    for ind, row in df.iterrows():
        abname = row['Name']
        ab = f'{abname}_{i}'
        # print(ab)
        pos = row['ab_ca']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS2_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_cdr']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS1_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_rbd']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS0_{ab}', vdw=1.0, pos=posString)
        pymol.cmd.bond(atom1=(f"name PS0_{ab}"), atom2=(f"name PS1_{ab}"))
        pymol.cmd.bond(atom1=(f"name PS1_{ab}"), atom2=(f"name PS2_{ab}"))
        pymol.cmd.color(color, ab)
        pymol.cmd.show('spheres', ab)

pymol.cmd.save(pse2)

In [None]:
def center_pos(X): # calculate center of position excluding ~1/3 outliers
    N = len(X)
    Ncut = math.ceil(0.67*N)
    # print(N, Ncut, X.shape)
    sel = slice(N)
    for i in range(10):
        Xc = np.mean(X[sel], axis=0)
        dist = np.linalg.norm(X-Xc, axis=-1)
        # print(dist)
        dcut = sorted(dist)[Ncut-1]
        sel = (dist<=dcut)
        Xc1 = np.mean(X[sel], axis=0)
        if np.linalg.norm(Xc-Xc1)==0:
            break
   # print(np.mean(X, axis=0) - Xc)
    return Xc

In [None]:
allcenter = np.empty((3, 12, 3))
for i, posatom in enumerate(('ab_ca', 'cont_cdr', 'cont_rbd')):
    for e in range(12):
        df_e = df_merge[df_merge['POS_class']==e]
        _cmat = df_e[posatom].to_list()
        allcenter[i, e] =  center_pos(np.array(_cmat))
np.save(fn_centers, allcenter)

In [None]:
covadab['Class_CA'] = -1
covadab['Class_cdr'] = -1
covadab['Class_rbd'] = -1
covadab['good'] = 0

In [None]:
def setclass(X ,center):
   # print('setclass')
    dist = np.linalg.norm(X-center, axis=-1)
   # print(np.argmin(dist))
    return np.argmin(dist)

In [None]:
for ind, row in covadab.iterrows():
    covadab.at[ind, 'Class_CA'] = setclass(np.array(row['ab_ca']), allcenter[0])
    covadab.at[ind, 'Class_cdr'] = setclass(np.array(row['cont_cdr']), allcenter[1])
    covadab.at[ind, 'Class_rbd'] = setclass(np.array(row['cont_rbd']), allcenter[2])
    covadab.at[ind, 'POS_class'] = int(covadab.at[ind, 'Class_rbd'])
    if (covadab.at[ind, 'Class_CA'] == covadab.at[ind, 'Class_cdr']) & \
    (covadab.at[ind, 'Class_CA'] == covadab.at[ind, 'Class_rbd']):
        covadab.at[ind, 'good'] = 1

In [None]:
print(len(covadab[covadab['good']==1]), len(covadab))

In [None]:
counter = Counter(covadab['POS_class']) 
print(sorted(counter.items()))

In [None]:
jul_new = jul4.merge(df_merge[['CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3']], on=['CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3'],
                    how='left', suffixes = (None, '_y'), indicator=True)

In [None]:
print(len(jul_new[jul_new['_merge'] == 'left_only']))

In [None]:
gvecs = jul_new[jul_new['_merge'] == 'left_only'].copy()
cdrkeys = [f'CDR{c}{i}' for c, i in itertools.product(['H', 'L'], range(1, 4))]

for k in cdrkeys:
    gvecs[f'e{k}'] = None
    
gvecs['eCDRH'] = None
gvecs['eCDRL'] = None

gvecs['eCDRHL'] = None
gvecs['eCDRHL3'] = None


for index, row in tqdm(gvecs.iterrows(), total=len(gvecs)):
 
    embeds = {k: None for k in cdrkeys}
    seqs_keys = {row[k]: k for k in cdrkeys}

    seqs = [list(k) for k in seqs_keys.keys()]
    seqs.sort(key=len) 

    embedding = seqvec.embed_sentences(seqs) # returns: List-of-Lists with shape [3,L,1024]
    c = 0
    for seq, embed_ in zip(seqs, list(embedding)):
        c += 1
        k = seqs_keys[''.join(seq)]
        embed__ = torch.tensor(embed_).sum(dim=0).mean(dim=0)
        embed = embed__.cpu().detach().numpy()
        embeds[k] = embed

        gvecs.at[index, f'e{k}'] = embed

    if c != 6:
        continue

    gvecs.at[index, 'eCDRHL'] = np.hstack(list(embeds.values()))

    gvecs.at[index, 'eCDRH'] = np.hstack([
        gvecs.at[index, 'eCDRH1'],
        gvecs.at[index, 'eCDRH2'],
        gvecs.at[index, 'eCDRH3']
    ])

    gvecs.at[index, 'eCDRL'] = np.hstack([    
        gvecs.at[index, 'eCDRL1'],
        gvecs.at[index, 'eCDRL2'],
        gvecs.at[index, 'eCDRL3']
    ])

    gvecs.at[index, 'eCDRHL'] = np.hstack([
        gvecs.at[index, 'eCDRH'],
        gvecs.at[index, 'eCDRL'],
    ])

    gvecs.at[index, 'eCDRHL3'] = np.hstack([
        gvecs.at[index, 'eCDRH3'],
        gvecs.at[index, 'eCDRL3']])

In [None]:
gvecs['good'] = -1
gvecs = gvecs[[kname, hchain, lchain, 'POS_class', 'CDRH1', 'CDRH2', 'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3']].rename(\
    columns={kname: 'Name', hchain: 'VHorVHH', lchain: 'VL'})
combo_classes = pd.concat((covadab, gvecs), axis=0).drop_duplicates(subset=['CDRH1', 'CDRH2', 'CDRH3', \
    'CDRL1', 'CDRL2', 'CDRL3'], keep='first').reset_index()

In [None]:
len(combo_classes)

In [None]:
combo_classes.keys()

In [None]:
dups = combo_classes[combo_classes.duplicated(
    subset=[
        'Name'
    ], keep=False)]
print(len(dups))

In [None]:
len(combo_classes['Name'].to_list()), len(combo_classes), len(set(combo_classes['Name'].to_list()))

In [None]:
print(dups[['Name', 'CDRL1', 'CDRL2', 'CDRL3', 'CDRH1', 'CDRH2', 'CDRH3', 'POS_class']])

In [None]:
combo_classes.drop_duplicates(subset=['Name'], inplace=True)
print(len(combo_classes))

In [None]:
newdf = combo_classes.rename(columns={'VHorVHH': 'Heavy','VL': 'Light'})[['index', 'Name', 'Heavy', 'Light', 'POS_class', 'CDRH1', 'CDRH2',
       'CDRH3', 'CDRL1', 'CDRL2', 'CDRL3', 'good', 'eCDRH1', 'eCDRH2',
       'eCDRH3', 'eCDRL1', 'eCDRL2', 'eCDRL3']]

In [None]:
covadab.to_json(fn_covadadb)
newdf.to_json(fn_combo)

In [None]:
allcl = ('A', 'B', 'C', 'D1', 'D2', 'E1', 'E2.1', 'E2.2', 'E3', 'F1', 'F2', 'F3')
    
colors = ['0x005a60', '0x2fbead', '0xfe5e44', '0xd1a684', '0xfde74c', '0xfe7d0e',
          '0xe29462', '0x4a4a4a', '0x909393', '0x2daaf0', '0x7565ff', '0xd458fb']

pymol.cmd.reinitialize()
pymol.cmd.load(ref_fn)
pymol.cmd.load(ace2_fn)
pymol.cmd.color('grey70')
for i in range(12):
    df = covadab[covadab['POS_class']==i]
    # print(i, len(df))
    color = colors[i]
    for ind, row in df.iterrows():
        abname = row['Name']
        ab = f'{abname}_{i}'
        # print(ab)
        pos = row['ab_ca']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS2_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_cdr']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS1_{ab}', vdw=1.0, pos=posString)
        pos = row['cont_rbd']
        posString = "[%3.2f,%3.2f,%3.2f]" % (pos[0], pos[1], pos[2])
        pymol.cmd.pseudoatom(ab, name=f'PS0_{ab}', vdw=1.0, pos=posString)
        pymol.cmd.bond(atom1=(f"name PS0_{ab}"), atom2=(f"name PS1_{ab}"))
        pymol.cmd.bond(atom1=(f"name PS1_{ab}"), atom2=(f"name PS2_{ab}"))
        pymol.cmd.color(color, ab)
        pymol.cmd.show('spheres', ab)

pymol.cmd.save(psefin)