In [1]:
import torch, dgl
import numpy as np
from dgl.data.utils import save_graphs, load_graphs

from tqdm import tqdm
from glob import glob

from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')

from atom_feature import *
from utils import load_obj, calculate_pair_distance
from collections import defaultdict

In [2]:
proteins = load_obj('/data/PLBA/pocket_8A.pickle')
ligands = load_obj("/data/PLBA/ligand.pickle")

train_y = { line.split()[0].lower(): torch.tensor( [float(line.split()[1])] ).float() for line in open('../data/BindingAffinity/PDBBind_v2020.txt').readlines() }
test_y  = { line.split()[0].lower(): torch.tensor(  [float(line.split()[1])] ).float() for line in open('../data/BindingAffinity/core-set.txt').readlines() }

In [3]:
def mol_to_graph( mol ):
    n     = mol.GetNumAtoms()
    coord = get_mol_coordinate(mol)
    h     = get_atom_feature(mol)
    adj   = get_bond_feature(mol).to_sparse(sparse_dim=2)
    
    u = adj.indices()[0]
    v = adj.indices()[1]
    e = adj.values()

    g = dgl.DGLGraph()
    g.add_nodes(n)
    g.add_edges(u, v)
    
    g.ndata['feats'] = h
    g.ndata['coord'] = coord
    g.edata['feats'] = e
    
    g.ndata['pos_enc'] = dgl.random_walk_pe(g, 20)

    return g

def complex_to_graph(pmol, lmol):
    pcoord = get_mol_coordinate(pmol)
    lcoord = get_mol_coordinate(lmol)
    ccoord = torch.cat( [pcoord, lcoord] )
    
    npa = pmol.GetNumAtoms()
    nla = lmol.GetNumAtoms()
    
    distance = calculate_pair_distance(pcoord, lcoord)
    u, v = torch.where( distance < 5 ) ### u - src protein node, v - dst ligand node

    distance = distance[ u, v ].unsqueeze(-1)
    
    interact_feature = get_interact_feature( pmol, lmol, u, v  )
    distance_feature = get_distance_feature(distance).squeeze(-1)

    e = torch.cat( [interact_feature, distance_feature], dim=1)
    e = torch.cat( [e, e] )
    
    distance = torch.cat( [ distance, distance] )
    
    u, v = torch.cat( [u, v+npa] ), torch.cat( [v+npa, u] )
    
    g = dgl.DGLGraph()
    g.add_nodes( npa + nla )
    g.add_edges( u, v )

    g.ndata['coord'] = ccoord
    g.edata['feats'] = e
    g.edata['distance'] = distance
    
    return g

In [4]:
decoy_ligands = { f.split('/')[-2]: f for f in glob('/data/CASF-2016/decoys_docking/*/*') if '.sdf' in f }

In [12]:
def get_ligand_name_from_sdf( sdf ):
    names = []
    lines = open( sdf ).readlines()
    for idx, line in enumerate( lines ):
        if line[:3] == ' Op':
            lignad_name = lines[idx -1 ]
            names.append(lignad_name[:-1)
    return names

In [20]:
for ppdb in tqdm(decoy_ligands):
    pmol = proteins[ppdb]
    
    lmol_names = get_ligand_name_from_sdf( decoy_ligands[ppdb] )
    
    lmols = Chem.SDMolSupplier( decoy_ligands[ppdb] )
        
    pmol = proteins[ppdb]
    gp = mol_to_graph(pmol)
    
    for idx, lmol in enumerate(lmols):
        try:
            gl = mol_to_graph(lmol)
            gc = complex_to_graph(pmol, lmol)

            save_graphs(f'/data/PLBA/docking-power_graph/{ppdb}/{ppdb}_protein_{lmol_names[idx]}.bin', gp)
            save_graphs(f'/data/PLBA/docking-power_graph/{ppdb}/{ppdb}_ligand_{lmol_names[idx]}.bin', gl)
            save_graphs(f'/data/PLBA/docking-power_graph/{ppdb}/{ppdb}_complex_{lmol_names[idx]}.bin', gc)

        except Exception as E:        
            print(E, ppdb, lmol_names[idx])
    break

  0%|                                                                                                                                                                  | 0/285 [00:00<?, ?it/s]
