# Small molecule generation

Please make sure you've added your ANTHROPIC_API_KEY and MOLMIM_API_KEY to a .env file in the root of the project.

In [None]:
import glob
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import Descriptors

from helpers.utils import (
    PALETTE,
    calculate_tanimoto_similarities,
    canonicalize_and_validate_smiles,
    compute_novelty_rates,
    make_umap,
    plot_boxenplot,
    plot_umap,
    process_rare_rings,
)

# Set random seed and helper functions
SEED = 42
ORDER = ["original", "molmim", "claude", "claude_scaffold", "reinvent", "crem"]
DATASET_FILE = {
    "a2a": "data/adenosineA2A.csv",
    "aryl": "data/Aryl piperazine.csv",
    "sirt2": "data/SIRT2.csv",
}
UMAP_RESULTS = os.path.join("results", "umap")
FIGURES = os.path.join("results", "figures")
os.makedirs(UMAP_RESULTS, exist_ok=True)
os.makedirs(FIGURES, exist_ok=True)


np.random.seed(SEED)

### Post processing and data cleaning
Join all the generated output together in one dataframe and remove any invalid SMILES strings.

In [None]:
# Merge all data into a single DataFrame
df = pd.concat([pd.read_csv(f) for f in glob.glob("output/*.csv")])
for dataset, file in DATASET_FILE.items():
    d = pd.read_csv(file)[["Smiles"]]
    d["Dataset"] = dataset
    d["Model"] = "original"
    df = pd.concat([df, d])

df["canonical smiles"] = df["Smiles"].apply(canonicalize_and_validate_smiles)
df = df.dropna(subset=["canonical smiles"])

### Visual inspection of generated molecules
Generate a visual grid of sample molecules from different models.

In [None]:
subset = df[df["Dataset"] == "aryl"]
models = ["original", "crem", "molmim", "claude", "claude_scaffold", "reinvent"]

samples = {model: subset[subset["Model"] == model].sample(5, random_state=SEED)["Smiles"].values for model in models}
molecules = [Chem.MolFromSmiles(val) for tup in zip(*samples.values()) for val in tup]
legends = [model.capitalize() if "_" not in model else model.replace("_", " ").capitalize() for model in models] * 5

# Draw and save the grid image
Chem.Draw.MolsToGridImage(molecules, molsPerRow=6, subImgSize=(300, 300), legends=legends)

### Calculate Tanimoto Similarities
Calculate Tanimoto similarity scores for the generated molecules compared to the original dataset molecules.

In [None]:
# Calculate Tanimoto Similarities (~5 minutes)
if not os.path.exists(os.path.join("results", "combined_with_tanimoto_scores.csv")):
    df["molwt"] = df["canonical smiles"].apply(lambda x: Descriptors.ExactMolWt(Chem.MolFromSmiles(x)))
    df["qed"] = df["canonical smiles"].apply(lambda x: Descriptors.qed(Chem.MolFromSmiles(x)))
    df["Tanimoto Score"] = np.nan
    for dataset in DATASET_FILE.keys():
        tmp_df = df[df["Dataset"] == dataset].copy()
        original_smiles = tmp_df[tmp_df["Model"] == "original"]["canonical smiles"].values
        tmp_df["Tanimoto Score"] = tmp_df["canonical smiles"].apply(lambda smiles: np.max(calculate_tanimoto_similarities(smiles, original_smiles)))
        df.loc[df["Dataset"] == dataset, "Tanimoto Score"] = tmp_df["Tanimoto Score"]
    df.to_csv(os.path.join("results", "combined_with_tanimoto_scores.csv"), index=False)
else:
    df = pd.read_csv(os.path.join("results", "combined_with_tanimoto_scores.csv"))

### Plot Tanimoto Similarity, Molecular Weight, and QED
Generate boxen plots for Tanimoto similarity scores, molecular weight (MW), and quantitative estimate of drug-likeness (QED) across different models.

In [None]:
plot_boxenplot(df, x='Dataset', y='Tanimoto Score', hue='Model', title='Tanimoto Similarity Score', order=ORDER)
plt.savefig(os.path.join("results", "figures", "boxen_plot_tanimoto.png"))

plot_boxenplot(df, x='Dataset', y='molwt', hue='Model', title='MW', order=ORDER)
plt.savefig(os.path.join("results", "figures", "boxen_plot_mw.png"))

plot_boxenplot(df, x='Dataset', y='qed', hue='Model', title='QED', order=ORDER)
plt.savefig(os.path.join("results", "figures", "boxen_plot_qed.png"))

### Generate UMAP plots
Create UMAP visualizations for different datasets.

In [None]:
for dataset in DATASET_FILE.keys():
    group = df[df["Dataset"] == dataset]
    make_umap(group['Smiles'], group['Model'], n_neighbors=100, dataset=dataset, results_dir=UMAP_RESULTS)

### Plot UMAP results
Visualize the UMAP results for the 'aryl' dataset.

In [None]:
plot_umap('aryl', results_dir=UMAP_RESULTS)
plt.savefig(os.path.join(FIGURES, 'umap_aryl.png'))

### Compute and plot novelty rates
Calculate and visualize the novelty rates for the generated molecules.

In [None]:
path = os.path.join("results", "novelty_rates.csv")
if os.path.exists(path):
    novelty_rates_df = pd.read_csv(path)
else:
    novelty_rates_df = compute_novelty_rates(df)
    novelty_rates_df.to_csv(os.path.join("results", "novelty_rates.csv"), index=False)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
ax = ax.flatten()
for i, y_type in enumerate(['Scaffold novelty rate', "Skeleton novelty rate"]):
    sns.barplot(x='Model', y=y_type, data=novelty_rates_df, ax=ax[i], alpha=0.2, errorbar=None, color='#5359CC')
    sns.swarmplot(x='Model', y=y_type, data=novelty_rates_df, color='#5359CC', ax=ax[i])
    ax[i].set_title(y_type)
    ax[i].set_ylabel(y_type)
    ax[i].set_xlabel('Model')
    ax[i].set_ylim(0, 1.1)
    ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
plt.tight_layout()
plt.savefig(os.path.join(FIGURES, 'novelty_rates.png'))

### Process and visualize rare rings
Analyze the presence of rare rings in the generated molecules and visualize the results.

In [None]:
path = os.path.join("results", "rare_rings.csv")
if not os.path.exists(path):
    df = process_rare_rings(df, smiles_column='canonical smiles')
    df.to_csv(path, index=False)
else:
    df = pd.read_csv(path)

In [None]:
grouped = df.groupby(['Model', 'Dataset'])['rare_ring'].mean().reset_index()
sns.barplot(data=grouped, x='Dataset', y='rare_ring', hue='Model', palette=PALETTE, hue_order=ORDER)
plt.savefig(os.path.join(FIGURES, 'rare_rings.png'))