In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
# Setup
data_dir = Path("../../revisions")
plot_dir = Path("../../revisions/plots")

prop_to_plot_name = {
    "s_aureus_activity": "S. aureus",
    "solubility": "Solubility",
}
model_to_plot_name = {
    "chemprop_rdkit": "Chemprop-RDKit",
    "mlp_rdkit": "MLP-RDKit",
}

In [None]:
# Load data
value_model_to_train_model_to_scores = {
    "Chemprop-RDKit": {
        "Chemprop-RDKit": pd.read_csv(data_dir / "rl_chemprop_generated.csv"),
        "MLP-RDKit": pd.read_csv(data_dir / "rl_chemprop_mlp_score_generated.csv"),
    },
    "MLP-RDKit": {
        "Chemprop-RDKit": pd.read_csv(data_dir / "rl_mlp_generated.csv"),
        "MLP-RDKit": pd.read_csv(data_dir / "rl_mlp_mlp_score_generated.csv"),
    },
}


In [None]:
# Plot Chemprop vs MLP scores
for eval_model in ["chemprop_rdkit", "mlp_rdkit"]:
    for prop in ["s_aureus_activity", "solubility"]:
        methods = []
        values = []

        for value_model, train_model_to_scores in value_model_to_train_model_to_scores.items():
            for train_model, scores in train_model_to_scores.items():
                methods += [f"RL Value Model: {value_model}\nTrain Score Model: {train_model}"] * len(scores)
                values += scores[f"{prop}_{eval_model}"].tolist()

        prop_plot_name = f"{prop_to_plot_name[prop]} (Eval Score Model: {model_to_plot_name[eval_model]})"
        plot_data = pd.DataFrame({"Method": methods, prop_plot_name: values})

        plt.figure(figsize=(14, 6))

        ax = sns.violinplot(data=plot_data, x="Method", y=prop_plot_name, hue="Method", palette="tab20")

        # Compute medians by category
        medians = plot_data.groupby("Method")[prop_plot_name].median()

        # Map category names -> x positions on the axis
        labels = [t.get_text() for t in ax.get_xticklabels()]
        xpos = dict(zip(labels, ax.get_xticks()))

        # Small offset
        xlim = ax.get_xlim()
        offset = 0.05 * (xlim[1] - xlim[0])

        # Draw a short median line and a text label per category
        for method, med in medians.items():
            ax.text(xpos[method] + offset, med, f"{med:.2f}", ha="center", va="center", fontsize=12, color="k", zorder=4)

        plt.xlabel("")
        plt.savefig(plot_dir / f"chemprop_vs_mlp_scores_{prop}_eval_{eval_model}.pdf", bbox_inches="tight")