In [11]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from time import time
import pandas as pd
import numpy as np
import tqdm
import random
from collections import defaultdict
import argparse

import torch
import torch.nn as nn
from torch import cuda
from torch.optim import Adam
from torch.utils.data import Dataset

from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss,DataLoader
from torchkge import KnowledgeGraph,DistMultModel,TransEModel,TransRModel
from torchkge.models.bilinear import HolEModel,ComplExModel

from utils import *

In [19]:
import pandas as pd
from rdkit import Chem 
from rdkit.Chem import Lipinski
from rdkit.Chem import Descriptors
import tqdm

def rule_five_calculator(df):
    # INDEX = []
    num_H_acc =[]
    num_H_don = []
    num_rota = []
    mw = []
    logp = []

    for i in tqdm.tqdm(range(len(df))):
        smi = df.iloc[i]['smi']
        mol = Chem.MolFromSmiles(smi)

        num_H_acc.append(Lipinski.NumHAcceptors(mol))
        num_H_don.append(Lipinski.NumHDonors(mol))
        num_rota.append(Lipinski.NumRotatableBonds(mol))
        mw.append(Descriptors.MolWt(mol))
        logp.append(Descriptors.MolLogP(mol))
        # if num_H_acc < 10 and num_H_don <5 and mw <=500 and num_rota < 10 and logp>=1 and logp <=5:
        #     INDEX.append(i)
        #     num += 1

    df['num_H_acc'] = num_H_acc
    df['num_H_don'] = num_H_don
    df['num_rota'] = num_rota
    df['mw'] = mw
    df['logp'] = logp

    return df



from rdkit import Chem
from rdkit.Chem import FilterCatalog
# Initialize filter parameters
param = FilterCatalog.FilterCatalogParams()
# Add PAINS filters
param.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS)
# Create the filter catalog
filt = FilterCatalog.FilterCatalog(param)

def PAINS_calculator(df):
    pains = []
    for i in tqdm.tqdm(range(len(df))):
        smi = df.iloc[i]['smi']
        mol = Chem.MolFromSmiles(smi)
        pains.append(filt.HasMatch(mol))
    
    df['pains'] = pains

    return df

In [12]:
'''
This section is user defined !!!
'''
h_dim = 300

data_path = "../processed_data/target_inference_2/"
save_model_path = '../best_model/target_inference_2/'
output_path = "../results/target_inference_2/"

ent_list = ['Protein:ENPP1']

In [20]:
# load processed_data after training
cause = pd.read_csv(data_path + 'cause.txt',sep='\t',names=['from','rel','to'])
ent2id = np.load(data_path + 'ent2id.npy', allow_pickle=True).item()
rel2id = np.load(data_path + 'rel2id.npy', allow_pickle=True).item()

h_cand = [v for k,v in ent2id.items() if k.startswith('CID:')]
h_cand_ent = [k for k,v in ent2id.items() if k.startswith('CID:')]

t_cand = [v for k,v in ent2id.items() if k.startswith(('Protein:','TF:','RBP:'))]
t_cand_ent = [k for k,v in ent2id.items() if k.startswith(('Protein:','TF:','RBP:'))]

In [21]:
# count virtual screening score
vs_dict = {}

for ent in tqdm.tqdm(ent_list):
    results = []
    for i in range(5):
        model = DistMultModel(h_dim, len(ent2id), len(rel2id))
        model.load_state_dict(torch.load(save_model_path + "pertkg{}.pt".format(i)))
        if cuda.is_available():
            cuda.empty_cache()
            model.cuda()
        model.normalize_parameters()
        model.eval()
        with torch.no_grad():
            ent_emb,rel_emb = model.get_embeddings() # (n_ent, emb_dim)
            score = inference(ent,
                            ent2id,rel2id,
                            ent_emb,rel_emb,
                            h_cand,t_cand,
                            'virtual_screening')
            results.append(score)

    average_list = [sum(x) / len(x) for x in zip(*results)]
    vs_dict['{}'.format(ent)] = average_list

100%|██████████| 2/2 [00:07<00:00,  3.52s/it]


In [22]:
# output
comp_map_table = pd.read_csv('../map_file/compound_map_table.txt',sep='\t')
cid2smi = dict(zip(comp_map_table['cid'],comp_map_table['smi']))
h_cand_smi = [cid2smi[x] for x in h_cand_ent]

for k,v in vs_dict.items():
    vs_score = v
    df = pd.DataFrame({'compound':h_cand_ent,
                       'vs_score':vs_score,
                       'smi':h_cand_smi})
    df = df.sort_values(by='vs_score', ascending=False)

    # filter
    df = rule_five_calculator(df)
    df = PAINS_calculator(df)

    df.to_csv(output_path + '{}.txt'.format(k),sep='\t',index=False,header=True)

100%|██████████| 11062/11062 [00:10<00:00, 1060.70it/s]
100%|██████████| 11062/11062 [00:35<00:00, 310.13it/s]
100%|██████████| 11062/11062 [00:10<00:00, 1072.93it/s]
100%|██████████| 11062/11062 [00:35<00:00, 311.57it/s]
