In [None]:
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("..")
import os
import json
import numpy as np
import pandas as pd
import functools
from dqn import molecules
from dqn import deep_q_networks
from dqn.py.SA_Score import sascorer

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors, QED

import matplotlib.pyplot as plt
import tensorflow as tf
from pathlib import Path

In [None]:
def latest_ckpt(path):
    return max([int(p.stem.split('-')[1]) for p in path.iterdir() if p.stem[:4] == 'ckpt'])

In [None]:
basepath = '/Users/odin/sherlock_scratch/moldqn2/target_sas/mol%i_target_%.1f'
path = Path(basepath %(1, 4.8))
latest_ckpt(path)

In [None]:
all_molecules = ["CCCN(C)N=Nc1ccc(cc1)C(=O)O",
 "CN1CCC[C@H]1c2cccnc2",
 "CCCCCC(O)c1cccc(OCc2cccc(c2)C(=O)OC)c1",
 "CCc1c(C)[nH]c2CCC(CN3CCOCC3)C(=O)c12",
 "COc1cc(cc(OC)c1OC)C(=O)N2CCN(C(COC(=O)CC(C)(C)C)C2)C(=O)c3cc(OC)c(OC)c(OC)c3",
 "Cc1cc(C)cc(c1)N2C(=O)Cc3ccccc3C2=O",
 "CCCCCCCCCCCCc1ccc(OCCCC(C)(C)C(=O)O)cc1OCCCC(C)(C)C(=O)O",
 "COc1ccc(C[C@@H](C)NC[C@H](O)c2ccc(O)c(NC=O)c2)cc1",
 "CC12CC3CC(C)(C1)CC(N)(C3)C2",
 "CC(C)NCC(O)COC(=O)c1ccc(NC(=O)C)cc1",
 "CCC1=C(CNC1=O)c2ccc(cc2)n3ccnc3",
 "CN(C)CCCn1cc(C2=C(C(=O)NC2=O)c3cn(CCOCCO)c4ccccc34)c5ccccc15",
 "Cc1c(ccc2nc(N)nc(N)c12)C(=O)NC(CCC(=O)O)C(=O)O",
 "COc1cc2nc(nc(N)c2cc1OC)N3CCN(CC3)C(=O)Nc4ccccc4",
 "O=C1CCC(N2C(=O)c3ccccc3C2=O)C(=O)N1",
 "OC(=O)c1ccccc1",
 "CCCCCCCC1=CC(=CC(=O)O1)OC",
 "CCOc1ccc2c(c1)c(CCNC(=O)C3CC3)c4c5ccccc5CCCn24",
 "CC(Cc1ccc(O)c(O)c1)C(C)Cc2ccc(O)c(O)c2",
 "CCN1c2ccccc2Cc3c(O)ncnc13",
 "CN1C(=O)C2(OCCO2)c3ccccc13",
 "CN1C(=O)NC2=C(N(C)C(=O)N2)C1=O",
 "CN(C)c1ncnc2c1ncn2Cc3cccc(C)c3",
 "Cc1cccc(CC2CCc3nc(N)nc(N)c3C2)c1",
 "CC(C)NCC(O)COC(=O)c1ccc(CO)cc1",
 "OC(=O)CN(CCN(CC(=O)O)CC(=O)O)CC(=O)O",
 "CN(CC=C)CC(N(C)CC=C)C(=O)Nc1c(C)cccc1C",
 "OCCCc1cc2OCCc2cc1O",
 "Cc1cc(CCCCCCCOc2ccc(cc2)C3=NC(C)(C)CO3)on1",
 "CCCCCCN1CCN2CC(c3ccccc3)c4ccccc4C2C1",
 "COc1c2OC(=O)C=Cc2c(COCCCO)c3ccoc13"]

In [None]:
def eval(model_dir, idx):
  ckpt = latest_ckpt(Path(model_dir))
  hparams_file = os.path.join(model_dir, 'config.json')
  try:
    fh = open(hparams_file, 'r')
  except FileNotFoundError:
    fh = open('/Users/odin/sherlock_scratch/moldqn2/target_sas/config.json', 'r')
  hp_dict = json.load(fh)
  hparams = deep_q_networks.get_hparams(**hp_dict)
  fh.close()

  environment = molecules.Molecule(
      atom_types=set(hparams.atom_types),
      init_mol=all_molecules[idx],
      allow_removal=hparams.allow_removal,
      allow_no_modification=hparams.allow_no_modification,
      allowed_ring_sizes=set(hparams.allowed_ring_sizes),
      allow_bonds_between_rings=hparams.allow_bonds_between_rings,
      max_steps=hparams.max_steps_per_episode)

  dqn = deep_q_networks.DeepQNetwork(
      input_shape=(hparams.batch_size, hparams.fingerprint_length + 1),
      q_fn=functools.partial(
          deep_q_networks.multi_layer_model, hparams=hparams),
      optimizer=hparams.optimizer,
      grad_clipping=hparams.grad_clipping,
      num_bootstrap_heads=hparams.num_bootstrap_heads,
      gamma=hparams.gamma,
      epsilon=0.0)
  
  tf.reset_default_graph()
  with tf.Session() as sess:
    dqn.build()
    model_saver = tf.train.Saver(max_to_keep=hparams.max_num_checkpoints)
    model_saver.restore(sess, os.path.join(model_dir, 'ckpt-%i' % ckpt))
    environment.initialize()
    for step in range(hparams.max_steps_per_episode):
      steps_left = hparams.max_steps_per_episode - environment.num_steps_taken
      
      if hparams.num_bootstrap_heads:
        head = np.random.randint(hparams.num_bootstrap_heads)
      else:
        head = 0
      valid_actions = list(environment.get_valid_actions())
      observations = np.vstack(
        [np.append(deep_q_networks.get_fingerprint(act, hparams), steps_left) 
         for act in valid_actions])
      action = valid_actions[dqn.get_action(
          observations, head=head, update_epsilon=0.0)]
      result = environment.step(action)
  return ckpt, result


In [None]:
all_results = []
for i in range(31):
    for target in (2.5, 4.8):
        ckpt, result = eval(basepath %(i, target), i)
        ori_sas = sascorer.calculateScore(Chem.MolFromSmiles(all_molecules[i]))
        sas = sascorer.calculateScore(Chem.MolFromSmiles(result.state))
        all_results.append((i, ckpt, all_molecules[i], result.state, ori_sas, target, sas))

In [None]:
df = pd.DataFrame(all_results, columns=['index', 'ckpt', 'original_molecule', 'generated_molecule', 'original_sas', 'target_sas', 'sas'])
df.to_csv('target_sas_results.csv')

In [None]:
plt.figure()
df25 = df[df['target_sas'] == 2.5]
x25 = df25['original_sas']
y25 = df25['sas']
plt.scatter(x25, y25, label='target_sas=2.5')

df48 = df[df['target_sas'] == 4.8]
x48 = df48['original_sas']
y48 = df48['sas']
plt.scatter(x48, y48, label='target_sas=4.8')

plt.legend()
plt.show()