# Plot model performance

Visualize performance on different models on data `2024-04-18` from the wandb API.

In [None]:
import pathlib
import itertools
import sys
sys.path.append(str(pathlib.Path().absolute().parent))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.stats import wilcoxon
import seaborn as sns
import wandb

from utils import get_runs_as_list
from src.util.definitions import DATA_ROOT

In [None]:
# settings
sns.set_theme(context="paper", 
              style="white", 
              font_scale=1, #0.7,
              rc={"savefig.transparent": True, 
                  "axes.grid": False, 
                  "axes.spines.bottom": True,
                  "axes.spines.left": False,
                  "axes.spines.right": False,
                  "axes.spines.top": False,
                  "font.family":'sans-serif',
                  "font.sans-serif":["Helvetica", "Arial"],
                  "xtick.major.pad": 0.0,
                  "xtick.minor.pad": 0.0,
                  "ytick.major.pad": 0.0,
                  "ytick.minor.pad": 0.0,
                  "axes.labelweight": "bold",
                  "axes.labelpad": 2.5,  # standard is 4.0
                  "axes.xmargin": .05,
                 }, 
             )

params = {
    'axes.labelsize': 6,
    'axes.titlesize': 6,
    'xtick.labelsize': 6,
    'ytick.labelsize': 6,
    'legend.fontsize': 6,
    'font.size': 6,
    'svg.fonttype': 'none',  # necessary to have editable text in SVGs
    'text.color': 'black',
    'axes.labelcolor': 'black',
    'xtick.color': 'black',
    'ytick.color': 'black',
}

plt.rcParams.update(params)


# more settings for all plots
errorbar = "se"  # standard error of the mean
errwidth = .9
errcolor = "black"
capsize = .1  # size of the end of the errorbar
linewidth = 1.  # width of the outline of barplot

### Get data from wandb API

In [None]:
summary_list, config_list, tag_list, name_list  = get_runs_as_list(
    filters={
        "$and": [
                {"$or":
                     [
                         {"jobType": "training"},
                         {"jobType": "hparam_best"}
                     ]
                },
            {"createdAt": {"$gt": "2024-03-30"}}
        ],
    },
    timeout=59,
)
df_all = pd.json_normalize(config_list).merge(pd.json_normalize(summary_list), left_index=True, right_index=True)
df_all["tags"] = tag_list
df_all["run_id"] = name_list
df_all["run_group"] = [s.rsplit("_", maxsplit=1)[0] for s in name_list]
df_all["Model+Features"] = df_all["name"] + "/" + df_all["decoder.global_features"].str.join("+").str.replace("None", "CGR")
df_all["fold"] = df_all["run_id"].str[-1].astype(int)

In [None]:
# all experiment ids in the queried data
", ".join(sorted(df_all["experiment_id"].unique()))

In [None]:
# filter for relevant data (JG1527 onwards for 2024-04-18 models)
# n.b. just a doublecheck, the date filter in the query should take care of this
df_all = df_all.loc[df_all["experiment_id"].apply(lambda x: int(x.strip("JG")) > 1526 if isinstance(x, str) else False)]

In [None]:
# check available experiments by split
for tag, row in df_all.groupby("tags")[["experiment_id"]].agg(set).iterrows():
    print(tag, "-->", row["experiment_id"])

In [None]:
# set dir where we will save plots
analysis_dir = pathlib.Path("results")
# set date (of the dataset) for saving
datadate = "2024-04-18"

In [None]:
split_stats = pd.read_csv(DATA_ROOT / "splits" / f"split_statistics_{datadate}.csv")

def get_chance_ap(split_name, fold=None, set_type="test"):
    df = pd.read_csv(DATA_ROOT / "splits" / f"split_statistics_{datadate}.csv")
    if fold is not None:
        return df.loc[(df["split_name"] == split_name ) & (df["fold"] == fold), f"Chance level average precision macro on {set_type} set"].item()
    else:
        return df.loc[(df["split_name"] == split_name ), f"Chance level average precision macro on {set_type} set"].mean()
    
def get_sample_count(split_name, fold=None):
    df = pd.read_csv(DATA_ROOT / "splits" / f"split_statistics_{datadate}.csv")
    if fold is not None:
        return int(df.loc[(df["split_name"] == split_name) & (df["fold"] == fold), "Train samples"].item().split()[0])
    else:
        return int(df.loc[df["split_name"] == split_name, "Train samples"].str.split(expand=True)[0].astype("int").mean())    
    
        
def get_buildingblock_count(split_name, fold=None, bb_types=["initiators", "monomers", "terminators"]):
    df = pd.read_csv(DATA_ROOT / "splits" / f"split_statistics_{datadate}.csv")
    for s in bb_types:
        if s not in ["initiators", "monomers", "terminators"]:  # catch incorrect options for bulding blocks
            raise ValueError("This building block type does not exist")
    columns = [f"Train {s}" for s in bb_types]
    if fold is not None:
        return df.loc[(df["split_name"] == split_name) & (df["fold"] == fold), columns].mean(axis=None)
    else:
        return df.loc[df["split_name"] == split_name, columns].mean(axis=None)
    

In [None]:
# check that we have the expected number of experiments (catch duplicated or unfinished runs)
for exp_id in df_all["experiment_id"].unique():
    n_exp = (df_all["experiment_id"] == exp_id).sum()
    if n_exp != 9:
        print(f"Warning: Expected 9 experiments for {exp_id}, found {n_exp}")

## 0D split

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "0D_80"
# choose the colorscheme
fill_color = "#4b4c68"  # blue
stroke_color = '0.1'  # almost black

# choose the data to plot
experiment_ids = [
    "JG1697",  # LogReg/FP
    "JG1700",  # LogReg/RDKit
    "JG1698",  # AttentiveFP/CGR
    "JG1699",  # GraphSAGE/CGR
    "JG1530",  # XGB/FP+RDKit
    "JG1529",  # XGB/FP
    "JG1531",  # DMPNN/CGR+RDKit
    "JG1527",  # DMPNN/CGR
    "JG1528",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(2.8,2))
sns.barplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    errorbar=errorbar,
    errwidth=errwidth,
    errcolor=errcolor,
    capsize=capsize,
    color=fill_color,
    edgecolor=stroke_color,
    linewidth=linewidth,
    alpha=.8,
    width=.7,
)

sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=2.5,
              alpha=.7
             )

ax.axhline(get_chance_ap(tag, set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/0D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()

fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.png", dpi=300)

In [None]:
# try the same with a boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "0D_80"
# choose the colorscheme
fill_color = "#4b4c68"  # blue
stroke_color = '0.1'  # almost black

# choose the data to plot
experiment_ids = [
    "JG1697",  # LogReg/FP
    "JG1700",  # LogReg/RDKit
    "JG1698",  # AttentiveFP/CGR
    "JG1699",  # GraphSAGE/CGR
    "JG1530",  # XGB/FP+RDKit
    "JG1529",  # XGB/FP
    "JG1531",  # DMPNN/CGR+RDKit
    "JG1527",  # DMPNN/CGR
    "JG1528",  # FFN/OHE
]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6,4))
sns.boxplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    boxprops={"edgecolor": stroke_color},
    medianprops={"color": stroke_color},
    color=fill_color,
    linewidth=linewidth,
    fliersize=2.5,
    saturation=.7,
    width=.7,
)

ax.set_ylabel("AUPRC")
ax.set_xlabel(None)
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.875, 0.98))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()

fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.png", dpi=300)

In [None]:
# try the same with a data points only
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "0D_80"
# choose the colorscheme
fill_color = "#4b4c68"  # blue
stroke_color = '0.1'  # almost black

# choose the data to plot
experiment_ids = [
    "JG1697",  # LogReg/FP
    "JG1700",  # LogReg/RDKit
    "JG1698",  # AttentiveFP/CGR
    "JG1699",  # GraphSAGE/CGR
    "JG1530",  # XGB/FP+RDKit
    "JG1529",  # XGB/FP
    "JG1531",  # DMPNN/CGR+RDKit
    "JG1527",  # DMPNN/CGR
    "JG1528",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6,4))

sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              jitter=.15,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=3.5,
              alpha=.7,
             )

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.875, 0.98))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()

fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.png", dpi=300)

In [None]:
df_plot.groupby("Model+Features")["val/avgPrecision_macro"].agg([np.mean, np.std]).sort_values(by="mean", ascending=False)

In [None]:
# is the best model significantly better than the rest?
best = "FFN/OHE"
x = df_plot.loc[df_plot["Model+Features"] == best].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
for i in df_plot["Model+Features"].drop_duplicates():
    if i != best:
        y = df_plot.loc[df_plot["Model+Features"] == i].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
        wilcoxon_result = wilcoxon(x, y, alternative="two-sided")
        print(i, ":\t", wilcoxon_result, wilcoxon_result[1] < 0.05)

## 1D split

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "1D"
# choose the colorscheme
fill_color = (175/256, 87/256, 38/256)
stroke_color = (132/256, 64/256, 30/256)

# choose the data to plot
experiment_ids = [
    "JG1701",  # LogReg/FP
    "JG1702",  # LogReg/RDKit
    "JG1708",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1704",  # XGB/FP+RDKit
    "JG1703",  # XGB/FP
    "JG1707",  # DMPNN/CGR+RDKit
    "JG1706",  # DMPNN/CGR
    "JG1705",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(2.8,2))
sns.barplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    errorbar=errorbar,
    errwidth=errwidth,
    errcolor=errcolor,
    capsize=capsize,
    color=fill_color,
    edgecolor=stroke_color,
    linewidth=linewidth,
    alpha=.8,
    width=.7,
)

sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=2.5,
              alpha=.8
             )

ax.axhline(get_chance_ap("1D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/1D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.png", dpi=300)

In [None]:
# make a boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "1D"
# choose the colorscheme
fill_color = (175/256, 87/256, 38/256)
stroke_color = (132/256, 64/256, 30/256)

# choose the data to plot
experiment_ids = [
    "JG1701",  # LogReg/FP
    "JG1702",  # LogReg/RDKit
    "JG1708",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1704",  # XGB/FP+RDKit
    "JG1703",  # XGB/FP
    "JG1707",  # DMPNN/CGR+RDKit
    "JG1706",  # DMPNN/CGR
    "JG1705",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.boxplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    color=fill_color,
    boxprops={"edgecolor": stroke_color},
    medianprops={"color": stroke_color},
    linewidth=linewidth,
    fliersize=2.5,
    saturation=.8,
    width=.8,
)

ax.axhline(get_chance_ap("1D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/1D
ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.65, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.png", dpi=300)

In [None]:
# make a boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "1D"
# choose the colorscheme
fill_color = (175/256, 87/256, 38/256)
stroke_color = (132/256, 64/256, 30/256)

# choose the data to plot
experiment_ids = [
    "JG1701",  # LogReg/FP
    "JG1702",  # LogReg/RDKit
    "JG1708",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1704",  # XGB/FP+RDKit
    "JG1703",  # XGB/FP
    "JG1707",  # DMPNN/CGR+RDKit
    "JG1706",  # DMPNN/CGR
    "JG1705",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              jitter=.15,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=3.5,
              alpha=.8
             )

ax.axhline(get_chance_ap("1D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/1D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.65, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.png", dpi=300)

In [None]:
df_plot.groupby("Model+Features")["val/avgPrecision_macro"].agg([np.mean, np.std]).sort_values(by="mean", ascending=False)

In [None]:
# is the best model significantly better than the rest?
best = "XGB/FP+RDKit"
x = df_plot.loc[df_plot["Model+Features"] == best].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
for i in df_plot["Model+Features"].drop_duplicates():
    if i != best:
        y = df_plot.loc[df_plot["Model+Features"] == i].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
        wilcoxon_result = wilcoxon(x, y, alternative="two-sided")
        print(i, ":\t", wilcoxon_result, wilcoxon_result[1] < 0.05)

## 2D split

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "2D"
# choose the colorscheme
fill_color = (249/256, 158/256, 35/256)
stroke_color = (186/256, 115/256, 28/256)

# choose the data to plot
experiment_ids = [
    "JG1710",  # LogReg/FP
    "JG1711",  # LogReg/RDKit
    "JG1717",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1713",  # XGB/FP+RDKit
    "JG1712",  # XGB/FP
    "JG1716",  # DMPNN/CGR+RDKit
    "JG1715",  # DMPNN/CGR
    "JG1714",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()

# set plot
fig, ax = plt.subplots(figsize=(2.8,2))
sns.barplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    errorbar=errorbar,
    errwidth=errwidth,
    errcolor=errcolor,
    capsize=capsize,
    color=fill_color,
    edgecolor=stroke_color,
    linewidth=linewidth,
    alpha=.8,
    width=.7,
)

sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=2.5,
              alpha=.8
             )

ax.axhline(get_chance_ap("2D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/2D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.png", dpi=300)

In [None]:
# make boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "2D"
# choose the colorscheme
fill_color = (249/256, 158/256, 35/256)
stroke_color = (186/256, 115/256, 28/256)

# choose the data to plot
experiment_ids = [
    "JG1710",  # LogReg/FP
    "JG1711",  # LogReg/RDKit
    "JG1717",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1713",  # XGB/FP+RDKit
    "JG1712",  # XGB/FP
    "JG1716",  # DMPNN/CGR+RDKit
    "JG1715",  # DMPNN/CGR
    "JG1714",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()

# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.boxplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    color=fill_color,
    boxprops={"edgecolor": stroke_color},
    medianprops={"color": stroke_color},
    linewidth=linewidth,
    fliersize=2.5,
    saturation=.8,
    width=.7,
)

ax.axhline(get_chance_ap("2D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/2D
ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.png", dpi=300)

In [None]:
# make boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "2D"
# choose the colorscheme
fill_color = (249/256, 158/256, 35/256)
stroke_color = (186/256, 115/256, 28/256)

# choose the data to plot
experiment_ids = [
    "JG1710",  # LogReg/FP
    "JG1711",  # LogReg/RDKit
    "JG1717",  # AttentiveFP/CGR
    "JG1709",  # GraphSAGE/CGR
    "JG1713",  # XGB/FP+RDKit
    "JG1712",  # XGB/FP
    "JG1716",  # DMPNN/CGR+RDKit
    "JG1715",  # DMPNN/CGR
    "JG1714",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()

# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              jitter=.15,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=3.5,
              alpha=.8
             )

ax.axhline(get_chance_ap("2D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/2D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.png", dpi=300)

In [None]:
df_plot.groupby("Model+Features")["val/avgPrecision_macro"].agg([np.mean, np.std]).sort_values(by="mean", ascending=False)

In [None]:
# is the best model significantly better than the rest?
best = "D-MPNN/CGR"
x = df_plot.loc[df_plot["Model+Features"] == best].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
for i in df_plot["Model+Features"].drop_duplicates():
    if i != best:
        y = df_plot.loc[df_plot["Model+Features"] == i].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
        wilcoxon_result = wilcoxon(x, y, alternative="two-sided")
        print(i, ":\t", wilcoxon_result, wilcoxon_result[1] < 0.05)

## 3D split

In [None]:
# barplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "3D"
# choose the colorscheme
fill_color = (123/256, 154/256, 207/256)
stroke_color = (99/256, 126/256, 165/256)

# choose the data to plot
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1720",  # LogReg/RDKit
    "JG1726",  # AttentiveFP/CGR
    "JG1727",  # GraphSAGE/CGR
    "JG1722",  # XGB/FP+RDKit
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
    "JG1723",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(2.8,2))
sns.barplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    errorbar=errorbar,
    errwidth=errwidth,
    errcolor=errcolor,
    capsize=capsize,
    color=fill_color,
    edgecolor=stroke_color,
    linewidth=linewidth,
    alpha=.8,
    width=.7,
)

sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=2.5,
              alpha=.8
             )

ax.axhline(get_chance_ap("3D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/3D
ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0, 1))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_barplot.png", dpi=300)

In [None]:
# boxplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "3D"
# choose the colorscheme
fill_color = (123/256, 154/256, 207/256)
stroke_color = (99/256, 126/256, 165/256)

# choose the data to plot
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1720",  # LogReg/RDKit
    "JG1726",  # AttentiveFP/CGR
    "JG1727",  # GraphSAGE/CGR
    "JG1722",  # XGB/FP+RDKit
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
    "JG1723",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.boxplot(
    ax=ax,
    data=df_plot, 
    x="Model+Features",
    y=metric,
    color=fill_color,
    boxprops={"edgecolor": stroke_color},
    medianprops={"color": stroke_color},
    linewidth=linewidth,
    fliersize=2.5,
    saturation=.8,
    width=.7,
)

ax.axhline(get_chance_ap("3D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/3D

ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_boxplot.png", dpi=300)

In [None]:
# scatter
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "3D"
# choose the colorscheme
fill_color = (123/256, 154/256, 207/256)
stroke_color = (99/256, 126/256, 165/256)

# choose the data to plot
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1720",  # LogReg/RDKit
    "JG1726",  # AttentiveFP/CGR
    "JG1727",  # GraphSAGE/CGR
    "JG1722",  # XGB/FP+RDKit
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
    "JG1723",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.stripplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              dodge=True,
              jitter=.15,
              edgecolor=stroke_color,
              color=fill_color,
              linewidth=.5,
              legend=False,
              marker="o",
              size=3.5,
              alpha=.8
)

ax.axhline(get_chance_ap("3D", set_type="val"), ls="--", color="black", linewidth=.7)  # chance level for val/3D
ax.set_xlabel(None)
ax.set_ylabel("AUPRC")
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_scatter.png", dpi=300)

In [None]:
# try a lineplot
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "3D"
# choose the colorscheme
fill_color = (198/256, 193/256, 80/256)

stroke_color = (142/256, 136/256, 58/256)

# choose the data to plot
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1720",  # LogReg/RDKit
    "JG1726",  # AttentiveFP/CGR
    "JG1727",  # GraphSAGE/CGR
    "JG1722",  # XGB/FP+RDKit
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
    "JG1723",  # FFN/OHE
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.lineplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              hue="fold", 
              palette=sns.color_palette("husl", 9),
              #edgecolor=stroke_color,
              #color=fill_color,
              #linewidth=.5,
              #legend=False,
              #marker="o",
              #size=3.5,
              #alpha=.8
)

ax.set_xlabel(None)
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_line.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_line.png", dpi=300)

In [None]:
# same, but only best models
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"
# choose the task (0D, 1D, or 2D)
tag = "3D"
# choose the colorscheme
fill_color = (198/256, 193/256, 80/256)

stroke_color = (142/256, 136/256, 58/256)

# choose the data to plot
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
]

# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]

# sort the values
sort_dict = dict(zip(experiment_ids, itertools.count()))
df_plot = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict))
ticklabels = df_plot["Model+Features"].unique().tolist()
# set plot
fig, ax = plt.subplots(figsize=(5.6, 4))
sns.lineplot(
    ax=ax,
    data=df_plot, 
              x="Model+Features",
              y=metric,
              hue="fold", 
              palette=sns.color_palette("husl", 9),
              #edgecolor=stroke_color,
              #color=fill_color,
              #linewidth=.5,
              #legend=False,
              #marker="o",
              #size=3.5,
              #alpha=.8
)

ax.set_xlabel(None)
ax.xaxis.set_tick_params(labelrotation=90)
ax.set_ylim((0.45, 0.95))
ax.xaxis.set_ticklabels(df_plot["Model+Features"].drop_duplicates().str.replace("LogisticRegression", "LogReg"))

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_line.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_{tag}_models_{metric.replace('/', '_')}_line.png", dpi=300)

In [None]:
experiment_ids = [
    "JG1719",  # LogReg/FP
    "JG1720",  # LogReg/RDKit
    "JG1726",  # AttentiveFP/CGR
    "JG1727",  # GraphSAGE/CGR
    "JG1722",  # XGB/FP+RDKit
    "JG1721",  # XGB/FP
    "JG1725",  # DMPNN/CGR+RDKit
    "JG1724",  # DMPNN/CGR
    "JG1723",  # FFN/OHE
]

df_plot = df_all.loc[df_all['experiment_id'].isin(experiment_ids) & df_all['tags'].apply(lambda x: tag in x)]


df_plot.groupby("Model+Features")["val/avgPrecision_macro"].agg([np.mean, np.std]).sort_values(by="mean", ascending=False)

In [None]:
# is the best model significantly better than the rest?
best = "LogisticRegression/FP"
x = df_plot.loc[df_plot["Model+Features"] == best].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
for i in df_plot["Model+Features"].drop_duplicates():
    if i != best:
        y = df_plot.loc[df_plot["Model+Features"] == i].sort_values(by="run_id")["val/avgPrecision_macro"].to_numpy()
        wilcoxon_result = wilcoxon(x, y, alternative="two-sided")
        print(i, ":\t", wilcoxon_result, wilcoxon_result[1] < 0.05)

## 0D restricted data splits

In [None]:
# some settings for all following plots

# order used for hue/style
order = ['FFN/OHE', 'XGB/FP', 'XGB/FP+RDKit', 'D-MPNN/CGR', 'D-MPNN/RDKit']

# linestyle
dashes=[(3, 3), (3, 3), (1, 1), (3, 3), (1, 1)]

# alternative palette where two colors are reused with less saturation to show derivative categories
#palette = sns.color_palette(["#5790fc", "#f89c20", "#a66611", "#a0228d", "#783b6e"])  # works for colorblind
palette = sns.color_palette(["#5790fc", "#f89c20", "#a66611", "#e42536", "#a1212c"])  # works for colorblind
palette

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"

# choose the data to plot
experiment_ids = [
        [  # _0.625
            "JG1562",  # FFN/OHE
            "JG1563",  # XGB/FP
            "JG1564",  # XGB/FP+RDKit
            "JG1565",  # D-MPNN/CGR
            "JG1566",  # D-MPNN/CGR+RDKit 
    ],
        [  # _1.25
            "JG1557",  # FFN/OHE
            "JG1558",  # XGB/FP
            "JG1559",  # XGB/FP+RDKit
            "JG1560",  # D-MPNN/CGR
            "JG1561",  # D-MPNN/CGR+RDKit 
    ],
        [  # _2.5
            "JG1552",  # FFN/OHE
            "JG1553",  # XGB/FP
            "JG1554",  # XGB/FP+RDKit
            "JG1555",  # D-MPNN/CGR
            "JG1556",  # D-MPNN/CGR+RDKit 
    ],
        [  # _5
            "JG1547",  # FFN/OHE
            "JG1548",  # XGB/FP
            "JG1549",  # XGB/FP+RDKit
            "JG1550",  # D-MPNN/CGR
            "JG1551",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1542",  # FFN/OHE
            "JG1543",  # XGB/FP
            "JG1544",  # XGB/FP+RDKit
            "JG1545",  # D-MPNN/CGR
            "JG1546",  # D-MPNN/CGR+RDKit 
    ],
        [  # _20
            "JG1537",  # FFN/OHE
            "JG1538",  # XGB/FP
            "JG1539",  # XGB/FP+RDKit
            "JG1540",  # D-MPNN/CGR
            "JG1541",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1532",  # FFN/OHE
            "JG1533",  # XGB/FP
            "JG1534",  # XGB/FP+RDKit
            "JG1535",  # D-MPNN/CGR
            "JG1536",  # D-MPNN/CGR+RDKit 
    ],
        [  # _80 (= full)
            "JG1528",  # FFN/OHE
            "JG1529",  # XGB/FP
            "JG1530",  # XGB/FP+RDKit
            "JG1527",  # D-MPNN/CGR
            "JG1531",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["0D_0.625", "0D_1.25", "0D_2.5", "0D_5", "0D_10","0D_20", "0D_40", "0D_80"]
}

chance_level = [get_chance_ap(k, set_type="val") for k in sample_counts.keys()]  # same order as sample_counts

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)
# don't plot chance level b/c it is much lower
#sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
ax.set_ylim((0.75, 1))
ax.set_xticks(
    [250, 500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['250', '500', '1000', '2000', '4000', '8000', '16000', '32000']
)
ax.legend(title=None, loc="lower right")

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [
        [  # _0.625
            "JG1562",  # FFN/OHE
            "JG1563",  # XGB/FP
            "JG1564",  # XGB/FP+RDKit
            "JG1565",  # D-MPNN/CGR
            "JG1566",  # D-MPNN/CGR+RDKit 
    ],
        [  # _1.25
            "JG1557",  # FFN/OHE
            "JG1558",  # XGB/FP
            "JG1559",  # XGB/FP+RDKit
            "JG1560",  # D-MPNN/CGR
            "JG1561",  # D-MPNN/CGR+RDKit 
    ],
        [  # _2.5
            "JG1552",  # FFN/OHE
            "JG1553",  # XGB/FP
            "JG1554",  # XGB/FP+RDKit
            "JG1555",  # D-MPNN/CGR
            "JG1556",  # D-MPNN/CGR+RDKit 
    ],
        [  # _5
            "JG1547",  # FFN/OHE
            "JG1548",  # XGB/FP
            "JG1549",  # XGB/FP+RDKit
            "JG1550",  # D-MPNN/CGR
            "JG1551",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1542",  # FFN/OHE
            "JG1543",  # XGB/FP
            "JG1544",  # XGB/FP+RDKit
            "JG1545",  # D-MPNN/CGR
            "JG1546",  # D-MPNN/CGR+RDKit 
    ],
        [  # _20
            "JG1537",  # FFN/OHE
            "JG1538",  # XGB/FP
            "JG1539",  # XGB/FP+RDKit
            "JG1540",  # D-MPNN/CGR
            "JG1541",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1532",  # FFN/OHE
            "JG1533",  # XGB/FP
            "JG1534",  # XGB/FP+RDKit
            "JG1535",  # D-MPNN/CGR
            "JG1536",  # D-MPNN/CGR+RDKit 
    ],
        [  # _80 (= full)
            "JG1528",  # FFN/OHE
            "JG1529",  # XGB/FP
            "JG1530",  # XGB/FP+RDKit
            "JG1527",  # D-MPNN/CGR
            "JG1531",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["0D_0.625", "0D_1.25", "0D_2.5", "0D_5", "0D_10","0D_20", "0D_40", "0D_80"]
}

chance_level = [get_chance_ap(k, set_type="test") for k in sample_counts.keys()]  # same order as sample_counts


exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

# don't plot chance level b/c it is much lower
#sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
ax.set_ylim((0.75, 1))
ax.set_xticks(
    [250, 500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['250', '500', '1000', '2000', '4000', '8000', '16000', '32000']
)
ax.legend(loc="lower right", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# Plot scaled using chance level for better comparability across sizes
# i.e. we divide by the chance level so that the chance result is now 0 and best is still 1
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"


# choose the data to plot
experiment_ids = [
        [  # _0.625
            "JG1562",  # FFN/OHE
            "JG1563",  # XGB/FP
            "JG1564",  # XGB/FP+RDKit
            "JG1565",  # D-MPNN/CGR
            "JG1566",  # D-MPNN/CGR+RDKit 
    ],
        [  # _1.25
            "JG1557",  # FFN/OHE
            "JG1558",  # XGB/FP
            "JG1559",  # XGB/FP+RDKit
            "JG1560",  # D-MPNN/CGR
            "JG1561",  # D-MPNN/CGR+RDKit 
    ],
        [  # _2.5
            "JG1552",  # FFN/OHE
            "JG1553",  # XGB/FP
            "JG1554",  # XGB/FP+RDKit
            "JG1555",  # D-MPNN/CGR
            "JG1556",  # D-MPNN/CGR+RDKit 
    ],
        [  # _5
            "JG1547",  # FFN/OHE
            "JG1548",  # XGB/FP
            "JG1549",  # XGB/FP+RDKit
            "JG1550",  # D-MPNN/CGR
            "JG1551",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1542",  # FFN/OHE
            "JG1543",  # XGB/FP
            "JG1544",  # XGB/FP+RDKit
            "JG1545",  # D-MPNN/CGR
            "JG1546",  # D-MPNN/CGR+RDKit 
    ],
        [  # _20
            "JG1537",  # FFN/OHE
            "JG1538",  # XGB/FP
            "JG1539",  # XGB/FP+RDKit
            "JG1540",  # D-MPNN/CGR
            "JG1541",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1532",  # FFN/OHE
            "JG1533",  # XGB/FP
            "JG1534",  # XGB/FP+RDKit
            "JG1535",  # D-MPNN/CGR
            "JG1536",  # D-MPNN/CGR+RDKit 
    ],
        [  # _80 (= full)
            "JG1528",  # FFN/OHE
            "JG1529",  # XGB/FP
            "JG1530",  # XGB/FP+RDKit
            "JG1527",  # D-MPNN/CGR
            "JG1531",  # D-MPNN/CGR+RDKit
    ],
]


sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["0D_0.625", "0D_1.25", "0D_2.5", "0D_5", "0D_10","0D_20", "0D_40", "0D_80"]
}

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# divide by chance level. Note that we obtain the chance level per individual fold so that the SEM still makes sense after scaling.
df_plot_x[f"{metric}_scaled"] = (df_plot_x[metric] - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1)) / (1 - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1))

# set plot
fig, ax = plt.subplots(figsize=(3.625, 3))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=f"{metric}_scaled",
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC (relative improvement over chance)")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0, 1))
ax.set_xticks(
    [250, 500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['250', '500', '1000', '2000', '4000', '8000', '16000', '32000']
)
ax.legend(loc="lower right", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}_relative.svg", format="svg", transparent=True)
fig.savefig(analysis_dir / f"metrics_{datadate}_0D_restricted-data_models_{metric.replace('/', '_')}_relative.png", dpi=300)

## 1D restricted data splits

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"

# choose the data to plot
experiment_ids = [
        [  # _2.5
            "JG1592",  # FFN/OHE
            "JG1593",  # XGB/FP
            "JG1594",  # XGB/FP+RDKit
            "JG1595",  # D-MPNN/CGR
            "JG1596",  # D-MPNN/CGR+RDKit
    ],
        [  # _5
            "JG1587",  # FFN/OHE
            "JG1588",  # XGB/FP
            "JG1589",  # XGB/FP+RDKit
            "JG1590",  # D-MPNN/CGR
            "JG1591",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1582",  # FFN/OHE
            "JG1583",  # XGB/FP
            "JG1584",  # XGB/FP+RDKit
            "JG1585",  # D-MPNN/CGR
            "JG1586",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1577",  # FFN/OHE
            "JG1578",  # XGB/FP
            "JG1579",  # XGB/FP+RDKit
            "JG1580",  # D-MPNN/CGR
            "JG1581",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1572",  # FFN/OHE
            "JG1573",  # XGB/FP
            "JG1574",  # XGB/FP+RDKit
            "JG1575",  # D-MPNN/CGR
            "JG1576",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1567",  # FFN/OHE
            "JG1568",  # XGB/FP
            "JG1569",  # XGB/FP+RDKit
            "JG1570",  # D-MPNN/CGR
            "JG1571",  # D-MPNN/CGR+RDKit
    ], 
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["1D_2.5", "1D_5", "1D_10", "1D_20", "1D_40", "1D_80"]
}

chance_level = [get_chance_ap(k, set_type="val") for k in sample_counts.keys()]  # same order as sample_counts


exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

#sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.70, 0.95))
ax.set_xticks(
    [500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['500', '1000', '2000', '4000', '8000', '16000', '32000']
)
ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [
        [  # _2.5
            "JG1592",  # FFN/OHE
            "JG1593",  # XGB/FP
            "JG1594",  # XGB/FP+RDKit
            "JG1595",  # D-MPNN/CGR
            "JG1596",  # D-MPNN/CGR+RDKit
    ],
        [  # _5
            "JG1587",  # FFN/OHE
            "JG1588",  # XGB/FP
            "JG1589",  # XGB/FP+RDKit
            "JG1590",  # D-MPNN/CGR
            "JG1591",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1582",  # FFN/OHE
            "JG1583",  # XGB/FP
            "JG1584",  # XGB/FP+RDKit
            "JG1585",  # D-MPNN/CGR
            "JG1586",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1577",  # FFN/OHE
            "JG1578",  # XGB/FP
            "JG1579",  # XGB/FP+RDKit
            "JG1580",  # D-MPNN/CGR
            "JG1581",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1572",  # FFN/OHE
            "JG1573",  # XGB/FP
            "JG1574",  # XGB/FP+RDKit
            "JG1575",  # D-MPNN/CGR
            "JG1576",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1567",  # FFN/OHE
            "JG1568",  # XGB/FP
            "JG1569",  # XGB/FP+RDKit
            "JG1570",  # D-MPNN/CGR
            "JG1571",  # D-MPNN/CGR+RDKit
    ], 
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["1D_2.5", "1D_5", "1D_10", "1D_20", "1D_40", "1D_80"]
}

chance_level = [get_chance_ap(k, set_type="test") for k in sample_counts.keys()]  # same order as sample_counts


exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

#sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.70, 0.95))
ax.set_xticks(
    [500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['500', '1000', '2000', '4000', '8000', '16000', '32000']
)

ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# Plot scaled using chance level for better comparability across sizes
# i.e. we divide by the chance level so that the chance result is now 0 and best is still 1
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [
        [  # _2.5
            "JG1592",  # FFN/OHE
            "JG1593",  # XGB/FP
            "JG1594",  # XGB/FP+RDKit
            "JG1595",  # D-MPNN/CGR
            "JG1596",  # D-MPNN/CGR+RDKit
    ],
        [  # _5
            "JG1587",  # FFN/OHE
            "JG1588",  # XGB/FP
            "JG1589",  # XGB/FP+RDKit
            "JG1590",  # D-MPNN/CGR
            "JG1591",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1582",  # FFN/OHE
            "JG1583",  # XGB/FP
            "JG1584",  # XGB/FP+RDKit
            "JG1585",  # D-MPNN/CGR
            "JG1586",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1577",  # FFN/OHE
            "JG1578",  # XGB/FP
            "JG1579",  # XGB/FP+RDKit
            "JG1580",  # D-MPNN/CGR
            "JG1581",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1572",  # FFN/OHE
            "JG1573",  # XGB/FP
            "JG1574",  # XGB/FP+RDKit
            "JG1575",  # D-MPNN/CGR
            "JG1576",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1567",  # FFN/OHE
            "JG1568",  # XGB/FP
            "JG1569",  # XGB/FP+RDKit
            "JG1570",  # D-MPNN/CGR
            "JG1571",  # D-MPNN/CGR+RDKit
    ], 
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["1D_2.5", "1D_5", "1D_10", "1D_20", "1D_40", "1D_80"]
}

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# divide by chance level. Note that we obtain the chance level per individual fold so that the SEM still makes sense after scaling.
df_plot_x[f"{metric}_scaled"] = (df_plot_x[metric] - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1)) / (1 - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1))

# set plot
fig, ax = plt.subplots(figsize=(3.625, 3))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=f"{metric}_scaled",
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC (relative improvement over chance)")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0, 1))
ax.set_xticks(
    [250, 500, 1000, 2000, 4000, 8000, 16000, 32000], 
    ['250', '500', '1000', '2000', '4000', '8000', '16000', '32000']
)
ax.legend(loc="lower right", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}_relative.svg", format="svg", transparent=True)
fig.savefig(analysis_dir / f"metrics_{datadate}_1D_restricted-data_models_{metric.replace('/', '_')}_relative.png", dpi=300)

## 2D restricted data splits

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _5
            "JG1637",  # FFN/OHE
            "JG1638",  # XGB/FP
            "JG1639",  # XGB/FP+RDKit
            "JG1640",  # D-MPNN/CGR
            "JG1641",  # D-MPNN/CGR+RDKit
    ],
        [  # _7.5
            "JG1632",  # FFN/OHE
            "JG1633",  # XGB/FP
            "JG1634",  # XGB/FP+RDKit
            "JG1635",  # D-MPNN/CGR
            "JG1636",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1627",  # FFN/OHE
            "JG1628",  # XGB/FP
            "JG1629",  # XGB/FP+RDKit
            "JG1630",  # D-MPNN/CGR
            "JG1631",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1622",  # FFN/OHE
            "JG1623",  # XGB/FP
            "JG1624",  # XGB/FP+RDKit
            "JG1625",  # D-MPNN/CGR
            "JG1626",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1617",  # FFN/OHE
            "JG1618",  # XGB/FP
            "JG1619",  # XGB/FP+RDKit
            "JG1620",  # D-MPNN/CGR
            "JG1621",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1612",  # FFN/OHE
            "JG1613",  # XGB/FP
            "JG1614",  # XGB/FP+RDKit
            "JG1615",  # D-MPNN/CGR
            "JG1616",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1607",  # FFN/OHE
            "JG1608",  # XGB/FP
            "JG1609",  # XGB/FP+RDKit
            "JG1610",  # D-MPNN/CGR
            "JG1611",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1602",  # FFN/OHE
            "JG1603",  # XGB/FP
            "JG1604",  # XGB/FP+RDKit
            "JG1605",  # D-MPNN/CGR
            "JG1606",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1597",  # FFN/OHE
            "JG1598",  # XGB/FP
            "JG1599",  # XGB/FP+RDKit
            "JG1600",  # D-MPNN/CGR
            "JG1601",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["2D_5", "2D_7.5", "2D_10", "2D_15", "2D_20", "2D_30", "2D_40", "2D_60", "2D_80"]
}

chance_level = [get_chance_ap(k, set_type="val") for k in sample_counts.keys()]  # same order as sample_counts

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.53, 0.9))
ax.set_xticks(
    [32*2**n for n in range(10)], 
    [f"{32*2**n}" for n in range(10)],
)

ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _5
            "JG1637",  # FFN/OHE
            "JG1638",  # XGB/FP
            "JG1639",  # XGB/FP+RDKit
            "JG1640",  # D-MPNN/CGR
            "JG1641",  # D-MPNN/CGR+RDKit
    ],
        [  # _7.5
            "JG1632",  # FFN/OHE
            "JG1633",  # XGB/FP
            "JG1634",  # XGB/FP+RDKit
            "JG1635",  # D-MPNN/CGR
            "JG1636",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1627",  # FFN/OHE
            "JG1628",  # XGB/FP
            "JG1629",  # XGB/FP+RDKit
            "JG1630",  # D-MPNN/CGR
            "JG1631",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1622",  # FFN/OHE
            "JG1623",  # XGB/FP
            "JG1624",  # XGB/FP+RDKit
            "JG1625",  # D-MPNN/CGR
            "JG1626",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1617",  # FFN/OHE
            "JG1618",  # XGB/FP
            "JG1619",  # XGB/FP+RDKit
            "JG1620",  # D-MPNN/CGR
            "JG1621",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1612",  # FFN/OHE
            "JG1613",  # XGB/FP
            "JG1614",  # XGB/FP+RDKit
            "JG1615",  # D-MPNN/CGR
            "JG1616",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1607",  # FFN/OHE
            "JG1608",  # XGB/FP
            "JG1609",  # XGB/FP+RDKit
            "JG1610",  # D-MPNN/CGR
            "JG1611",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1602",  # FFN/OHE
            "JG1603",  # XGB/FP
            "JG1604",  # XGB/FP+RDKit
            "JG1605",  # D-MPNN/CGR
            "JG1606",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1597",  # FFN/OHE
            "JG1598",  # XGB/FP
            "JG1599",  # XGB/FP+RDKit
            "JG1600",  # D-MPNN/CGR
            "JG1601",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["2D_5", "2D_7.5", "2D_10", "2D_15", "2D_20", "2D_30", "2D_40", "2D_60", "2D_80"]
}

chance_level = [get_chance_ap(k, set_type="test") for k in sample_counts.keys()]  # same order as sample_counts

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.5, 0.9))
ax.set_xticks(
    [32*2**n for n in range(10)], 
    [f"{32*2**n}" for n in range(10)],
)
ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# Plot scaled using chance level for better comparability across sizes
# i.e. we divide by the chance level so that the chance result is now 0 and best is still 1
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _5 leave out due to excess variance
#            "JG1637",  # FFN/OHE
#            "JG1638",  # XGB/FP
#            "JG1639",  # XGB/FP+RDKit
#            "JG1640",  # D-MPNN/CGR
#            "JG1641",  # D-MPNN/CGR+RDKit
    ],
        [  # _7.5
            "JG1632",  # FFN/OHE
            "JG1633",  # XGB/FP
            "JG1634",  # XGB/FP+RDKit
            "JG1635",  # D-MPNN/CGR
            "JG1636",  # D-MPNN/CGR+RDKit
    ],
        [  # _10
            "JG1627",  # FFN/OHE
            "JG1628",  # XGB/FP
            "JG1629",  # XGB/FP+RDKit
            "JG1630",  # D-MPNN/CGR
            "JG1631",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1622",  # FFN/OHE
            "JG1623",  # XGB/FP
            "JG1624",  # XGB/FP+RDKit
            "JG1625",  # D-MPNN/CGR
            "JG1626",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1617",  # FFN/OHE
            "JG1618",  # XGB/FP
            "JG1619",  # XGB/FP+RDKit
            "JG1620",  # D-MPNN/CGR
            "JG1621",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1612",  # FFN/OHE
            "JG1613",  # XGB/FP
            "JG1614",  # XGB/FP+RDKit
            "JG1615",  # D-MPNN/CGR
            "JG1616",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1607",  # FFN/OHE
            "JG1608",  # XGB/FP
            "JG1609",  # XGB/FP+RDKit
            "JG1610",  # D-MPNN/CGR
            "JG1611",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1602",  # FFN/OHE
            "JG1603",  # XGB/FP
            "JG1604",  # XGB/FP+RDKit
            "JG1605",  # D-MPNN/CGR
            "JG1606",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1597",  # FFN/OHE
            "JG1598",  # XGB/FP
            "JG1599",  # XGB/FP+RDKit
            "JG1600",  # D-MPNN/CGR
            "JG1601",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["2D_5", "2D_7.5", "2D_10", "2D_15", "2D_20", "2D_30", "2D_40", "2D_60", "2D_80"]
}

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# divide by chance level. Note that we obtain the chance level per individual fold so that the SEM still makes sense after scaling.
df_plot_x[f"{metric}_scaled"] = (df_plot_x[metric] - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1)) / (1 - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1))

# set plot
fig, ax = plt.subplots(figsize=(3.625, 3))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=f"{metric}_scaled",
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC (relative improvement over chance)")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0., 1))
ax.set_xticks(
    [125*2**n for n in range(9)], 
    [f"{125*2**n}" for n in range(9)],
)
legend = ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}_relative.svg", format="svg", transparent=True)
fig.savefig(analysis_dir / f"metrics_{datadate}_2D_restricted-data_models_{metric.replace('/', '_')}_relative.png", dpi=300)

## 3D restricted data splits

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "val/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _10
            "JG1692",  # FFN/OHE
            "JG1693",  # XGB/FP
            "JG1694",  # XGB/FP+RDKit
            "JG1695",  # D-MPNN/CGR
            "JG1696",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1687",  # FFN/OHE
            "JG1688",  # XGB/FP
            "JG1689",  # XGB/FP+RDKit
            "JG1690",  # D-MPNN/CGR
            "JG1691",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1682",  # FFN/OHE
            "JG1683",  # XGB/FP
            "JG1684",  # XGB/FP+RDKit
            "JG1685",  # D-MPNN/CGR
            "JG1686",  # D-MPNN/CGR+RDKit
    ],
        [  # _25
            "JG1677",  # FFN/OHE
            "JG1678",  # XGB/FP
            "JG1679",  # XGB/FP+RDKit
            "JG1680",  # D-MPNN/CGR
            "JG1681",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1672",  # FFN/OHE
            "JG1673",  # XGB/FP
            "JG1674",  # XGB/FP+RDKit
            "JG1675",  # D-MPNN/CGR
            "JG1676",  # D-MPNN/CGR+RDKit
    ],
        [  # _34
            "JG1667",  # FFN/OHE
            "JG1668",  # XGB/FP
            "JG1669",  # XGB/FP+RDKit
            "JG1670",  # D-MPNN/CGR
            "JG1671",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1662",  # FFN/OHE
            "JG1663",  # XGB/FP
            "JG1664",  # XGB/FP+RDKit
            "JG1665",  # D-MPNN/CGR
            "JG1666",  # D-MPNN/CGR+RDKit
    ],
        [  # _50
            "JG1657",  # FFN/OHE
            "JG1658",  # XGB/FP
            "JG1659",  # XGB/FP+RDKit
            "JG1660",  # D-MPNN/CGR
            "JG1661",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1652",  # FFN/OHE
            "JG1653",  # XGB/FP
            "JG1654",  # XGB/FP+RDKit
            "JG1655",  # D-MPNN/CGR
            "JG1656",  # D-MPNN/CGR+RDKit
    ],
        [  # _70
            "JG1647",  # FFN/OHE
            "JG1648",  # XGB/FP
            "JG1649",  # XGB/FP+RDKit
            "JG1650",  # D-MPNN/CGR
            "JG1651",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1642",  # FFN/OHE
            "JG1643",  # XGB/FP
            "JG1644",  # XGB/FP+RDKit
            "JG1645",  # D-MPNN/CGR
            "JG1646",  # D-MPNN/CGR+RDKit
    ],
]


sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["3D_10", "3D_15", "3D_20", "3D_25", "3D_30", "3D_34", "3D_40", "3D_50", "3D_60", "3D_70", "3D_80"]
}

chance_level = [get_chance_ap(k, set_type="val") for k in sample_counts.keys()]  # same order as sample_counts


exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    style="Model+Features",
    hue="Model+Features",
    hue_order=order,
    style_order=order,
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.5, 0.9))
ax.set_xticks(
    [32*2**n for n in range(10)], 
    [f"{32*2**n}" for n in range(10)],
)
ax.legend(loc="upper left", title=None)


fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _10
            "JG1692",  # FFN/OHE
            "JG1693",  # XGB/FP
            "JG1694",  # XGB/FP+RDKit
            "JG1695",  # D-MPNN/CGR
            "JG1696",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1687",  # FFN/OHE
            "JG1688",  # XGB/FP
            "JG1689",  # XGB/FP+RDKit
            "JG1690",  # D-MPNN/CGR
            "JG1691",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1682",  # FFN/OHE
            "JG1683",  # XGB/FP
            "JG1684",  # XGB/FP+RDKit
            "JG1685",  # D-MPNN/CGR
            "JG1686",  # D-MPNN/CGR+RDKit
    ],
        [  # _25
            "JG1677",  # FFN/OHE
            "JG1678",  # XGB/FP
            "JG1679",  # XGB/FP+RDKit
            "JG1680",  # D-MPNN/CGR
            "JG1681",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1672",  # FFN/OHE
            "JG1673",  # XGB/FP
            "JG1674",  # XGB/FP+RDKit
            "JG1675",  # D-MPNN/CGR
            "JG1676",  # D-MPNN/CGR+RDKit
    ],
        [  # _34
            "JG1667",  # FFN/OHE
            "JG1668",  # XGB/FP
            "JG1669",  # XGB/FP+RDKit
            "JG1670",  # D-MPNN/CGR
            "JG1671",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1662",  # FFN/OHE
            "JG1663",  # XGB/FP
            "JG1664",  # XGB/FP+RDKit
            "JG1665",  # D-MPNN/CGR
            "JG1666",  # D-MPNN/CGR+RDKit
    ],
        [  # _50
            "JG1657",  # FFN/OHE
            "JG1658",  # XGB/FP
            "JG1659",  # XGB/FP+RDKit
            "JG1660",  # D-MPNN/CGR
            "JG1661",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1652",  # FFN/OHE
            "JG1653",  # XGB/FP
            "JG1654",  # XGB/FP+RDKit
            "JG1655",  # D-MPNN/CGR
            "JG1656",  # D-MPNN/CGR+RDKit
    ],
        [  # _70
            "JG1647",  # FFN/OHE
            "JG1648",  # XGB/FP
            "JG1649",  # XGB/FP+RDKit
            "JG1650",  # D-MPNN/CGR
            "JG1651",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1642",  # FFN/OHE
            "JG1643",  # XGB/FP
            "JG1644",  # XGB/FP+RDKit
            "JG1645",  # D-MPNN/CGR
            "JG1646",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["3D_10", "3D_15", "3D_20", "3D_25", "3D_30", "3D_34", "3D_40", "3D_50", "3D_60", "3D_70", "3D_80"]
}

chance_level = [get_chance_ap(k, set_type="test") for k in sample_counts.keys()]  # same order as sample_counts

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    dashes=dashes,
    errorbar=errorbar,
    err_style="bars",
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.5, 0.85))
ax.set_xticks(
    [32*2**n for n in range(10)], 
    [f"{32*2**n}" for n in range(10)],
)
ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# repetition of the last plot, but with number of seen building blocks on the x axis instead of training samples
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot
experiment_ids = [        
        [  # _10
            "JG1692",  # FFN/OHE
            "JG1693",  # XGB/FP
            "JG1694",  # XGB/FP+RDKit
            "JG1695",  # D-MPNN/CGR
            "JG1696",  # D-MPNN/CGR+RDKit
    ],
        [  # _15
            "JG1687",  # FFN/OHE
            "JG1688",  # XGB/FP
            "JG1689",  # XGB/FP+RDKit
            "JG1690",  # D-MPNN/CGR
            "JG1691",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1682",  # FFN/OHE
            "JG1683",  # XGB/FP
            "JG1684",  # XGB/FP+RDKit
            "JG1685",  # D-MPNN/CGR
            "JG1686",  # D-MPNN/CGR+RDKit
    ],
        [  # _25
            "JG1677",  # FFN/OHE
            "JG1678",  # XGB/FP
            "JG1679",  # XGB/FP+RDKit
            "JG1680",  # D-MPNN/CGR
            "JG1681",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1672",  # FFN/OHE
            "JG1673",  # XGB/FP
            "JG1674",  # XGB/FP+RDKit
            "JG1675",  # D-MPNN/CGR
            "JG1676",  # D-MPNN/CGR+RDKit
    ],
        [  # _34
            "JG1667",  # FFN/OHE
            "JG1668",  # XGB/FP
            "JG1669",  # XGB/FP+RDKit
            "JG1670",  # D-MPNN/CGR
            "JG1671",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1662",  # FFN/OHE
            "JG1663",  # XGB/FP
            "JG1664",  # XGB/FP+RDKit
            "JG1665",  # D-MPNN/CGR
            "JG1666",  # D-MPNN/CGR+RDKit
    ],
        [  # _50
            "JG1657",  # FFN/OHE
            "JG1658",  # XGB/FP
            "JG1659",  # XGB/FP+RDKit
            "JG1660",  # D-MPNN/CGR
            "JG1661",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1652",  # FFN/OHE
            "JG1653",  # XGB/FP
            "JG1654",  # XGB/FP+RDKit
            "JG1655",  # D-MPNN/CGR
            "JG1656",  # D-MPNN/CGR+RDKit
    ],
        [  # _70
            "JG1647",  # FFN/OHE
            "JG1648",  # XGB/FP
            "JG1649",  # XGB/FP+RDKit
            "JG1650",  # D-MPNN/CGR
            "JG1651",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1642",  # FFN/OHE
            "JG1643",  # XGB/FP
            "JG1644",  # XGB/FP+RDKit
            "JG1645",  # D-MPNN/CGR
            "JG1646",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_buildingblock_count(k)
    for k in ["3D_10", "3D_15", "3D_20", "3D_25", "3D_30", "3D_34", "3D_40", "3D_50", "3D_60", "3D_70", "3D_80"]
}

chance_level = [get_chance_ap(k, set_type="test") for k in sample_counts.keys()]  # same order as sample_counts

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# set plot
fig, ax = plt.subplots(figsize=(4.75,4))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=metric,
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    dashes=dashes,
    errorbar=errorbar,
    err_style="bars",
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

sns.lineplot(x=sample_counts.values(), y=chance_level, c="black", label="Chance level")

ax.set_xlabel("Training data unique building blocks")
ax.set_ylabel("AUPRC")
#ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0.5, 0.85))
#ax.set_xticks(
#    [32*2**n for n in range(10)], 
#    [f"{32*2**n}" for n in range(10)],
#)
ax.legend(loc="center right", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.svg", format="svg")
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}.png", dpi=300)

In [None]:
# Plot scaled using chance level for better comparability across sizes
# i.e. we divide by the chance level so that the chance result is now 0 and best is still 1
# choose the metric, i.e. accuracy or AUROC
metric = "test/avgPrecision_macro"

# choose the data to plot

experiment_ids = [        
#       [  # _10  # do not plot b/c excessive variance
#           "JG1692",  # FFN/OHE
#           "JG1693",  # XGB/FP
#           "JG1694",  # XGB/FP+RDKit
#           "JG1695",  # D-MPNN/CGR
#           "JG1696",  # D-MPNN/CGR+RDKit
#   ],
        [  # _15
            "JG1687",  # FFN/OHE
            "JG1688",  # XGB/FP
            "JG1689",  # XGB/FP+RDKit
            "JG1690",  # D-MPNN/CGR
            "JG1691",  # D-MPNN/CGR+RDKit
    ],
        [  # _20
            "JG1682",  # FFN/OHE
            "JG1683",  # XGB/FP
            "JG1684",  # XGB/FP+RDKit
            "JG1685",  # D-MPNN/CGR
            "JG1686",  # D-MPNN/CGR+RDKit
    ],
        [  # _25
            "JG1677",  # FFN/OHE
            "JG1678",  # XGB/FP
            "JG1679",  # XGB/FP+RDKit
            "JG1680",  # D-MPNN/CGR
            "JG1681",  # D-MPNN/CGR+RDKit
    ],
        [  # _30
            "JG1672",  # FFN/OHE
            "JG1673",  # XGB/FP
            "JG1674",  # XGB/FP+RDKit
            "JG1675",  # D-MPNN/CGR
            "JG1676",  # D-MPNN/CGR+RDKit
    ],
        [  # _34
            "JG1667",  # FFN/OHE
            "JG1668",  # XGB/FP
            "JG1669",  # XGB/FP+RDKit
            "JG1670",  # D-MPNN/CGR
            "JG1671",  # D-MPNN/CGR+RDKit
    ],
        [  # _40
            "JG1662",  # FFN/OHE
            "JG1663",  # XGB/FP
            "JG1664",  # XGB/FP+RDKit
            "JG1665",  # D-MPNN/CGR
            "JG1666",  # D-MPNN/CGR+RDKit
    ],
        [  # _50
            "JG1657",  # FFN/OHE
            "JG1658",  # XGB/FP
            "JG1659",  # XGB/FP+RDKit
            "JG1660",  # D-MPNN/CGR
            "JG1661",  # D-MPNN/CGR+RDKit
    ],
        [  # _60
            "JG1652",  # FFN/OHE
            "JG1653",  # XGB/FP
            "JG1654",  # XGB/FP+RDKit
            "JG1655",  # D-MPNN/CGR
            "JG1656",  # D-MPNN/CGR+RDKit
    ],
        [  # _70
            "JG1647",  # FFN/OHE
            "JG1648",  # XGB/FP
            "JG1649",  # XGB/FP+RDKit
            "JG1650",  # D-MPNN/CGR
            "JG1651",  # D-MPNN/CGR+RDKit
    ],
        [  # _80
            "JG1642",  # FFN/OHE
            "JG1643",  # XGB/FP
            "JG1644",  # XGB/FP+RDKit
            "JG1645",  # D-MPNN/CGR
            "JG1646",  # D-MPNN/CGR+RDKit
    ],
]

sample_counts = {  # mean number of training samples for each split 
    k: get_sample_count(k)
    for k in ["3D_10", "3D_15", "3D_20", "3D_25", "3D_30", "3D_34", "3D_40", "3D_50", "3D_60", "3D_70", "3D_80"]
}

exps = [i for exp in experiment_ids for i in exp]
# filter the data
df_plot = df_all.loc[df_all['experiment_id'].isin(exps)]

# sort the values
sort_dict = dict(zip(exps, itertools.count()))
df_plot_x = df_plot.sort_values(by="experiment_id", kind="mergesort", key=lambda x: x.map(sort_dict)).copy()
df_plot_x["x"] = df_plot_x["tags"].apply(lambda x: sample_counts[x[0]])

# divide by chance level. Note that we obtain the chance level per individual fold so that the SEM still makes sense after scaling.
df_plot_x[f"{metric}_scaled"] = (df_plot_x[metric] - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1)) / (1 - df_plot_x.apply(lambda x: get_chance_ap(x["tags"][0], x["fold"], "test"), axis=1))

# set plot
fig, ax = plt.subplots(figsize=(3.625, 3))
sns.lineplot(
    ax=ax,
    data=df_plot_x, 
    x="x",
    y=f"{metric}_scaled",
    palette=palette,
    hue_order=order,
    style_order=order,
    style="Model+Features",
    hue="Model+Features",
    errorbar=errorbar,
    err_style="bars",
    dashes=dashes,
    linewidth=linewidth,
    markers=["o", "^", "s", "<", "p"],
)

ax.set_xlabel("Training data size")
ax.set_ylabel("AUPRC (relative improvement over chance)")
ax.set_xscale("log")
ax.xaxis.set_tick_params(labelrotation=0)
ax.set_ylim((0., 1.))
ax.set_xticks(
    [125*2**n for n in range(9)], 
    [f"{125*2**n}" for n in range(9)],
)
ax.legend(loc="upper left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}_relative.svg", format="svg", transparent=True)
fig.savefig(analysis_dir / f"metrics_{datadate}_3D_restricted-data_models_{metric.replace('/', '_')}_relative.png", dpi=300)