In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
# Setup
data_dir = Path("../../revisions")
plot_dir = Path("../../revisions/plots")

stages = ["generated", "hits", "hits_novel", "hits_novel_diverse", "hits_novel_diverse_top", "available", "ordered", "synthesized"]

In [None]:
# Load data
rl_chemprop = pd.read_csv(data_dir / "rl_chemprop_generated.csv")
rl_mlp = pd.read_csv(data_dir / "rl_mlp_generated.csv")

In [None]:
# Plot Chemprop vs MLP scores for S. aureus
methods = []
values = []

methods += ["RL-Chemprop\nChemprop-RDKit Score"] * len(rl_chemprop)
values += rl_chemprop["S. aureus"].tolist()

methods += ["RL-MLP\nChemprop-RDKit Score"] * len(rl_mlp)
values += rl_mlp["S. aureus"].tolist()

methods += ["RL-Chemprop\nMLP-RDKit Score"] * len(rl_chemprop)
values += rl_chemprop["s_aureus_activity_mlp_rdkit"].tolist()

methods += ["RL-MLP\nMLP-RDKit Score"] * len(rl_mlp)
values += rl_mlp["s_aureus_activity_mlp_rdkit"].tolist()

plot_data = pd.DataFrame({"Method": methods, "S. aureus": values})

plt.figure(figsize=(10, 6))

ax = sns.violinplot(data=plot_data, x="Method", y="S. aureus")

# Compute medians by category
medians = plot_data.groupby("Method")["S. aureus"].median()

# Map category names -> x positions on the axis
labels = [t.get_text() for t in ax.get_xticklabels()]
xpos = dict(zip(labels, ax.get_xticks()))

# Small offset
xlim = ax.get_xlim()
offset = 0.05 * (xlim[1] - xlim[0])

# Draw a short median line and a text label per category
for method, med in medians.items():
    ax.text(xpos[method] + offset, med, f"{med:.2f}", ha="center", va="center", fontsize=12, color="k", zorder=4)

plt.xlabel("")
plt.savefig(plot_dir / "chemprop_vs_mlp_scores_s_aureus.pdf", bbox_inches="tight")

In [None]:
# Plot Chemprop vs MLP scores for solubility
methods = []
values = []

methods += ["RL-Chemprop\nChemprop-RDKit Score"] * len(rl_chemprop)
values += rl_chemprop["Solubility"].tolist()

methods += ["RL-MLP\nChemprop-RDKit Score"] * len(rl_mlp)
values += rl_mlp["Solubility"].tolist()

methods += ["RL-Chemprop\nMLP-RDKit Score"] * len(rl_chemprop)
values += rl_chemprop["solubility_mlp_rdkit"].tolist()

methods += ["RL-MLP\nMLP-RDKit Score"] * len(rl_mlp)
values += rl_mlp["solubility_mlp_rdkit"].tolist()

plot_data = pd.DataFrame({"Method": methods, "Solubility": values})

plt.figure(figsize=(10, 6))

ax = sns.violinplot(data=plot_data, x="Method", y="Solubility")

# Compute medians by category
medians = plot_data.groupby("Method")["Solubility"].median()

# Map category names -> x positions on the axis
labels = [t.get_text() for t in ax.get_xticklabels()]
xpos = dict(zip(labels, ax.get_xticks()))

# Small offset
xlim = ax.get_xlim()
offset = 0.05 * (xlim[1] - xlim[0])

# Draw a short median line and a text label per category
for method, med in medians.items():
    ax.text(xpos[method] + offset, med, f"{med:.2f}", ha="center", va="center", fontsize=12, color="k", zorder=4)

plt.xlabel("")
plt.savefig(plot_dir / "chemprop_vs_mlp_scores_solubility.pdf", bbox_inches="tight")
