In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
# Setup
data_dir = Path("../../revisions")
plot_dir = Path("../../revisions/plots")

stages = ["generated", "hits", "hits_novel", "hits_novel_diverse", "hits_novel_diverse_top", "available", "ordered", "synthesized"]

In [None]:
# Load data
rl_chemprop_data = {
    stage: pd.read_csv(data_dir / f"rl_chemprop_{stage}.csv")
    for stage in stages
}

rl_mlp_data = {
    stage: pd.read_csv(data_dir / f"rl_mlp_{stage}.csv")
    for stage in stages
}

mcts_data = {
    stage: pd.read_csv(data_dir / f"mcts_{stage}.csv")
    for stage in stages
}

vs_chemprop_data = {
    stage: pd.read_csv(data_dir / f"vs_chemprop_{stage}.csv").drop(columns=["S. aureus", "Solubility"], errors="ignore").rename(columns={"s_aureus_activity": "S. aureus", "solubility": "Solubility"})
    for stage in stages
}

In [None]:
# Plot distributions of scores at each stage of filtering
for value in ["S. aureus", "Solubility"]:
    plot_subdir = plot_dir / value
    plot_subdir.mkdir(exist_ok=True, parents=True)

    for i, stage in enumerate(stages):
        plt.clf()
        methods = []
        values = []
        for name, data in [
            ("RL-Chemprop", rl_chemprop_data),
            ("RL-MLP", rl_mlp_data),
            ("MCTS", mcts_data),
            ("VS-Chemprop", vs_chemprop_data)
        ]:
            methods += [name] * len(data[stage][value])
            values += data[stage][value].tolist()
        plot_data = pd.DataFrame({"Method": methods, value: values})

        ax = sns.violinplot(data=plot_data, x="Method", y=value)  # density_norm="width"

        # Compute medians by category
        medians = plot_data.groupby("Method")[value].median()

        # Map category names -> x positions on the axis
        labels = [t.get_text() for t in ax.get_xticklabels()]
        xpos = dict(zip(labels, ax.get_xticks()))

        # Small offset
        xlim = ax.get_xlim()
        offset = 0.05 * (xlim[1] - xlim[0])

        # Draw a short median line and a text label per category
        for method, med in medians.items():
            ax.text(xpos[method] + offset, med, f"{med:.2f}", ha="center", va="center", fontsize=12, color="k", zorder=4)

        plt.title(stage)
        plt.xlabel("")
        plt.savefig(plot_subdir / f'{i + 1}_{stage}.pdf', bbox_inches="tight")