In [49]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import Draw

import networkx as nx

from utils.graph_utils import *
from predict_logp.predict_logp import *
import torch_geometric as pyg

In [50]:
if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

# Loading generated molecules

In [51]:
colnames = ['SMILE', 'rew_valid', 'rew_qed', 'rew_sa', 'final_stat', 'rew_env', 'rew_d_step','rew_d_final',\
           'cur_ep_et', 'flag_steric_strain_filter', 'flag_zinc_molecule_filter', 'stop']
df_logpen = pd.read_csv("../../Downloads/molecule_gen/molecule_zinc_logppen.csv", header = None, names = colnames)
df_qed_condition = pd.read_csv("../../Downloads/molecule_gen/molecule_zinc_qed_conditional.csv", header = None, names = colnames)
df_qedsa = pd.read_csv("../../Downloads/molecule_gen/molecule_zinc_qedsa.csv", header = None, names = colnames)
df_qed = pd.read_csv("../../Downloads/molecule_gen/molecule_zinc_test_conditional.csv", header = None, names = colnames)

In [52]:
df_logpen = df_logpen[~df_logpen["SMILE"].str.contains("Iteration")]
df_qed_condition = df_qed_condition[~df_qed_condition["SMILE"].str.contains("Iteration")]
df_qedsa = df_qedsa[~df_qedsa["SMILE"].str.contains("Iteration")]
df_qed = df_qed[~df_qed["SMILE"].str.contains("Iteration")]

In [53]:
#Filter by steric_strain_filter == True, flag_zinc_molecule_filter==True, and sort by qed
mol_filter = 'flag_steric_strain_filter == True & flag_zinc_molecule_filter == True'
df_logpen = df_logpen.query(mol_filter).sort_values("final_stat", ascending = False)
df_qed_condition = df_qed_condition.query(mol_filter).sort_values("final_stat", ascending = False)
df_qedsa = df_qedsa.query(mol_filter).sort_values("final_stat", ascending = False)
df_qed = df_qed.query(mol_filter).sort_values("final_stat", ascending = False)

In [54]:
df_logpen.shape, df_qed_condition.shape, df_qedsa.shape, df_qed.shape

((10321, 12), (4823, 12), (53423, 12), (88242, 12))

In [70]:
df_logpen

Unnamed: 0,SMILE,rew_valid,rew_qed,rew_sa,final_stat,rew_env,rew_d_step,rew_d_final,cur_ep_et,flag_steric_strain_filter,flag_zinc_molecule_filter,stop
10356,CCCCC(CC(C)CC)CC(C)(CCCC)C(CCCC)C(Cl)C(C)CC,2.0,0.195422,0.620199,1.327454,3.316815,2.399182e-07,0.020973,3.671941,True,True,True
8965,CCCCC(CC)(CCC)CC(CC)(CCC)C(CC)CCC,2.0,0.282128,0.666009,1.146174,3.135536,6.174223e-04,0.284012,3.770076,True,True,True
4470,CCCCCC(C)C(C)=CCCCCC(C)(C)C(C)CC,2.0,0.254441,0.715845,1.092173,3.081535,2.920695e-03,0.147920,3.570815,True,True,True
5788,C=C(C)C(CC)C(CC)C(=C(CC)OC(C(C)C)C(C)CCCCC)C(C)C,2.0,0.153201,0.619376,1.072826,3.062188,3.829380e-06,0.000680,3.494623,True,True,True
9620,CCC=CC=C(C)C(Cl)=CC(=CC(Cl)=C1C=CC(C)=C1C(C)C)CC,2.0,0.390786,0.657766,1.025221,3.014583,3.900398e-03,0.036986,3.879980,True,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...
379,CC=CC1=C(N)C2=NNC(=O)C(=C3NC(C)(C4=CC5=C4C1=C5...,2.0,0.328158,0.322280,-18.288178,-16.277540,0.000000e+00,2.492597,-10.811681,True,True,False
1010,CC=C1C=COC2C3CCC4=CCC(C)C5CC(OF)C(=C(CNC)C(CCC...,2.0,0.468325,0.273927,-18.510834,-16.500196,0.000000e+00,2.088650,-11.615422,True,True,False
1384,CCC=NC1(C)NC2CC(=NC(=O)N(C)C3=NC2=CC=CN=C3)C=C...,2.0,0.688034,0.341180,-18.705708,-16.716347,0.000000e+00,2.256023,-12.219404,True,True,True
412,O=CC1=CC2=C3C=C4CC=CNCC5=NP=NC=C(C=NC(=CC6=NN=...,2.0,0.429090,0.229343,-18.880343,-16.890981,0.000000e+00,2.413172,-11.205498,True,True,True


In [57]:
logpen_smiles = df_logpen["SMILE"].values[:5000]
qed_condition_smiles = df_qed_condition["SMILE"].values[:5000]
qedsa_smiles = df_qedsa["SMILE"].values[:5000]
qed_smiles = df_qed["SMILE"].values[:5000]

# Forward pass on GCN

In [58]:
gcn_net = torch.load("dock_score_models/default_run/dock_score/best_model.pth")

In [59]:
gcn_net

GNN_MyGAT(
  (layers): ModuleList(
    (0): MyGATConv(121, 512, heads=1)
    (1): MyGATConv(512, 512, heads=1)
    (2): MyGATConv(512, 512, heads=1)
    (3): MyGATConv(512, 512, heads=1)
    (4): MyGATConv(512, 512, heads=1)
    (5): MyGATConv(512, 512, heads=1)
    (6): MyGATConv(512, 512, heads=1)
  )
  (final_layer): Linear(in_features=512, out_features=1, bias=True)
  (act): ReLU()
)

In [60]:
logpen_data = MolData([0]*len(logpen_smiles), logpen_smiles)
qed_condition_data = MolData([0]*len(qed_condition_smiles), qed_condition_smiles)
qedsa_data = MolData([0]*len(qedsa_smiles), qedsa_smiles)
qed_data = MolData([0]*len(qed_smiles), qed_smiles)

In [61]:
logpen_dataloader = DataLoader(logpen_data, collate_fn = my_collate, batch_size = 512, num_workers =24)
qed_condition_dataloader = DataLoader(qed_condition_data, collate_fn = my_collate, batch_size = 512, num_workers =24)
qedsa_dataloader = DataLoader(qedsa_data, collate_fn = my_collate, batch_size = 512, num_workers =24)
qed_dataloader = DataLoader(qed_data, collate_fn = my_collate, batch_size = 512, num_workers =24)

In [62]:
logpen_scores = torch.empty(0)
qed_condition_scores = torch.empty(0)
qedsa_scores = torch.empty(0)
qed_scores = torch.empty(0)

for i, (g1,y,g2) in enumerate(logpen_dataloader):
    g1 = g1.to(DEVICE)
    g2 = g2.to(DEVICE)
    y_pred = gcn_net(g1, g2.edge_index)
    logpen_scores = torch.cat((logpen_scores, y_pred))
    
for i, (g1,y,g2) in enumerate(qed_condition_dataloader):
    g1 = g1.to(DEVICE)
    g2 = g2.to(DEVICE)
    y_pred = gcn_net(g1, g2.edge_index)
    qed_condition_scores = torch.cat((qed_condition_scores, y_pred))

for i, (g1,y,g2) in enumerate(qedsa_dataloader):
    g1 = g1.to(DEVICE)
    g2 = g2.to(DEVICE)
    y_pred = gcn_net(g1, g2.edge_index)
    qedsa_scores = torch.cat((qedsa_scores, y_pred))

for i, (g1,y,g2) in enumerate(qed_dataloader):
    g1 = g1.to(DEVICE)
    g2 = g2.to(DEVICE)
    y_pred = gcn_net(g1, g2.edge_index)
    qed_scores = torch.cat((qed_scores, y_pred))



In [63]:
logpen_scores = logpen_scores.detach().numpy()
qed_condition_scores = qed_condition_scores.detach().numpy()
qedsa_scores = qedsa_scores.detach().numpy()
qed_scores = qed_scores.detach().numpy()

In [65]:
top_logpen_mols = logpen_scores.argsort()[:10]
top_qed_condition_mols = qed_condition_scores.argsort()[:10]
top_qedsa_mols = qedsa_scores.argsort()[:10]
top_qed_mols = qed_scores.argsort()[:10]

In [66]:
top_logpen_smiles = logpen_smiles[top_logpen_mols]
top_qed_condition_smiles = qed_condition_smiles[top_qed_condition_mols]
top_qedsa_smiles = qedsa_smiles[top_qedsa_mols]
top_qed_smiles = qed_smiles[top_qed_mols]

In [67]:
top_logpen_molecules = [Chem.MolFromSmiles(i) for i in top_logpen_smiles]
top_qed_condition_molecules = [Chem.MolFromSmiles(i) for i in top_qed_condition_smiles]
top_qedsa_molecules = [Chem.MolFromSmiles(i) for i in top_qedsa_smiles]
top_qed_molecules = [Chem.MolFromSmiles(i) for i in top_qed_smiles]

img = Draw.MolsToGridImage(top_logpen_molecules, subImgSize=(300, 300), molsPerRow=3, useSVG=False)
img.save('ToplogpenMolecules.png')
img = Draw.MolsToGridImage(top_qed_condition_molecules, subImgSize=(300, 300), molsPerRow=3, useSVG=False)
img.save('TopqedcondMolecules.png')
img = Draw.MolsToGridImage(top_qedsa_molecules, subImgSize=(300, 300), molsPerRow=3, useSVG=False)
img.save('TopqedsaMolecules.png')
img = Draw.MolsToGridImage(top_qed_molecules, subImgSize=(300, 300), molsPerRow=3, useSVG=False)
img.save('TopqedMolecules.png')

In [74]:
df_logpen.iloc[top_logpen_mols,]['rew_qed'].values

array([0.20509399, 0.09190346, 0.27486826, 0.11773375, 0.16423396,
       0.11022592, 0.5820904 , 0.17021597, 0.30879826, 0.33916267])

In [75]:
logpen_scores[top_logpen_mols]

array([-21.731936, -19.865137, -18.112387, -17.971996, -17.462837,
       -16.802916, -16.79058 , -16.767763, -16.709795, -16.265676],
      dtype=float32)