In [1]:
import pandas as pd

df_train = pd.read_csv("epitope3d_dataset_180_Train.csv")

# use the subset of df_train as example
df_example = df_train[:5]
df_example

Unnamed: 0,PDB,Epitope
0,1YBW,"A:LEU:487, A:ILE:483, A:HIS:496, A:ASP:639, A:..."
1,5X59,"A:VAL:26, A:LYS:27, A:LYS:543, A:ARG:190, A:GL..."
2,3PMT,"A:TRP:567, A:LYS:590, A:GLU:598, A:TYR:597, A:..."
3,3IT8,"C:GLU:23, C:PRO:139, B:GLU:23, B:GLN:67, B:GLU..."
4,4O38,"B:GLU:56, B:ASN:183, B:SER:62, B:GLU:124, B:GL..."


In [None]:
# For ESM_IF1 generation, removing heteroatom is necessary
# But not necessary for examples PDB; already processed 
# reference from https://stackoverflow.com/questions/25718201/remove-heteroatoms-from-pdb

class NonHetSelect(Select):
    def accept_residue(self, residue):
        return 1 if residue.id[0] == " " else 0

In [2]:
# calculate the euclidean distance between two nodes (Ca coordinates)
def euclidean_dist(x, y):
    return ((x[:, None] - y) ** 2).sum(-1).sqrt()

# based on the distance, generate edge connection for distance within threshold
def edge_connection(coord_list, threshold):
    # Compute pairwise euclidean distances
    distances = euclidean_dist(coord_list, coord_list)
    
    # to avoid self-connection, make the distance 0 between self nodes into infinity
    distances.fill_diagonal_(float("inf"))

    # edges are constructed within threshold 
    edges = (distances < threshold).nonzero(as_tuple=False).t()
    
    return edges

In [3]:
from esm_embedding import esm_if_2_embedding
from Bio.PDB.DSSP import dssp_dict_from_pdb_file, residue_max_acc
from torch_geometric.data import Data
import torch

import warnings
warnings.filterwarnings('ignore')

def generate_graph(df, path, distance_threshold, RSA_threshold):    
    

       # iterate the pdbs
    pyg_data_list = []
    for idx, row in df.iterrows():
        pdb_id = row["PDB"]
        print("PDB is :", pdb_id)
        # get the list of epitopes 
        epitope_list = row["Epitope"].split(", ")
        
        
        esm_if_rep, esm2_rep, node_list, coord_list = esm_if_2_embedding(pdb_id, path)
        
        
        esm_node_features = torch.concat((esm_if_rep, esm2_rep), dim=1)
        
        # iterate per-chain node_list into whole-chain node_list
        node_all_list = []
        for chain_node in node_list:
            for node in chain_node:
                node_all_list.append(node)
                
        # iterate per-chain coord_list into whole-chain coord_list           
        coord_all_list = []
        for chain_coord in coord_list:
            for coord in chain_coord:
                coord_all_list.append(coord)
        
        coord_all_list = torch.tensor(coord_all_list)
        
        # generate the edge connection within 10 Angstrome distance (while removing self-connection)
        edges = edge_connection(coord_all_list, threshold=distance_threshold)
        
        # generate the label for each node based on the epitope annotation
        y_list = list()
        for node in node_all_list:
            if node in epitope_list:
                y_list.append(int(1))
            else:
                y_list.append(int(0))
        
        y_list = torch.tensor(y_list)
        
        # generate rsa feature by extracting dssp from pdb file
        dssp = dssp_dict_from_pdb_file(f"{path}/{pdb_id}.pdb")
        
        rsa_list = []
        for node in node_all_list:
            chain, res_name, res_id = node.split(":")
            try:
                # indexing the dssp such as ('A', (' ', 53, ' '))
                key = (chain, (' ', int(res_id), ' '))
                
                # generate rsa va;ie by normalizing asa by residue_max_acc -> 
                rsa = dssp[0][key][2] / residue_max_acc["Sander"][res_name] 
                rsa_list.append(rsa)
            except:
                rsa_list.append(0)
                print("Key Error... appending rsa: 0")
        
        # The surface residues were selected with RSA cutoff 10%
        # surface residues can be chosen by indexing as [data.train_mask]
        
        train_mask = torch.tensor([rsa >=  RSA_threshold for rsa in rsa_list])
            
        data = Data(coords=coord_all_list, node_id=node_all_list, node_attrs=esm_node_features, edge_index=edges.contiguous(), y=y_list,
                 num_nodes=len(node_all_list), name=pdb_id, train_mask=train_mask, rsa=rsa_list)
        
        pyg_data_list.append(data)
        
    return pyg_data_list



In [4]:
# using cache line comes for each protein chain 

example_pyg_list = generate_graph(df_example, path="Example_PDB", distance_threshold=10, RSA_threshold=0.15)

PDB is : 1YBW


Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main


PDB is : 5X59


Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main


Key Error... appending rsa: 0
PDB is : 3PMT


Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main


PDB is : 3IT8


Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main
Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main
Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main


PDB is : 4O38


Using cache found in /home/sjchoi/.cache/torch/hub/facebookresearch_esm_main


In [5]:
example_pyg_list[0]

Data(edge_index=[2, 4866], y=[246], coords=[246, 3], node_id=[246], node_attrs=[246, 1792], num_nodes=246, name='1YBW', train_mask=[246], rsa=[246])