In [26]:
import os
# import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
# from rdkit.Chem import Descriptors
import time
import pickle
import matplotlib.pyplot as plt


from fragdockrl import load_ep
from fragdockrl import utils, rl_utils

In [21]:
d_dir = '.'
data_dir = d_dir+'/../data'
ep_dir = d_dir+'/ep'
protein_dir = d_dir+'/../fix'

In [6]:


building_block_file = data_dir + '/bb_reaction.pkl'
reaction_file = data_dir + '/smirks_reactant.pkl'
m_bb_file = data_dir + '/m_bb.pkl'

dd = utils.load_reaction_data(building_block_file, reaction_file, m_bb_file)
df_bb, df_reaction, reactant_id_dict, mol_bb_dict = dd
reaction_key = reactant_id_dict.keys()
b_gg = df_bb[df_bb.columns[2:]].any(axis=1)
b_gg.loc[0]=True
df_bb = df_bb[b_gg]

In [4]:

p_list, ep_simple_list, shot_list, ep_tree_dict = load_ep(ep_dir)

In [22]:

# config.txt
# protein_pdb_file = protein_dir + '/2P2IA_receptor.pdb'
# protein_pdbqt_file = protein_dir + '/2P2IA_receptor.pdbqt'
ligand_pdb_file = protein_dir + '/2P2IA_608.pdb' 
start_smi = 'NCc1ccncc1'
smi_ref_com = 'Cc1ccncc1'
ref_atom_idx_ref = 0

m_ref = Chem.MolFromPDBFile(ligand_pdb_file, removeHs=True)

m_start_h = utils.prep_ref_mol_simple(start_smi, m_ref)
m_ref_com_h = utils.prep_ref_mol_simple(smi_ref_com, m_ref)
m_start = Chem.RemoveHs(m_start_h)
m_ref_com = Chem.RemoveHs(m_ref_com_h)

cut_para_dict = {
    'num_rb': 12,
    'num_heavy_atoms': 60,
    'mol_wt': 650,
    'timeout_docking': 120,
    'cut_score': 3,  # clip docking score
    'cut_rmsd': 2.5,
}

# penelty_para_dict = {'w_logp': 1.0, 'w_mw': 1.0, 'w_ha': 1.0, 'w_hd': 1.0,}
penelty_para_dict = {'w_logp': 0.0, 'w_mw': 0.0, 'w_ha': 0.0, 'w_hd': 0.0}

In [None]:
#p_list.append([idx, cumulative_reward, final_reward, dock_score, dock_rmsd, mol_wt, num_rb, logp, num_hd, num_ha, num_ring, tpsa, num_heavy_atoms])

In [None]:
a_list = list()
leg_list = list()
num_ba = 20000
for i in range(000,4):
    ini = i*num_ba
    fin = (i+1)*num_ba
    a_list.append(p_list[ini:fin,2])
    leg_list.append('%s-%s' %(ini//200+1, fin//200))

#_ = plt.hist(a_list, density = True)
_ = plt.hist(a_list, bins=np.arange(-3,13.5,2), density = True)

plt.xlabel('Docking score', fontsize=16)
plt.ylabel('Density', fontsize=16)
plt.legend(leg_list, title = 'Epochs')
plt.tight_layout()
plt.show()

In [None]:
dd_list = list()

num_ba = 200
j=1
for i in range(000,400):
    ini = i*num_ba
    fin = (i+1)*num_ba
    d_mean = p_list[ini:fin,j].mean()
    dd_list.append([i+1,d_mean])
dd_list = np.array(dd_list)
plt.plot(dd_list[:,0], dd_list[:,1])
plt.xlabel('Epochs', fontsize=16)
plt.ylabel('Average Total reward', fontsize=16)
plt.tight_layout()
plt.show()

In [9]:
count = 0
mm_dict = dict()
gg_list = list()
for i, ep_ss in enumerate(ep_simple_list[:80000]):
    idx = ep_ss['idx']
    ep_list = ep_ss['ep']
    dock_score = ep_ss['dock_score']
    ss = list()
    for ep in ep_list:
        if ep[1] :
            ss.append('%d' %ep[0])
    ss2 = '_'.join(ss)
    if ss2 in mm_dict:
        count += 1
#        print(ss2, dock_score, idx, mm_dict[ss2][0])
    else:
        mm_dict[ss2] = (idx, dock_score)
#        if dock_score < -14:
#            print(ss2, dock_score, idx)
    if i%200==199:
        gg_list.append(((i+1)/200, count))
print(count)
gg_list = np.array(gg_list, dtype=int)

14396


In [10]:
df = pd.DataFrame(gg_list,columns=['epoch', 'VGFR2'])
#df.to_csv('cumu_vgfr2.csv', index=False)

In [None]:
plt.plot(gg_list[:,0], gg_list[:,1])
plt.xlabel('Epochs', fontsize=16)
plt.ylabel('Cumulative Duplicate', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
temperature0 = 0.45
temp_reduce = 0.99
temperature_min = 0.05
temperature_list = list()
for i_gen in range(0,400):
    temperature = temperature0 * np.power(temp_reduce, i_gen) + temperature_min
    temperature_list.append((i_gen, temperature))
temperature_list = np.array(temperature_list)
plt.plot(temperature_list[:,0], temperature_list[:,1])

In [14]:
dock_score_list = p_list[:,2]
idx_sort = (-dock_score_list).argsort()

In [None]:
for i in range(0,10):
    ep_idx0 = idx_sort[i]
    ep_simple = ep_simple_list[ep_idx0]
    print(ep_simple)
#    m_new = copy.copy(m_start)
#    smi_new = Chem.MolToSmiles(m_new)
#    print(smi_new)
#    ggg.append(ep_idx0)
    ep_idx = ep_simple['idx']
    ep_r = ep_simple['ep']
    ep_dock = ep_simple['dock_score']
    print(ep_idx, ep_dock, ep_r)

In [None]:
for ep_idx0 in idx_sort[0:10]:
    ep_simple = ep_simple_list[ep_idx0]
    m_0 = copy.copy(m_start)
    m_0.RemoveAllConformers()
    
    m_new = copy.copy(m_start)
    smi_new = Chem.MolToSmiles(m_new)
    print('start_SMILES:',smi_new)
    
    ep_idx = ep_simple['idx']
    ep_r = ep_simple['ep']
    ep_dock = ep_simple['dock_score']
#    print(ep_idx, ep_dock, ep_r)
    bb_select_list = list()
    bb_select_list.append(m_0)
    bb_select_list_id = list()
    bb_select_list_id.append('start')
    
    for ep_r0 in ep_r:
        rea_id = ep_r0[0]
        if not ep_r0[1]:
            continue
        pr_list = rl_utils.possible_reaction(m_new, reactant_id_dict, df_reaction)
        df00 = df_bb.loc[rea_id]
        smi_bb = df00['SMILES']
        rea2_tmp = df00[df00==True]
        m_cc = mol_bb_dict[rea_id]
        b_id = df_bb.at[(rea_id,'Catalog_ID')]
    #    print(b_id)
        bb_select_list_id.append(b_id)
        bb_select_list.append(m_cc)
    #    print(ep_r0)
        for pr in pr_list:
            if pr[3] in rea2_tmp:
    #            print(pr)
                reaction_id = pr[1]
                reactant_num = pr[2]
    #            print(reaction_id)
                smirks = df_reaction.at[reaction_id, 'smirks']
                reaction_name = df_reaction.loc[reaction_id]['name']
                rxn = AllChem.ReactionFromSmarts(smirks)
    #            print(smirks)
                if reactant_num == 1:
                    m_new_list = rxn.RunReactants((m_new, m_cc))
                    
                else:
                    m_new_list = rxn.RunReactants((m_cc, m_new))
                m_new = m_new_list[0][0]
                m_new = Chem.RemoveAllHs(m_new)
                smi_new = Chem.MolToSmiles(m_new)
                print(rea_id, b_id, smi_bb, reaction_id, reaction_name, smirks)
                break
    print('final_SMILES:',smi_new)
    print('\n')