In [50]:
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import json
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import PIL.Image
import random
# import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import py_environment
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.environments import wrappers
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import array_spec
from tf_agents.trajectories import trajectory
from tf_agents.trajectories import time_step as ts
from tf_agents.utils import common

from molgym.agents.preprocessing import MorganFingerprints
from molgym.envs.rewards import RewardFunction
from molgym.envs.rewards.multiobjective import AdditiveReward
from molgym.envs.rewards.oneshot import OneShotScore
from molgym.envs.rewards.tuned import LogisticCombination
from molgym.envs.rewards.rdkit import LogP, QEDReward, SAScore, CycleLength
from molgym.envs.rewards.mpnn import MPNNReward
from rdkit.Chem import GetPeriodicTable, MolFromSmiles
from molgym.mpnn.layers import custom_objects
from molgym.utils.conversions import convert_nx_to_smiles, convert_rdkit_to_nx, convert_smiles_to_nx
from tensorflow.keras.models import load_model


tf.compat.v1.enable_v2_behavior()
print(tf.version.VERSION)

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

2.4.0-dev20200802


In [51]:
# from molgym.agents.preprocessing import MorganFingerprints
# from molgym.utils.conversions import convert_rdkit_to_nx
# from molgym.envs.actions.utils import get_valid_actions
# emb_sz = 64 
# processor =  MorganFingerprints(emb_sz)
# from rdkit import Chem
# # m = Chem.MolFromSmiles('CCC(CC)COC(=O)C(C)NP(=O)(OCC1C(C(C(O1)(C#N)C2=CC=C3N2N=CN=C3N)O)O)OC4=CC=CC=C4')
# m = Chem.MolFromSmiles('C#C')

# graph = convert_rdkit_to_nx(m)
# print(list(graph.nodes(data=True)))
# valid_actions = get_valid_actions('C#C', set(['C', 'O', 'N', 'F']), False, False, None, True, None)
# print(np.random.choice(5))
# # embedding = processor.get_features([graph])
# # print(embedding[0])

In [77]:
class MolecularGraphEncoder:
    def __init__(self, opt='mf', emb_sz=64):
        self.opt = opt 
        self.processor =  MorganFingerprints(emb_sz)
        return
    
    def encode(self, smiles_str):
        graph = convert_smiles_to_nx(smiles_str)
        if self.opt == 'mf': # Morgan fingerprint
            embedding = self.processor.get_features([graph])
            return embedding[0]
        else:
            raise NotImplementedError

class MolDesignEnv(py_environment.PyEnvironment):
    def __init__(self):
        self.action_space = ['C', 'O', 'N', 'F'] # or it can be a functional group as well
        self.emb_dims = 128
        self.MAX_ATOM_COUNT = 64
        self.MAX_BOND_COUNT = 100
        self.graph_encoder = MolecularGraphEncoder()
        self.init_reward_func()

        # Action is a 4-tuple (src_node_idx, relation_type, dst_node_idx, dst_node_type)
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=self.MAX_ATOM_COUNT-1, name='action')
#             shape=(1, 4), dtype=np.int32, minimum=0, maximum=self.MAX_ATOM_COUNT-1, name='action')

        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(1,), dtype=np.float32, minimum=0, name='observation')

    #     m = Chem.MolFromSmiles('C')
        self._state = 'C' #convert_rdkit_to_nx(m)
        self._episode_ended = False
    
    def init_reward_func(self):
        # Get the list of elements
        #  We want those where SMILES supports implicit valences
        mpnn_dir = os.path.join('../notebooks', 'mpnn-training')
        with open(os.path.join(mpnn_dir, 'atom_types.json')) as fp:
            atom_types = json.load(fp)
        with open(os.path.join(mpnn_dir, 'bond_types.json')) as fp:
            bond_types = json.load(fp)
        pt = GetPeriodicTable()
        elements = [pt.GetElementSymbol(i) for i in atom_types]
        elements = [e for e in elements if MolFromSmiles(e) is not None]

        # Prepare the one-shot model. We the molecules to compare against and the comparison model
        with open(os.path.join('../seed-molecules', 'top_100_pIC50.json')) as fp:
            comparison_mols = [convert_smiles_to_nx(s) for s in json.load(fp)]
        oneshot_dir = '../similarity'
        oneshot_model = load_model(os.path.join(oneshot_dir, 'oneshot_model.h5'), custom_objects=custom_objects)
        with open(os.path.join(oneshot_dir, 'atom_types.json')) as fp:
            os_atom_types = json.load(fp)
        with open(os.path.join(oneshot_dir, 'bond_types.json')) as fp:
            os_bond_types = json.load(fp)

        # Making all of the reward functions
        # model = load_model(os.path.join(mpnn_dir, 'best_model.h5'), custom_objects=custom_objects)

        rewards = {
                'logP': LogP(maximize=True),
        #         'ic50': MPNNReward(model, atom_types, bond_types, maximize=True),
                'QED': QEDReward(maximize=True),
                'SA': SAScore(maximize=False),
                'cycles': CycleLength(maximize=False),
                'oneshot': OneShotScore(oneshot_model, os_atom_types, os_bond_types, comparison_mols, maximize=True)
            }

        # Load in the ranges for reward functions, used in making multi-objective searches
        with open('reward_ranges.json') as fp:
            ranges = json.load(fp)

        opt_reward = 'QED'
        # Make the reward function
        if opt_reward == 'ic50':
            self.reward_func = rewards['ic50']
        elif opt_reward == 'logP':
            self.reward_func = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['logP', 'SA', 'cycles']])
        elif opt_reward == "QED":
            self.reward_func = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['QED', 'SA', 'cycles']])
        elif opt_reward == "MO":
            self.reward_func = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['ic50', 'QED', 'SA', 'cycles']])
        elif opt_reward == "oneshot":
            self.reward_func = rewards['oneshot']
        elif opt_reward == "tuned":
            self.reward_func = LogisticCombination(rewards['ic50'], rewards['oneshot'])
        else:
            raise ValueError(f'Reward function not defined: {args.reward}')
        return

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._state = 'C'
        self._episode_ended = False
        return ts.restart(self.graph_encoder.encode(self._state))

    def _step(self, action):

        if self._episode_ended:
            # The last action ended the episode. Ignore the current action and start
            # a new episode.
            return self.reset()

        mol = Chem.MolFromSmiles(self._state)
        graph = convert_rdkit_to_nx(mol)
        
        # This can change if we are dealing with a coarsened graph and nodes 
        # represent functional groups and NOT atoms
        atom_count = graph.number_of_nodes() 
        num_bonds = graph.number_of_edges()
        # Make sure episodes don't go on forever.
        if atom_count == self.MAX_ATOM_COUNT or num_bonds == self.MAX_BOND_COUNT:
            self._episode_ended = True
        else:
            valid_actions = get_valid_actions(self._state, set(['C', 'O', 'N', 'F']), False, False, None, True, None)
            if len(valid_actions) == 0:
                self._episode_ended = True
            else:  
                self._state = random.sample(valid_actions, 1)[0]

        reward = self.reward_func(graph)
        if self._episode_ended:
            return ts.termination(self.graph_encoder.encode(self._state), reward)
        else:
            return ts.transition(self.graph_encoder.encode(self._state), reward=0.0, discount=1.0)


In [78]:
# # Get the list of elements
# #  We want those where SMILES supports implicit valences
# mpnn_dir = os.path.join('../notebooks', 'mpnn-training')
# with open(os.path.join(mpnn_dir, 'atom_types.json')) as fp:
#     atom_types = json.load(fp)
# with open(os.path.join(mpnn_dir, 'bond_types.json')) as fp:
#     bond_types = json.load(fp)
# pt = GetPeriodicTable()
# elements = [pt.GetElementSymbol(i) for i in atom_types]
# elements = [e for e in elements if MolFromSmiles(e) is not None]

# # Prepare the one-shot model. We the molecules to compare against and the comparison model
# with open(os.path.join('../seed-molecules', 'top_100_pIC50.json')) as fp:
#     comparison_mols = [convert_smiles_to_nx(s) for s in json.load(fp)]
# oneshot_dir = '../similarity'
# oneshot_model = load_model(os.path.join(oneshot_dir, 'oneshot_model.h5'), custom_objects=custom_objects)
# with open(os.path.join(oneshot_dir, 'atom_types.json')) as fp:
#     os_atom_types = json.load(fp)
# with open(os.path.join(oneshot_dir, 'bond_types.json')) as fp:
#     os_bond_types = json.load(fp)
        
# # Making all of the reward functions
# # model = load_model(os.path.join(mpnn_dir, 'best_model.h5'), custom_objects=custom_objects)

# rewards = {
#         'logP': LogP(maximize=True),
# #         'ic50': MPNNReward(model, atom_types, bond_types, maximize=True),
#         'QED': QEDReward(maximize=True),
#         'SA': SAScore(maximize=False),
#         'cycles': CycleLength(maximize=False),
#         'oneshot': OneShotScore(oneshot_model, os_atom_types, os_bond_types, comparison_mols, maximize=True)
#     }

# # Load in the ranges for reward functions, used in making multi-objective searches
# with open('reward_ranges.json') as fp:
#     ranges = json.load(fp)

# opt_reward = 'QED'
# # Make the reward function
# if opt_reward == 'ic50':
#     reward = rewards['ic50']
# elif opt_reward == 'logP':
#     reward = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['logP', 'SA', 'cycles']])
# elif opt_reward == "QED":
#     reward = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['QED', 'SA', 'cycles']])
# elif opt_reward == "MO":
#     reward = AdditiveReward([{'reward': rewards[r], **ranges[r]} for r in ['ic50', 'QED', 'SA', 'cycles']])
# elif opt_reward == "oneshot":
#     reward = rewards['oneshot']
# elif opt_reward == "tuned":
#     reward = LogisticCombination(rewards['ic50'], rewards['oneshot'])
# else:
#     raise ValueError(f'Reward function not defined: {args.reward}')
    
# print(reward(graph))

In [79]:
environment = MolDesignEnv()
action = np.array(1, dtype=np.int32)
time_step = environment.reset()
# print(time_step)
num_time_steps = 0
while not time_step.is_last():
    time_step = environment.step(action)
    num_time_steps += 1
print('Number of steps: %d' % num_time_steps)

RDKit ERROR: [04:13:42] SMILES Parse Error: syntax error while parsing: Si
RDKit ERROR: [04:13:42] SMILES Parse Error: Failed parsing SMILES 'Si' for input: 'Si'
RDKit ERROR: [04:13:42] SMILES Parse Error: syntax error while parsing: Mn
RDKit ERROR: [04:13:42] SMILES Parse Error: Failed parsing SMILES 'Mn' for input: 'Mn'
RDKit ERROR: [04:13:42] non-ring atom 1 marked aromatic
RDKit ERROR: [04:13:42] SMILES Parse Error: syntax error while parsing: Cu
RDKit ERROR: [04:13:42] SMILES Parse Error: Failed parsing SMILES 'Cu' for input: 'Cu'


Number of steps: 5


In [80]:
train_env = tf_py_environment.TFPyEnvironment(MolDesignEnv())
eval_env = tf_py_environment.TFPyEnvironment(MolDesignEnv())

print('Observation Spec:')
print(train_env.time_step_spec().observation)
print('Reward Spec:')
print(train_env.time_step_spec().reward)
print('Action Spec:')
print(train_env.action_spec())

fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

RDKit ERROR: [04:13:43] SMILES Parse Error: syntax error while parsing: Si
RDKit ERROR: [04:13:43] SMILES Parse Error: Failed parsing SMILES 'Si' for input: 'Si'
RDKit ERROR: [04:13:43] SMILES Parse Error: syntax error while parsing: Mn
RDKit ERROR: [04:13:43] SMILES Parse Error: Failed parsing SMILES 'Mn' for input: 'Mn'
RDKit ERROR: [04:13:43] non-ring atom 1 marked aromatic
RDKit ERROR: [04:13:43] SMILES Parse Error: syntax error while parsing: Cu
RDKit ERROR: [04:13:43] SMILES Parse Error: Failed parsing SMILES 'Cu' for input: 'Cu'
RDKit ERROR: [04:13:44] SMILES Parse Error: syntax error while parsing: Si
RDKit ERROR: [04:13:44] SMILES Parse Error: Failed parsing SMILES 'Si' for input: 'Si'
RDKit ERROR: [04:13:44] SMILES Parse Error: syntax error while parsing: Mn
RDKit ERROR: [04:13:44] SMILES Parse Error: Failed parsing SMILES 'Mn' for input: 'Mn'
RDKit ERROR: [04:13:44] non-ring atom 1 marked aromatic
RDKit ERROR: [04:13:44] SMILES Parse Error: syntax error while parsing: Cu
RDK

Observation Spec:
BoundedTensorSpec(shape=(1,), dtype=tf.float32, name='observation', minimum=array(0., dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
Reward Spec:
TensorSpec(shape=(), dtype=tf.float32, name='reward')
Action Spec:
BoundedTensorSpec(shape=(), dtype=tf.int32, name='action', minimum=array(0, dtype=int32), maximum=array(63, dtype=int32))


In [81]:
def compute_avg_return(environment, policy, num_episodes=10):

    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
        episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

def collect_step(environment, policy, buffer):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    # Add trajectory to the replay buffer
    buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
    for _ in range(steps):
        collect_step(env, policy, buffer)

In [82]:
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_max_length)

#@test {"skip": true}
try:
    %%time
except:
    pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

    # Collect a few steps using collect_policy and save to the replay buffer.
    for _ in range(collect_steps_per_iteration):
        collect_step(train_env, agent.collect_policy, replay_buffer)
    
    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step, avg_return))
        returns.append(avg_return)

ValueError: <tf_agents.networks.q_network.QNetwork object at 0x7f80efe6d990>: Inconsistent dtypes or shapes between `inputs` and `input_tensor_spec`.
dtypes:
<dtype: 'float64'>
vs.
<dtype: 'float32'>.
shapes:
(1, 64)
vs.
(1,).