In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from chemfunc import compute_top_similarities

In [None]:
# Setup
data_dir = Path("../../revisions")
plot_dir = Path("../../revisions/plots")
plot_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Load data
rl_chemprop = pd.read_csv(data_dir / "rl_chemprop_generated.csv")
rl_mlp = pd.read_csv(data_dir / "rl_mlp_generated.csv")
gflownet = pd.read_csv(data_dir / "gflownet_generated.csv")
reinvent = pd.read_csv(data_dir / "reinvent_generated.csv")

In [None]:
# Plot novelty of each model as violin plots
for novelty_name in ["Train Hits", "ChEMBL"]:
    novelty_column = f"{novelty_name.replace(' ', "_").lower()}_tversky_nearest_neighbor_similarity"

    plot_data = pd.concat([
        rl_chemprop[[novelty_column]].assign(model="RL-Chemprop"),
        rl_mlp[[novelty_column]].assign(model="RL-MLP"),
        gflownet[[novelty_column]].assign(model="GFlowNet"),
        reinvent[[novelty_column]].assign(model="REINVENT 4")
    ])

    plt.clf()
    plt.figure(figsize=(10, 6))
    sns.violinplot(x="model", y=novelty_column, data=plot_data)
    plt.ylabel(f"{novelty_name} Nearest Neighbor Tanimoto Similarity")
    plt.xlabel("")
    plt.savefig(plot_dir / f"synthemol_vs_gflownet_vs_reinvent_{novelty_column}.pdf", bbox_inches="tight")

In [None]:
# Calculate diversity among generated molecules as maximum pairwise Tanimoto similarity
models = []
similarities = []

for name, data in [("RL-Chemprop", rl_chemprop), ("RL-MLP", rl_mlp), ("GFlowNet", gflownet), ("REINVENT 4", reinvent)]:
    top_similarities = compute_top_similarities(similarity_type="tanimoto", mols=list(data["smiles"]), top_k=1)
    models += [name] * len(top_similarities)
    similarities += top_similarities.tolist()

similarities_df = pd.DataFrame({"model": models, "similarity": similarities})

In [None]:
# Plot diversity as violin plot
plt.clf()
plt.figure(figsize=(10, 6))
sns.violinplot(x="model", y="similarity", data=similarities_df)
plt.ylabel("Generated Nearest Neighbor Tanimoto Similarity")
plt.xlabel("")
plt.savefig(plot_dir / "synthemol_vs_gflownet_vs_reinvent_diversity.pdf", bbox_inches="tight")