# MedGAN

Model 2 | Subset ZINC15-II

Based on the principles of Wasserstein Generative Adversarial Networks and Graph Convolutional Networks, MedGAN generates new quinoline-scaffold molecules from molecular graphs.

In [None]:
# import packages

import numpy as np
import os
import pandas as pd
import datetime
import gc
import io
import json
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import networkx as nx
from rdkit.Chem import rdmolops
from rdkit.Chem.rdmolops import AddHs
from rdkit.Chem.rdmolops import GetMolFrags
from rdkit import Chem, RDLogger
from rdkit import DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors
from rdkit.Chem import AtomValenceException
from rdkit.Chem import Descriptors
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdmolfiles
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.DataStructs.cDataStructs import TanimotoSimilarity
import rdkit.RDLogger as rdl
from tensorflow import keras
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.utils import plot_model
import tensorflow as tf
from scipy.spatial.distance import pdist, squareform
import concurrent.futures
from tqdm import tqdm
import base64
from IPython.display import display
from IPython.display import Image
from multiprocessing import Pool, cpu_count
from PIL import Image
import gzip
import pickle
import psutil
import pygraphviz as pgv
import time
RDLogger.DisableLog("rdApp.*")
tf.get_logger().setLevel('ERROR')
logger = rdl.logger()
logger.setLevel(rdl.ERROR)
logger.setLevel(rdl.CRITICAL)

## Data Preprocessing: Filtering and Sampling of Quinoline Molecules

The following section is dedicated to loading and processing quinoline molecule data from the ZINC15 dataset. Initially, the script checks whether a pre-filtered dataset (non_duplicate_filtered_quinolines_zinc15_50atoms.csv) exists. If it doesn't, the original dataset (non_duplicate_quinolines_zinc15.csv) is loaded and molecules are filtered based on specific criteria:

Only molecules containing the atoms - Carbon (C), Nitrogen (N), Oxygen (O), Hydrogen (H), Fluorine (F), Sulfur (S), and Chlorine (Cl) are retained.
    
Molecules having up to a maximum of 50 atoms, including hydrogen, are considered.

Post filtering, duplicates are removed, and the processed data is saved for future use. If the filtered dataset already exists, it is directly loaded.

To ensure consistent results and manageable data sizes, a subsample of 1,000,000 molecules is randomly chosen from this dataset. This subsample is then saved for subsequent analyses. Lastly, as a quick verification, one molecule is printed to visualize its structure and compute its count of heavy atoms.

Note: for Cloud Ocean run, the sample was reduced to 100,000 molecules due to memory restrictions

In [None]:
# load filtered csv or save a new one
csv_path = "../data/non_duplicate_quinolines_zinc15.csv"
filtered_csv_path = "../data/non_duplicate_filtered_quinolines_zinc15_50atoms.csv"

allowed_atoms = {'C', 'N', 'O', 'H', 'F', 'S', 'Cl'}
max_atoms = 50

In [None]:
# functions

def has_allowed_atoms(mol, allowed_atoms):
    for atom in mol.GetAtoms():
        if atom.GetSymbol() not in allowed_atoms:
            return False
    return True

def total_atoms(mol):
    return sum(atom.GetTotalNumHs() + 1 for atom in mol.GetAtoms())

def is_valid_smiles(smiles):
    molecule = Chem.MolFromSmiles(smiles)
    molecule = Chem.AddHs(molecule)
    if (molecule is not None
        and '*' not in smiles
        and has_allowed_atoms(molecule, allowed_atoms)
        and total_atoms(molecule) <= max_atoms):
        return smiles
    return None

def get_cpu_usage():
    return psutil.cpu_percent()

In [None]:
# If filtered CSV doesn't exist, read the original CSV, apply the filter, and save the filtered data
if not os.path.isfile(filtered_csv_path):
    # Read CSV file
    data = pd.read_csv(csv_path, usecols=lambda col: col != 'index', header=None)

    valid_smiles = []
    for smiles in tqdm(data[0]):
        try:
            result = is_valid_smiles(smiles)
            if result is not None:
                valid_smiles.append(result)
        except Exception as e:
            print("Error processing item:", e)

    # Remove duplicates
    valid_smiles = list(set(valid_smiles))

    # Create DataFrame with valid SMILES
    data = pd.DataFrame(valid_smiles, columns=['smiles'])
    
    # Save the filtered data to a CSV file
    data.to_csv(filtered_csv_path, index=False)

# If filtered CSV exists, just load it
else:
    data = pd.read_csv(filtered_csv_path)

data = data.sample(n=100000, random_state=1, replace=False)
data.reset_index(drop=True, inplace=True)

# Save the subsample to a CSV file
data.to_csv("../data/data_zinc15_subset-ii/quinolines_filtered_subsample.csv", index=False)

In [None]:
# Print a sample from the dataset

smiles_test = data['smiles'][100]
print("SMILES:", smiles_test)
molecule = Chem.MolFromSmiles(smiles_test)
molecule = Chem.AddHs(molecule)
print("Num heavy atoms:", molecule.GetNumHeavyAtoms())
molecule

## Convert SMILEs to Graph representation

This code provides a suite of tools for representing molecules as graphs and vice versa. The main components are:

- Mappings: Dictionaries (atom_mapping, bond_mapping, charge_mapping) are used to translate atom and bond types between their string names and numerical indices.
Conversion Functions:

- smiles_to_graph(): Converts SMILES strings into graph matrices.

- graph_to_molecule(): Reconstructs molecules from their graph matrices.

- graph_to_networkx(): Transforms graph matrices into NetworkX graph objects for visualization or algorithmic analysis.

- Visualization: plot_graph() visually represents a molecule using the NetworkX graph, with atoms and bonds color-coded.

- Data Preprocessing: Molecular data is chunked and processed in parallel to convert batches of SMILES strings into graph representations, which are then saved as compressed files for efficiency.

If molecules are already converted to graphs and saved, will be loaded from compressed files.

Overall, these tools facilitate the transition between chemical molecular structures and their graph representations, offering a foundation for graph-based molecular analyses or neural networks.

In [None]:
atom_mapping = {
    "C": 0,
    0: "C",
    "N": 1,
    1: "N",
    "O": 2,
    2: "O",
    "H": 3,
    3: "H",
    "F": 4,
    4: "F",
    "S": 5,
    5: "S",
    "Cl": 6,
    6: "Cl"
}

bond_mapping = {
    "SINGLE": 0,
    0: Chem.BondType.SINGLE,
    "DOUBLE": 1,
    1: Chem.BondType.DOUBLE,
    "TRIPLE": 2,
    2: Chem.BondType.TRIPLE,
    "AROMATIC": 3,
    3: Chem.BondType.AROMATIC,
}

NUM_ATOMS = 50
ATOM_DIM = 7 + 1
BOND_DIM = 4 + 1
LATENT_DIM = 64

def smiles_to_graph(smiles):
    # Converts SMILES to molecule object
    molecule = Chem.MolFromSmiles(smiles)
    molecule = Chem.AddHs(molecule)

    # Initialize adjacency and feature tensor
    adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")
    features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")

    # Loop over each atom in molecule
    for atom in molecule.GetAtoms():
        i = atom.GetIdx()
        atom_type = atom_mapping[atom.GetSymbol()]
        features[i] = np.eye(ATOM_DIM)[atom_type]
        # Loop over one-hop neighbors
        for neighbor in atom.GetNeighbors():
            j = neighbor.GetIdx()
            bond = molecule.GetBondBetweenAtoms(i, j)
            bond_type_idx = bond_mapping[bond.GetBondType().name]
            adjacency[bond_type_idx, [i, j], [j, i]] = 1

    # Where no bond, add 1 to last channel (indicating "non-bond")
    # Notice: channels-first
    adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1

    # Where no atom, add 1 to last column (indicating "non-atom")
    features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1

    return adjacency, features

def graph_to_molecule(graph):
    # Unpack graph
    adjacency, features = graph

    # RWMol is a molecule object intended to be edited
    molecule = Chem.RWMol()

    # Remove "no atoms" & atoms with no bonds
    keep_idx = np.where(
        (np.argmax(features, axis=1) != ATOM_DIM - 1)
        & (np.sum(adjacency[:-1], axis=(0, 1)) != 0)
    )[0]
    features = features[keep_idx]
    adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]

    # Add atoms to molecule
    for atom_type_idx in np.argmax(features, axis=1):
        atom = Chem.Atom(atom_mapping[atom_type_idx])
        _ = molecule.AddAtom(atom)

    # Add bonds between atoms in molecule; based on the upper triangles of the [symmetric] adjacency tensor
    (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)
    for (bond_ij, atom_i, atom_j) in zip(bonds_ij, atoms_i, atoms_j):
        if atom_i == atom_j or bond_ij == BOND_DIM - 1:
            continue
        bond_type = bond_mapping[bond_ij]
        molecule.AddBond(int(atom_i), int(atom_j), bond_type)

    # Sanitize the molecule
    try:
        # Attempt to sanitize the molecule
        flag = Chem.SanitizeMol(molecule, catchErrors=True)

        # Check the flag returned by the sanitization
        if flag != Chem.SanitizeFlags.SANITIZE_NONE:
            return None

    except AtomValenceException as e:
        # If an AtomValenceException error occurred, print the error and return None
        print(f"AtomValenceException during molecule sanitization: {e}")
        return None

    except Exception as e:
        # If any other unexpected error occurred, print the error and return None
        print(f"Unexpected error during molecule sanitization: {e}")
        return None

    # If sanitization was successful, return the molecule
    return molecule

def graph_to_networkx(adjacency, features):
    G = nx.Graph()

    # Add nodes with atom type as a property
    for i, atom_type_idx in enumerate(tf.keras.backend.eval(tf.argmax(features, axis=1))):
        if atom_type_idx != ATOM_DIM - 1:
            atom_type = atom_mapping[atom_type_idx]
            G.add_node(i, atom_type=atom_type)

    # Convert adjacency matrices to NumPy arrays if input is a TensorFlow tensor
    adjacency = [matrix.numpy() if isinstance(matrix, tf.Tensor) else matrix for matrix in adjacency]

    # Add edges with bond type as a property
    for bond_type, adjacency_matrix in enumerate(adjacency):
        if bond_type != BOND_DIM - 1:
            bond_type_str = bond_mapping[bond_type]
            edge_indices = np.where(adjacency_matrix > 0)
            edges = list(zip(*edge_indices))
            for edge in edges:
                G.add_edge(edge[0], edge[1], bond_type=bond_type_str)

    return G

def plot_graph(G, title=''):
    atom_colors = {
        "C": "blue",
        "O": "red",
        "N": "yellow",
        "H": "grey",
        "F": "pink",
        "S": "orange",
        "Cl": "green"
    }

    bond_colors = {
        "SINGLE": "black",
        "DOUBLE": "red",
        "TRIPLE": "blue",
        "AROMATIC": "purple"
    }

    node_colors = [atom_colors.get(G.nodes[node]["atom_type"], "purple") for node in G.nodes()]
    edge_colors = [bond_colors.get(G.edges[edge]["bond_type"], "black") for edge in G.edges()]

    pos = nx.kamada_kawai_layout(G)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    nx.draw(G, pos, with_labels=True, node_color=node_colors, node_size=800, edge_color=edge_colors, ax=ax1)
    nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d["bond_type"] for u, v, d in G.edges(data=True)}, ax=ax1)
    ax1.set_title(title)

    # Add legend for atoms
    atom_legend_handles = [
        mlines.Line2D([], [], color=color, marker="o", linewidth=0, markersize=8, label=atom_type)
        for atom_type, color in atom_colors.items()
    ]
    ax1.legend(handles=atom_legend_handles, loc="lower right")

    # Create and plot the atom type table
    atom_data = {"Atom Number": list(G.nodes()), "Atom Type": [G.nodes[node]["atom_type"] for node in G.nodes()]}
    atom_df = pd.DataFrame(atom_data)
    atom_table = ax2.table(cellText=atom_df.values, colLabels=atom_df.columns, loc='center')
    ax2.axis('off')

    plt.show()

In [None]:
save_folder = "../data/data_zinc15_subset-ii/"
chunk_size = 100000

def process_smiles(smiles):
    adjacency, features = smiles_to_graph(smiles)
    return (adjacency, features)

num_chunks = len(data) // chunk_size
#print(num_chunks)

def process_chunk(i):
    adjacency_tensor_file = os.path.join(save_folder, f'adjacency_tensor_{i}.npz')
    feature_tensor_file = os.path.join(save_folder, f'feature_tensor_{i}.npz')

    if not os.path.isfile(adjacency_tensor_file) or not os.path.isfile(feature_tensor_file):
        # If tensors do not exist, process the chunk data and create tensors
        start = i * chunk_size
        end = (i + 1) * chunk_size
        chunk = data['smiles'][start:end]
        
        # Initialize Pool with the number of available CPUs
        with Pool(cpu_count()) as p:
            processed_chunk_data = list(tqdm(p.imap(process_smiles, chunk), total=len(chunk), desc=f'Processing chunk {i}'))

        # Convert loaded data to adjacency and feature tensors
        adjacency_tensor, feature_tensor = zip(*processed_chunk_data)
        adjacency_tensor = np.array(adjacency_tensor)
        feature_tensor = np.array(feature_tensor)

        # Save tensors to disk (compressed)
        np.savez_compressed(adjacency_tensor_file, adjacency_tensor)
        np.savez_compressed(feature_tensor_file, feature_tensor)

        print(f"Processed and saved tensors from chunk {i}")
        # Explicitly delete the chunk data to free up memory
        del processed_chunk_data
    else:
        # If tensors exist, load them (decompressed)
        adjacency_tensor = np.load(adjacency_tensor_file)['arr_0']
        feature_tensor = np.load(feature_tensor_file)['arr_0']
        
        print(f"Loaded tensors from chunk {i}")

    return adjacency_tensor, feature_tensor

adjacency_tensors = []
feature_tensors = []

# Process each chunk individually after loading
for i in tqdm(range(num_chunks), desc='Loading data chunks'):
    adjacency_tensor, feature_tensor = process_chunk(i)
    adjacency_tensors.append(adjacency_tensor)
    feature_tensors.append(feature_tensor)

adjacency_tensor = np.concatenate(adjacency_tensors)
feature_tensor = np.concatenate(feature_tensors)

print("adjacency_tensor.shape =", adjacency_tensor.shape)
print("feature_tensor.shape =", feature_tensor.shape)

## WGAN R-GCN implementation

This code provides a detailed walkthrough of the implementation of WGAN R-GCN using the Keras library. We break down the components into two main sections: the Graph Generator and the Graph Discriminator.

The Graph Generator function, GraphGenerator, is responsible for creating a graph representation of a molecule given latent space inputs.

The Graph Discriminator function, GraphDiscriminator, assesses the "realness" of a molecule graph. It uses graph convolution layers to process the adjacency and feature matrices, and provides a scalar output representing the authenticity of the input graph.

In [None]:
def GraphGenerator(
    dense_units, dropout_rate, latent_dim, adjacency_shape, feature_shape,
):
    z = keras.layers.Input(shape=(LATENT_DIM,))
    # Propagate through one or more densely connected layers
    x = z
    for units in dense_units:
        #x = BatchNormalization()(x) # Add Batch Normalization layer
        x = keras.layers.Dense(units)(x)
        x = keras.layers.LeakyReLU(alpha=0.01)(x)
        x = keras.layers.Dropout(dropout_rate)(x)

    # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)
    x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)
    x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)
    # Symmetrify tensors in the last two dimensions
    x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2
    x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)

    # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)
    x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)
    x_features = keras.layers.Reshape(feature_shape)(x_features)
    x_features = keras.layers.Softmax(axis=2)(x_features)

    return keras.Model(inputs=z, outputs=[x_adjacency, x_features], name="Generator")

generator = GraphGenerator(
    dense_units=[128, 256, 512],
    dropout_rate=0.50,
    latent_dim=LATENT_DIM,
    adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
    feature_shape=(NUM_ATOMS, ATOM_DIM),
)
generator.summary()

# Save the model summary to a text file
with open('generator_summary.txt', 'w') as f:
    generator.summary(print_fn=lambda x: f.write(x + '\n'))

# Save the model structure as an image
plot_model(generator, to_file='generator_model.png', show_shapes=True, show_layer_names=True)

In [None]:
class RelationalGraphConvLayer(keras.layers.Layer):
    def __init__(
        self,
        units=128,  # 128
        activation="relu",
        use_bias=False,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.units = units
        self.activation = keras.activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = keras.regularizers.get(bias_regularizer)

    def build(self, input_shape):
        bond_dim = input_shape[0][1]
        atom_dim = input_shape[1][2]

        self.kernel = self.add_weight(
            shape=(bond_dim, atom_dim, self.units),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            trainable=True,
            name="W",
            dtype=tf.float32,
        )

        if self.use_bias:
            self.bias = self.add_weight(
                shape=(bond_dim, 1, self.units),
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                trainable=True,
                name="b",
                dtype=tf.float32,
            )

        self.built = True

    def call(self, inputs, training=False):
        adjacency, features = inputs
        # Aggregate information from neighbors
        x = tf.matmul(adjacency, features[:, None, :, :])
        # Apply linear transformation
        x = tf.matmul(x, self.kernel)
        if self.use_bias:
            x += self.bias
        # Reduce bond types dim
        x_reduced = tf.reduce_sum(x, axis=1)
        # Apply non-linear transformation
        return self.activation(x_reduced)


def GraphDiscriminator(
    gconv_units, dense_units, dropout_rate, adjacency_shape, feature_shape
):

    adjacency = keras.layers.Input(shape=adjacency_shape)
    features = keras.layers.Input(shape=feature_shape)

    # Propagate through one or more graph convolutional layers
    features_transformed = features
    for units in gconv_units:
        features_transformed = RelationalGraphConvLayer(units)(
            [adjacency, features_transformed]
        )

    # Reduce 2-D representation of molecule to 1-D
    x = keras.layers.GlobalAveragePooling1D()(features_transformed)

    # Propagate through one or more densely connected layers
    for units in dense_units:
        #x = BatchNormalization()(x)
        x = keras.layers.Dense(units)(x)
        x = keras.layers.LeakyReLU(alpha=0.01)(x)
        x = keras.layers.Dropout(dropout_rate)(x)

    # For each molecule, output a single scalar value expressing the
    # "realness" of the inputted molecule
    x_out = keras.layers.Dense(1, dtype="float32")(x)

    return keras.Model(inputs=[adjacency, features], outputs=x_out)

discriminator = GraphDiscriminator(
    gconv_units= [128, 128, 128, 128],  # [512, 512, 512, 512],
    dense_units= [512, 512],
    dropout_rate=0.50,
    adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
    feature_shape=(NUM_ATOMS, ATOM_DIM),
)
discriminator.summary()

# Save the model summary to a text file
with open('discriminator_summary.txt', 'w') as f:
    discriminator.summary(print_fn=lambda x: f.write(x + '\n'))

# Save the model structure as an image
plot_model(discriminator, to_file='discriminator_model.png', show_shapes=True, show_layer_names=True)

## Generative Adversarial Network for Graphs

The GraphWGAN class implements a Generative Adversarial Network (GAN) to generate graphs. This GAN leverages Wasserstein distance with gradient penalty for training stability.

The train_step method trains the GAN for one step. This involves training the discriminator to differentiate between real and generated graphs, and training the generator to produce graphs that the discriminator cannot differentiate from real graphs.

After training, you can visualize the generated graphs using the plot_generated_graph method. It visualizes the graphs with different node and edge colors according to the atom type and bond type, respectively.

In [None]:
class GraphWGAN(keras.Model):
    def __init__(
        self,
        generator,
        discriminator,
        discriminator_steps=1,
        generator_steps=1,
        gp_weight=10,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.discriminator_steps = discriminator_steps
        self.generator_steps = generator_steps
        self.gp_weight = gp_weight
        self.latent_dim = self.generator.input_shape[-1]
        self.epoch = 0
        self.num_samples = 1
        self.metric_wgan_gen_loss = keras.metrics.Mean(name="wgan_gen_loss")

    def compile(self, optimizer_generator, optimizer_discriminator, **kwargs):
        super().compile(**kwargs)
        self.optimizer_generator = optimizer_generator
        self.optimizer_discriminator = optimizer_discriminator
        self.metric_generator = keras.metrics.Mean(name="loss_gen")
        self.metric_discriminator = keras.metrics.Mean(name="loss_dis")

    # Code to train

    def train_step(self, inputs):
        start_time = time.time()

        if isinstance(inputs[0], tuple):
            inputs = inputs[0]

        graph_real = inputs
        self.batch_size = tf.shape(inputs[0])[0]

        # Train the discriminator for one or more steps
        for _ in range(self.discriminator_steps):
            z = tf.random.normal((self.batch_size, self.latent_dim))
            with tf.GradientTape() as tape:
                graph_generated = self.generator(z, training=True)
                loss = self._loss_discriminator(graph_real, graph_generated)
            grads = tape.gradient(loss, self.discriminator.trainable_weights)
            self.optimizer_discriminator.apply_gradients(zip(grads, self.discriminator.trainable_weights))
            self.metric_discriminator.update_state(loss)

        # Train the generator for one or more steps
        for _ in range(self.generator_steps):
            z = tf.random.normal((self.batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                graph_generated = self.generator(z, training=True)
                loss_wgan_generator = self._loss_generator(graph_generated)
                self.metric_wgan_gen_loss.update_state(loss_wgan_generator)
            grads = tape.gradient(loss_wgan_generator, self.generator.trainable_weights)
            self.optimizer_generator.apply_gradients(zip(grads, self.generator.trainable_weights))
            self.metric_generator.update_state(loss_wgan_generator)

        end_time = time.time()
        time_per_epoch = end_time - start_time

        cpu_usage = get_cpu_usage()

        logs = {m.name: m.result() for m in self.metrics}
        logs['time'] = time_per_epoch
        logs['cpu_usage'] = cpu_usage

        return logs

    def _loss_discriminator(self, graph_real, graph_generated):
        logits_real = self.discriminator(graph_real, training=True)
        logits_generated = self.discriminator(graph_generated, training=True)
        loss = tf.reduce_mean(logits_generated) - tf.reduce_mean(logits_real)
        loss_gp = self._gradient_penalty(graph_real, graph_generated)
        return loss + loss_gp * self.gp_weight

    def _loss_generator(self, graph_generated):
        logits_generated = self.discriminator(graph_generated, training=True)
        return -tf.reduce_mean(logits_generated)

    def _gradient_penalty(self, graph_real, graph_generated):
        # Unpack graphs
        adjacency_real, features_real = graph_real
        adjacency_generated, features_generated = graph_generated

        # Generate interpolated graphs (adjacency_interp and features_interp)
        alpha = tf.random.uniform([self.batch_size])
        alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))
        adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated
        alpha = tf.reshape(alpha, (self.batch_size, 1, 1))
        features_interp = (features_real * alpha) + (1 - alpha) * features_generated

        # Compute the logits of interpolated graphs
        with tf.GradientTape() as tape:
            tape.watch(adjacency_interp)
            tape.watch(features_interp)
            logits = self.discriminator(
                [adjacency_interp, features_interp], training=True
            )

        # Compute the gradients with respect to the interpolated graphs
        grads = tape.gradient(logits, [adjacency_interp, features_interp])
        # Compute the gradient penalty
        grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2
        grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2
        return tf.reduce_mean(
            tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))
            + tf.reduce_mean(grads_features_penalty, axis=(-1))
        )

    def save_model(self, folder_path="training_models_model2_zinc15_ii/WGAN"):
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        self.generator.save(os.path.join(folder_path, "generator"))
        self.discriminator.save(os.path.join(folder_path, "discriminator"))

    def load_model(self, folder_path="training_models_model2_zinc15_ii/WGAN"):
        self.generator = keras.models.load_model(os.path.join(folder_path, "generator"))
        self.discriminator = keras.models.load_model(os.path.join(folder_path, "discriminator"))
    
    def plot_generated_graph(self, num_samples, epoch):
        z = tf.random.normal((num_samples, LATENT_DIM))
        graph = self.generator.predict(z)

        adjacency = tf.argmax(graph[0], axis=1)
        adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)
        adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
        features = tf.argmax(graph[1], axis=2)
        features = tf.one_hot(features, depth=ATOM_DIM, axis=2)

        networkx_graphs = [
            graph_to_networkx(adjacency[i].numpy(), features[i].numpy())
            for i in range(num_samples)
        ]

        output_dir = "output_graphs_model2_zinc15_ii"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        for i, G in enumerate(networkx_graphs):

            atom_colors = {
                "C": "blue",
                "O": "red",
                "N": "yellow",
                "H": "grey",
                "F": "pink",
                "S": "orange",
                "Cl": "green"
            }

            bond_colors = {
                "SINGLE": "black",
                "DOUBLE": "red",
                "TRIPLE": "blue",
                "AROMATIC": "purple"
            }

            node_colors = [atom_colors.get(G.nodes[node].get("atom_type", None), "purple") for node in G.nodes()]
            edge_colors = [bond_colors.get(G.edges[edge].get("bond_type", None), "black") for edge in G.edges()]

            pos = nx.kamada_kawai_layout(G)
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
            nx.draw(G, pos, with_labels=True, node_color=node_colors, node_size=300, edge_color=edge_colors, ax=ax1)
            nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d["bond_type"] for u, v, d in G.edges(data=True)}, ax=ax1)
            ax1.set_title(f"Generated Graph {i + 1}")

            # Add legend for atoms
            atom_legend_handles = [
                mlines.Line2D([], [], color=color, marker="o", linewidth=1, markersize=8, label=atom_type)
                for atom_type, color in atom_colors.items()
            ]
            ax1.legend(handles=atom_legend_handles, loc="lower right")

            # Create and plot the atom type table
            #atom_data = {"Atom Number": list(G.nodes()), "Atom Type": [G.nodes[node].get("atom_type", "Unknown") for node in G.nodes()]}
            #atom_df = pd.DataFrame(atom_data)
            #atom_table = ax2.table(cellText=atom_df.values, colLabels=atom_df.columns, loc='center')
            ax2.axis('off')

            plt.show()

            filename = f"generated_graph_epoch_{epoch}_sample_{i + 1}.png"
            filepath = os.path.join(output_dir, filename)
            fig.savefig(filepath)

## Training metrics

This code is a detailed implementation of logging various metrics on Tensorboard related to the progress and performance of a Generative Adversarial Network (GAN) designed to generate molecular structures. It computes metrics such as validity, uniqueness, novelty, quinoline, and scaffold similarity. But also training performance on weights and layers.

In [None]:
# Load the full quinoline dataset
quinolines_df = pd.read_csv("../data/non_duplicate_filtered_quinolines_zinc15_50atoms.csv")
full_quinolines_smiles = set(quinolines_df['smiles'].tolist())

In [None]:
class GANLogger(tf.keras.callbacks.Callback):
    def __init__(self, tensorboard_logdir, original_dataset, full_dataset, num_samples):
        super().__init__()
        self.tensorboard_logdir = tensorboard_logdir
        self.writer = tf.summary.create_file_writer(tensorboard_logdir)
        self.original_smiles = set(original_dataset)
        self.full_quinolines_smiles = set(full_dataset)
        self.num_samples = num_samples
        self.quinoline_scaffold = Chem.MolFromSmiles("n1cccc2ccccc12")

    def mol_sample(self, generator, batch_size):
        z = tf.random.normal((batch_size, LATENT_DIM))
        graph = generator.predict(z)
        adjacency = tf.argmax(graph[0], axis=1)
        adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)
        adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
        features = tf.argmax(graph[1], axis=2)
        features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
        molecules = []
        none_counter = 0
        for i in tqdm(range(batch_size), desc="Generating molecules"):
            try:
                mol = graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
                molecules.append(mol)
            except AtomValenceException:
                molecules.append(None)
                none_counter += 1  # Increment the counter
        return molecules, none_counter  # Return the counter

    def on_epoch_begin(self, epoch, logs=None):
        if epoch == 0:  # Just trace the graph once, at the beginning
            tf.summary.trace_on(graph=True, profiler=True)

    def on_epoch_end(self, epoch, logs=None):

        logs = logs or {}
        with self.writer.as_default():
            # Scalars
            tf.summary.scalar("generator_loss", logs.get("loss_gen"), step=epoch)
            tf.summary.scalar("discriminator_loss", logs.get("loss_dis"), step=epoch)

            # Time per epoch
            if 'time' in logs:
                tf.summary.scalar("time_per_epoch", logs.get("time"), step=epoch)

            # CPU and GPU Usage
            tf.summary.scalar("cpu_usage", logs.get("cpu_usage"), step=epoch)

            # Histograms of trainable variables
            for var in self.model.trainable_variables:
                tf.summary.histogram(var.name, var, step=epoch)

            self.writer.flush()

        if epoch == 0:  # Write the graph at the end of the first epoch
            with self.writer.as_default():
                tf.summary.trace_export(name="model_trace", step=epoch, profiler_outdir=self.tensorboard_logdir)
        
        # Compute and log the metrics every certain number of epochs
        if epoch % 1 == 0:  # for example, every 10 epochs
            
            for i in range(self.num_samples):
                img_path = f"output_molecules_model2_zinc15_ii/molecule_epoch_{epoch}_image_{i}.png"

                # Only proceed if the image file exists
                if os.path.exists(img_path):
                    # Load the image file
                    img = tf.keras.preprocessing.image.load_img(img_path)
                    img_array = tf.keras.preprocessing.image.img_to_array(img)

                    # Add an extra dimension (for the batch), and scale to [0, 1]
                    img_array = np.expand_dims(img_array, axis=0) / 255.0

                    # Log the image to TensorBoard
                    with self.writer.as_default():
                        tf.summary.image(f"generated_graph_{i}", img_array, step=epoch)

            molecules, none_counter = self.mol_sample(self.model.generator, batch_size=100)
            # Calculate the percentage of None molecules
            none_percentage = none_counter / 100  # Assuming batch_size is 100
            #valid_molecules = [m for m in molecules if m is not None]
            valid_molecules = []
            for m in molecules:
                if m is not None and m.GetNumAtoms() > 0:  # add check for number of atoms
                    mol_block = Chem.MolToMolBlock(m)
                    if not all(char == '0' for char in mol_block):
                        valid_molecules.append(m)

            connected_valid_molecules = [m for m in valid_molecules if len(GetMolFrags(m)) == 1]
            connected_valid_molecules_with_Hs = [AddHs(mol) for mol in connected_valid_molecules]

            quinoline_scaffold = Chem.MolFromSmiles("n1cccc2ccccc12")
            similarities = []

            for m in valid_molecules:
                if m.GetNumAtoms() > 0:  # add check for number of atoms
                    try:
                        mol_block = Chem.MolToMolBlock(m)
                        # Get the Bemis-Murcko scaffold
                        scaffold = MurckoScaffold.GetScaffoldForMol(m)

                        # Compute the Tanimoto similarity
                        #similarity = DataStructs.TanimotoSimilarity(Chem.RDKFingerprint(quinoline_scaffold), Chem.RDKFingerprint(scaffold))
                        similarity = DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(quinoline_scaffold, 2), AllChem.GetMorganFingerprint(scaffold, 2))

                        # Add the total reward to the list of rewards
                        similarities.append(similarity)
                    except Chem.rdchem.AtomValenceException:
                        print("Invalid molecule skipped due to AtomValenceException")

            average_similarity = 0

            if len(valid_molecules) > 0:
                validity = len(valid_molecules) / len(molecules)

                if len(connected_valid_molecules) > 0:
                    connected_validity = len(connected_valid_molecules_with_Hs) / len(molecules)
                else:
                    connected_validity = 0

                unique_molecules = list(set(valid_molecules))
                if len(unique_molecules) > 0:
                    uniqueness = len(unique_molecules) / len(valid_molecules)

                    novel_molecules = [Chem.MolToSmiles(mol, isomericSmiles=False, allBondsExplicit=False) for mol in unique_molecules if Chem.MolToSmiles(mol, isomericSmiles=False, allBondsExplicit=False) not in self.original_smiles]
                    if len(novel_molecules) > 0:
                        novelty = len(novel_molecules) / len(unique_molecules)
                    else:
                        novelty = 0

                    # Calculate absolute novelty
                    absolute_novel_molecules = [Chem.MolToSmiles(mol, isomericSmiles=False, allBondsExplicit=False) for mol in unique_molecules if Chem.MolToSmiles(mol, isomericSmiles=False, allBondsExplicit=False) not in self.full_quinolines_smiles]
                    if len(absolute_novel_molecules) > 0:
                        absolute_novelty = len(absolute_novel_molecules) / len(unique_molecules)
                    else:
                        absolute_novelty = 0

                else:
                    uniqueness = 0
                    novelty = 0
                    absolute_novelty = 0
            else:
                validity = 0
                novelty = 0
                uniqueness = 0
                connected_validity = 0
                absolute_novelty = 0

            if len(valid_molecules) > 0:
                quinoline_molecules = [mol for mol in valid_molecules if mol.HasSubstructMatch(self.quinoline_scaffold)]
                quinoline_percentage = len(quinoline_molecules) / len(valid_molecules)
            else:
                quinoline_percentage = 0

            if validity > 0:
                # Create a new directory for the current epoch
                epoch_dir = os.path.join("output_molecules", f"epoch_{epoch}")
                os.makedirs(epoch_dir, exist_ok=True)

                # Save up to five molecule images
                for i, mol in enumerate(valid_molecules):   # for i, mol in enumerate(valid_molecules[:5]):
                    
                    # Save the molecule as a MOL file
                    mol_path = os.path.join(epoch_dir, f"molecule_{i}.mol")
                    mol_block = Chem.MolToMolBlock(mol)
                    
                    # Check if the mol file only contains zeros
                    if not all(char == '0' for char in mol_block):
                        # Save the image
                        img_path = os.path.join(epoch_dir, f"molecule_image_{i}.png")
                        Draw.MolToFile(mol, img_path, size=(300, 300))

                        # Save the MOL file
                        rdmolfiles.MolToMolFile(mol, mol_path)

            if len(similarities) > 0:
                average_similarity = sum(similarities) / len(similarities) if similarities else 0
            else:
                average_similarity = 0

            # Log the metrics
            with self.writer.as_default():
                tf.summary.scalar("validity", validity, step=epoch)
                tf.summary.scalar("connected_validity", connected_validity, step=epoch)
                tf.summary.scalar("uniqueness", uniqueness, step=epoch)
                tf.summary.scalar("novelty", novelty, step=epoch)
                tf.summary.scalar("absolute_novelty", absolute_novelty, step=epoch)
                tf.summary.scalar("quinoline_percentage", quinoline_percentage, step=epoch)
                tf.summary.scalar("average_tanimoto_similarity", average_similarity if average_similarity != 0 else 0, step=epoch)
                tf.summary.scalar("none_molecule_percentage", none_percentage, step=epoch)
                self.writer.flush()

            logs['quinoline_percentage'] = quinoline_percentage           


## Train the model

In the following code, the WGAN is constructed and trained. The PlotSamplesCallback class is a custom Keras callback to visualize and plot samples at the end of each training epoch. This helps in monitoring the evolution of generated structures over the training process. To ensure continuity and robustness in training, checkpoints are used. A data_generator function is defined to produce training batches from provided tensors, ensuring that the data is fed correctly into the GAN during training. The GAN is then trained using the fit method, with checkpoints, sample plotting, and logging utilities as its callbacks. After training is complete, the model is saved, providing a reusable pre-trained model for future tasks.

In [None]:
class PlotSamplesCallback(keras.callbacks.Callback):
    def __init__(self, wgan, num_samples=1):
        super().__init__()
        self.wgan = wgan
        self.num_samples = num_samples

    def on_epoch_end(self, epoch, logs=None):
        if epoch % 1 == 0:
            print("\nGenerating and plotting samples at epoch", epoch)
            self.wgan.plot_generated_graph(self.num_samples, epoch)

In [None]:
invalid_mol = MolFromSmiles("")

In [None]:
tensorboard_logdir = "logs_model2_zinc15_ii/WGAN"

In [None]:
# Create the GraphWGAN instance and configure it
wgan = GraphWGAN(generator, discriminator, discriminator_steps=1)

wgan.compile(
    optimizer_generator=keras.optimizers.RMSprop(1e-4),
    optimizer_discriminator=keras.optimizers.RMSprop(1e-4)
)

# Checkpoint configurations
checkpoint_dir = 'training_checkpoints_model2_zinc15_ii/WGAN'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')


if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

checkpoint = tf.train.Checkpoint(
    epoch=tf.Variable(0),
    generator=wgan.generator,
    discriminator=wgan.discriminator,
    optimizer_generator=wgan.optimizer_generator,
    optimizer_discriminator=wgan.optimizer_discriminator,
)

def load_latest_checkpoint(checkpoint):
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        print(f"Loading checkpoint from {latest_checkpoint}")
        checkpoint.restore(latest_checkpoint)
        last_epoch = int(checkpoint.epoch.numpy())
    else:
        print("No checkpoint found. Training from scratch.")
        last_epoch = 0
    return last_epoch

# Load the latest checkpoint and get the starting epoch
starting_epoch = load_latest_checkpoint(checkpoint)

# Create a PlotSamplesCallback instance
plot_samples_callback = PlotSamplesCallback(wgan, num_samples=1)

class CustomModelCheckpoint(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Update the checkpoint's epoch
        checkpoint.epoch.assign(epoch + 1)
        # Save the checkpoint at the end of each epoch
        checkpoint.save(file_prefix=checkpoint_prefix)
        print(f"Saving checkpoint for epoch {epoch + 1}")

checkpoint_callback = CustomModelCheckpoint()

# Instantiate the GANLogger
gan_logger = GANLogger(
    tensorboard_logdir, 
    data['smiles'].tolist(), 
    full_quinolines_smiles, 
    num_samples=1)

In [None]:
print(tf.config.list_physical_devices('GPU'))

In [None]:
def data_generator(adjacency_tensor, feature_tensor, batch_size):
    dataset_size = len(adjacency_tensor)
    indices = np.arange(dataset_size)
    while True:
        # Shuffle indices at the start of each epoch
        #np.random.shuffle(indices)
        for i in range(0, dataset_size, batch_size):
            batch_indices = indices[i: min(i + batch_size, dataset_size)]
            batch_adjacency_tensor = adjacency_tensor[batch_indices]
            batch_feature_tensor = feature_tensor[batch_indices]
            yield [batch_adjacency_tensor, batch_feature_tensor]

batch_size = 128  # Set batch size
data_gen = data_generator(adjacency_tensor, feature_tensor, batch_size)

steps_per_epoch = len(adjacency_tensor) // batch_size
if len(adjacency_tensor) % batch_size != 0:
    steps_per_epoch += 1

# Train the model
wgan.fit(
    data_gen,
    initial_epoch=starting_epoch,
    epochs=300,
    steps_per_epoch=steps_per_epoch,
    callbacks=[
        checkpoint_callback, 
        plot_samples_callback, 
        gan_logger, 
        ],
)

# Save the trained model
wgan.save_model()