In [3]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas
import seaborn as sns

%matplotlib inline

In [10]:
sars_exp_path = Path("../ml_datasets/sars_exp_file.json")
mers_exp_path = Path("../ml_datasets/mers_exp_file.json")
sars_exp_data = {
    d["compound_id"]: d["experimental_data"]["pIC50"]
    for d in json.loads(sars_exp_path.read_text())
}
mers_exp_data = {
    d["compound_id"]: d["experimental_data"]["pIC50"]
    for d in json.loads(mers_exp_path.read_text())
}
exp_mean_dict = {
    "SARS": np.mean(list(sars_exp_data.values())),
    "MERS": np.mean(list(mers_exp_data.values())),
}
exp_mean_dict

{'SARS': 5.042179602888086, 'MERS': 5.772885682574917}

In [5]:
sars_preds = np.load("../predictions/sars.npy")
mers_preds = np.load("../predictions/mers.npy")

In [6]:
preds_df = pandas.DataFrame(
    {
        "Target": ["SARS"] * len(sars_preds) + ["MERS"] * len(mers_preds),
        "Predicted pIC$_{50}$": np.concatenate([sars_preds, mers_preds]),
    }
)
preds_df

Unnamed: 0,Target,Predicted pIC$_{50}$
0,SARS,5.073753
1,SARS,6.185442
2,SARS,7.083118
3,SARS,5.102748
4,SARS,6.275742
...,...,...
589,MERS,5.445300
590,MERS,5.445300
591,MERS,5.494909
592,MERS,5.909341


In [None]:
ax = sns.histplot(
    preds_df,
    x="Predicted pIC$_{50}$",
    hue="Target",
    hue_order=["SARS", "MERS"],
    # complementary=True,
)

for target, c in zip(["SARS", "MERS"], sns.color_palette()):
    ax.axvline(exp_mean_dict[target], ls="--", c=c, alpha=0.5)

ax.set_title("Test Set Predictions")
ax.get_figure().savefig(
    "../figures/test_set_predictions_from_scratch.png", bbox_inches="tight", dpi=200
)