In [89]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool

import gym
from gym import spaces
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import rdchem
from rdkit.Chem import rdMolDescriptors
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
import deepchem as dc

In [98]:
class MolecularGNN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, num_tasks):
        super(MolecularGNN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_feats)
        self.conv2 = GCNConv(hidden_feats, out_feats)
        self.linear = nn.Linear(out_feats, num_tasks)

    def forward(self, state, action, hidden=None):
        """Applies the GNN to the molecule state and the one-hot encoded action.

        Args:
            state (torch.Tensor): A tensor of shape (batch_size, num_nodes, node_feature_size)
                representing the current state of the molecule.
            action (torch.Tensor): A tensor of shape (batch_size, num_bond_types) representing
                the current one-hot encoded action.
            hidden (torch.Tensor): A tensor of shape (batch_size, gnn_hidden_size) representing the
                hidden state from the previous step.

        Returns:
            torch.Tensor: A tensor of shape (batch_size, gnn_hidden_size) representing the hidden
                state of the GNN after processing the input state and action.
        """
        state = self.node_encoder(state)
        action = self.action_encoder(action)
        input_ = torch.cat([state, action], dim=1)

        if hidden is None:
            hidden = torch.zeros((input_.size(0), self.gnn_hidden_size), device=input_.device)

        hidden = self.gnn(input_, hidden)

        return hidden

In [99]:
# Load Tox21 dataset from MoleculeNet
tasks, datasets, transformers = dc.molnet.load_tox21()
train_dataset, valid_dataset, test_dataset = datasets

In [100]:
def mol_to_pyg_graph(mol, target):
    if mol is None:
        return None

    # Get atom features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append([atom.GetAtomicNum(), atom.GetDegree(), atom.GetTotalValence()])
    atom_features = torch.tensor(atom_features, dtype=torch.float)

    # Get bond features and adjacency info
    bond_indices = []
    bond_features = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bond_indices.append((start, end))
        bond_indices.append((end, start))
        bond_features.append(bond.GetBondTypeAsDouble())
        bond_features.append(bond.GetBondTypeAsDouble())
    bond_indices = torch.tensor(bond_indices, dtype=torch.long).t().contiguous()
    bond_features = torch.tensor(bond_features, dtype=torch.float).view(-1, 1)

    # Create PyTorch Geometric Data object
    data = Data(x=atom_features, edge_index=bond_indices, edge_attr=bond_features, y=target)
    return data

In [93]:
# Convert the datasets into PyTorch Geometric graph representations with target labels
train_graphs = [mol_to_pyg_graph(Chem.MolFromSmiles(smiles), torch.tensor([label], dtype=torch.float32)) for smiles, label in zip(train_dataset.ids, train_dataset.y)]
valid_graphs = [mol_to_pyg_graph(Chem.MolFromSmiles(smiles), torch.tensor([label], dtype=torch.float32)) for smiles, label in zip(valid_dataset.ids, valid_dataset.y)]
test_graphs = [mol_to_pyg_graph(Chem.MolFromSmiles(smiles), torch.tensor([label], dtype=torch.float32)) for smiles, label in zip(test_dataset.ids, test_dataset.y)]




In [94]:
# Filter out invalid molecules
train_graphs = [graph for graph in train_graphs if graph is not None]
valid_graphs = [graph for graph in valid_graphs if graph is not None]
test_graphs = [graph for graph in test_graphs if graph is not None]

In [101]:
in_feats = 3
hidden_feats = 64
out_feats = 128
num_tasks = 12  # Number of tasks in Tox21 dataset
gnn = MolecularGNN(in_feats, hidden_feats, out_feats, num_tasks)

In [102]:
# Train the GNN model
# Replace train_gnn() with the appropriate training function and dataset
def train_gnn(gnn, train_graphs, epochs=50, batch_size=32, learning_rate=0.01):
    gnn.train()
    optimizer = torch.optim.Adam(gnn.parameters(), lr=learning_rate)
    data_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        for batch in data_loader:
            optimizer.zero_grad()
            output = gnn(batch)
            # Assuming a simple regression task with mean squared error loss
            loss = F.mse_loss(output, batch.y)
            loss.backward()
            optimizer.step()

In [103]:
train_gnn(gnn, train_graphs)



TypeError: MolecularGNN.forward() missing 1 required positional argument: 'action'

In [104]:
import numpy as np
import gym
from gym import spaces
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors

class MoleculeEnvironment(gym.Env):
    def __init__(self, gnn, target_mol_weight):
        super(MoleculeEnvironment, self).__init__()
        
        self.gnn = gnn
        self.target_mol_weight = target_mol_weight

        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)

        self.current_mol = None

    def reset(self):
        initial_mol = Chem.MolFromSmiles("C")
        self.current_mol = initial_mol
        initial_observation = self._get_observation()
        return initial_observation

    def step(self, action):
        # Calculate the new observation, reward, and done status
        self.current_mol = self.apply_action(self.current_mol, action)
        self.current_smiles = Chem.MolToSmiles(self.current_mol)
        mol_vec = self.featurize(self.current_mol)
        action_vec = self.action_featurizer(action)
        obs = self.gnn.forward(mol_vec, action_vec, self.hidden)
        self.hidden = obs
        mw = rdMolDescriptors.CalcExactMolWt(self.current_mol)
        reward = -abs(self.target_mol_weight - mw)
        return obs, reward, False, {}

    def _take_action(self, action):
        # Generate the new molecule based on the selected action
        new_mol = self.gnn(self.current_mol, action)
        if new_mol is not None:
            self.current_mol = new_mol

    def _get_observation(self):
        mol = self.current_mol
        mw = rdMolDescriptors.CalcExactMolWt(mol)
        observation = np.array([mw], dtype=np.float32)
        return observation

    def _get_reward(self):
        mol = self.current_mol
        mw = rdMolDescriptors.CalcExactMolWt(mol)
        reward = -abs(self.target_mol_weight - mw)
        return reward

    def _is_done(self):
        mol = self.current_mol
        mw = rdMolDescriptors.CalcExactMolWt(mol)
        done = np.isclose(self.target_mol_weight, mw, atol=0.1)
        return done

    def render(self, mode="human"):
        if mode == "human":
            mol = self.current_mol
            smiles = Chem.MolToSmiles(mol)
            print(smiles)
        else:
            raise NotImplementedError("Only 'human' mode is supported for rendering.")


In [105]:
# Create the Molecule Environment
target_mol_weight = 10
env = MoleculeEnvironment(gnn, target_mol_weight)
check_env(env)

AttributeError: 'MoleculeEnvironment' object has no attribute 'apply_action'

In [106]:

# Train the RL agent (PPO)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)

# Generate a novel molecule
obs = env.reset()
done = False
while not done:
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()

# Print the final generated molecule
print("Generated molecule:", env.render())

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


AttributeError: 'MoleculeEnvironment' object has no attribute 'apply_action'