In [41]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import glob

In [None]:
def as_list(x):
    """Return a Python list from possible types: list/dict/str/NaN."""
    if isinstance(x, list): 
        return x
    if isinstance(x, str):
        try:
            v = json.loads(x)
            return v if isinstance(v, list) else []
        except Exception:
            return []
    return []  # None/NaN or anything else

def task_from_item(it):
    """Extract 'task' from a dict item or return None."""
    if isinstance(it, dict) and 'task' in it:
        return it['task']
    return None

def items_from_labels_dict(d):
    """labels is a dict {task: 0/1/None}; make (task, y_true) pairs."""
    if isinstance(d, dict):
        return list(d.items())
    return []

Single Model

In [26]:
with open("../runs/dmpnn_tox21/metrics.json", "r") as f:
    metrics = json.load(f)

with open("../runs/dmpnn_tox21/test_metrics.json", "r") as f:
    test_metrics = json.load(f)

with open("../runs/dmpnn_tox21/val_metrics.json", "r") as f:
    val_metrics = json.load(f)

print("Available keys in metrics.json:", metrics.keys())
print("Available keys in test_metrics.json:", test_metrics.keys())


Available keys in metrics.json: dict_keys(['best_val_macro_auc', 'test_macro_auc', 'val_metrics_json', 'test_metrics_json', 'checkpoint', 'config'])
Available keys in test_metrics.json: dict_keys(['test_macro_auc', 'per_task_auc'])


In [27]:
print("Test macro AUC:", metrics.get("test_macro_auc"))

Test macro AUC: 0.8038765505995866


In [30]:
with open("../runs/dmpnn_tox21_ensemble/ensemble_metrics.json", "r") as f:
    ensemble = json.load(f)

In [31]:
ensemble

{'ensemble_size': 5,
 'split_seed': 7,
 'test_macro_auc': 0.8147003785188223,
 'test_macro_aupr': 0.44531724248492105,
 'per_task_auc': [0.8496153207505397,
  0.8733150111241984,
  0.8100632911392405,
  0.790174672489083,
  0.7447254150702426,
  0.7546333020840845,
  0.766221206259891,
  0.8077739912085741,
  0.8071764946764948,
  0.8008189529102078,
  0.9206748025843503,
  0.8512120819289604],
 'ckpts': ['runs/dmpnn_tox21_seed1/best.pt',
  'runs/dmpnn_tox21_seed2/best.pt',
  'runs/dmpnn_tox21_seed3/best.pt',
  'runs/dmpnn_tox21_seed4/best.pt',
  'runs/dmpnn_tox21_seed5/best.pt'],
 'data_csv': 'data/tox21_multitask.csv',
 'smiles_col': 'smiles'}

In [40]:
run_dirs = [
    "../runs/dmpnn_tox21_seed1",
    "../runs/dmpnn_tox21_seed2",
    "../runs/dmpnn_tox21_seed3",
    "../runs/dmpnn_tox21_seed4",
    "../runs/dmpnn_tox21_seed5"]

macro_auc, macro_aupr, per_task = [], [], []
for run in run_dirs:
    with open(f"{run}/test_metrics.json", "r") as f:
        metrics = json.load(f)
    macro_auc.append(metrics["test_macro_auc"])
    per_task.append(metrics["per_task_auc"])

macro_auc = np.array(macro_auc)
per_task = np.array(per_task)
print(f"Macro ROC–AUC: {macro_auc.mean():.3f} ± {macro_auc.std():.3f}")
for j, (m, s) in enumerate(zip(per_task.mean(axis=0), per_task.std(axis=0))):
    print(f"Task {j+1:2d}: {m:.3f} ± {s:.3f}")


Macro ROC–AUC: 0.801 ± 0.008
Task  1: 0.834 ± 0.012
Task  2: 0.867 ± 0.030
Task  3: 0.801 ± 0.005
Task  4: 0.776 ± 0.035
Task  5: 0.734 ± 0.014
Task  6: 0.743 ± 0.011
Task  7: 0.756 ± 0.015
Task  8: 0.783 ± 0.012
Task  9: 0.788 ± 0.022
Task 10: 0.788 ± 0.021
Task 11: 0.907 ± 0.007
Task 12: 0.830 ± 0.029


In [98]:
JSONL_PATH = "./data/explanations_verdicts_32.json"
JSONL_PATH_2 = "./data/explanations_verdicts_128.json"
JSONL_PATH_3 = "./data/explanations_verdicts64.json"
df_32 = pd.read_json(JSONL_PATH)
df_128 = pd.read_json(JSONL_PATH_2)

df_64 = pd.read_json(JSONL_PATH_3)


In [107]:
df = df_32.copy()
df['has_concept'] = df['detected_concepts'].apply(lambda x: isinstance(x, (list, tuple, set)) and len(x) > 0)
df['alert_overlap'] = pd.to_numeric(df['alert_overlap'], errors='coerce')

# ---------- Build labels_long (task, y_true, row_id) ----------
labels_long = df[['labels']].reset_index().rename(columns={'index':'row_id'}.assign(items=lambda x: x['labels'].apply(items_from_labels_dict).explode('items', ignore_index=True))
labels_long[['task','y_true']] = labels_long['items'].apply(pd.Series)
labels_long.drop(columns=['items','labels'], inplace=True)
# Preserve mapping back to df rows to join per-row fields later
labels_long = labels_long.merge(df[['has_concept','alert_overlap']].reset_index().rename(columns={'index':'row_id'}),on='row_id', how='left')

# ---------- Build predicted-positives long (task, row_id) ----------
# Use 'pred_pos_list' if you already created it earlier; otherwise derive from 'predicted_positive_tasks'
if 'pred_pos_list' not in df.columns:
    df['pred_pos_list'] = df['predicted_positive_tasks'].apply(as_list).apply(lambda L: [task_from_item(it) for it in L if task_from_item(it) is not None])

pred_long = df[['pred_pos_list']].reset_index().rename(,columns={'index':'row_id'}).explode('pred_pos_list', ignore_index=False).dropna().reset_index(drop=True).rename(columns={'pred_pos_list':'task'})
pred_long['pred_pos'] = 1

# ---------- Aggregates ----------
# Support and total n per task
support = labels_long.groupby('task', as_index=False).agg(P=('y_true', lambda s: int(np.nansum(s==1))), n=('y_true', 'size')))

# Predicted positives per task
pred_counts = pred_long.groupby('task', as_index=False).agg(Pred_plus=('pred_pos','sum'))

# Overall coverage/IG mass per task
overall_cov = labels_long.groupby('task', as_index=False).agg(coverage_overall=('has_concept','mean'),igmass_overall=('alert_overlap','mean'))

# Among predicted positives: join row_id back to per-row fields
pred_join = pred_long.merge(df[['has_concept','alert_overlap']].reset_index().rename(,columns={'index':'row_id'}),on='row_id', how='left')
cov_predpos = pred_join.groupby('task', as_index=False).agg(coverage_predpos=('has_concept','mean'),igmass_predpos=('alert_overlap','mean'))

# Among true positives (label==1)
gtpos = labels_long[labels_long['y_true']==1]
cov_gtpos = gtpos.groupby('task', as_index=False).agg(coverage_gtpos=('has_concept','mean'),igmass_gtpos=('alert_overlap','mean'))

# ---------- Final per-task table ----------
per_task = support.merge(pred_counts, on='task', how='left').merge(overall_cov, on='task', how='left').merge(cov_predpos, on='task', how='left').merge(cov_gtpos, on='task', how='left').fillna({'Pred_plus':0})

per_task = per_task[['task','n','P','Pred_plus',
                     'coverage_overall','igmass_overall',
                     'coverage_predpos','igmass_predpos',
                     'coverage_gtpos','igmass_gtpos']].sort_values('task')

per_task.head(12)


SyntaxError: invalid syntax. Perhaps you forgot a comma? (2125656734.py, line 6)

In [89]:
coverage_overall = df['has_concept'].mean()
igm_overall = df['alert_overlap'].mean()

mask_pred_any = df['pred_pos_list'].apply(lambda L: len(L)>0)
coverage_pred_any = df.loc[mask_pred_any, 'has_concept'].mean()
igm_pred_any = df.loc[mask_pred_any, 'alert_overlap'].mean()

is_gt_pos_any = df['labels'].apply(lambda d: any((v==1) for v in d.values()))
coverage_gt_any = df.loc[is_gt_pos_any, 'has_concept'].mean()
igm_gt_any = df.loc[is_gt_pos_any, 'alert_overlap'].mean()

{
    'coverage_overall': coverage_overall,
    'igm_overall': igm_overall,
    'coverage_pred_any': coverage_pred_any,
    'igm_pred_any': igm_pred_any,
    'coverage_gt_any': coverage_gt_any,
    'igm_gt_any': igm_gt_any
}


{'coverage_overall': 0.3827683615819209,
 'igm_overall': 0.3724257493285105,
 'coverage_pred_any': 0.4041916167664671,
 'igm_pred_any': 0.24808275131031624,
 'coverage_gt_any': 0.4984025559105431,
 'igm_gt_any': 0.2385212091410329}

In [96]:
df = df_128.copy()
df['has_concept'] = df['detected_concepts'].apply(lambda x: isinstance(x, (list, tuple, set)) and len(x) > 0)
df['alert_overlap'] = pd.to_numeric(df['alert_overlap'], errors='coerce')

# ---------- Build labels_long (task, y_true, row_id) ----------
labels_long = (
    df[['labels']].reset_index().rename(columns={'index':'row_id'})
    .assign(items=lambda x: x['labels'].apply(items_from_labels_dict))
    .explode('items', ignore_index=True)
)
labels_long[['task','y_true']] = labels_long['items'].apply(pd.Series)
labels_long.drop(columns=['items','labels'], inplace=True)
# Preserve mapping back to df rows to join per-row fields later
labels_long = labels_long.merge(
    df[['has_concept','alert_overlap']].reset_index().rename(columns={'index':'row_id'}),
    on='row_id', how='left'
)

# ---------- Build predicted-positives long (task, row_id) ----------
# Use 'pred_pos_list' if you already created it earlier; otherwise derive from 'predicted_positive_tasks'
if 'pred_pos_list' not in df.columns:
    df['pred_pos_list'] = df['predicted_positive_tasks'].apply(as_list).apply(
        lambda L: [task_from_item(it) for it in L if task_from_item(it) is not None]
    )

pred_long = (
    df[['pred_pos_list']].reset_index().rename(columns={'index':'row_id'})
    .explode('pred_pos_list', ignore_index=False).dropna().reset_index(drop=True)
    .rename(columns={'pred_pos_list':'task'})
)
pred_long['pred_pos'] = 1

# ---------- Aggregates ----------
# Support and total n per task
support = (
    labels_long.groupby('task', as_index=False)
    .agg(P=('y_true', lambda s: int(np.nansum(s==1))),
         n=('y_true', 'size'))
)

# Predicted positives per task
pred_counts = pred_long.groupby('task', as_index=False).agg(Pred_plus=('pred_pos','sum'))

# Overall coverage/IG mass per task
overall_cov = labels_long.groupby('task', as_index=False).agg(
    coverage_overall=('has_concept','mean'),
    igmass_overall=('alert_overlap','mean')
)

# Among predicted positives: join row_id back to per-row fields
pred_join = pred_long.merge(
    df[['has_concept','alert_overlap']].reset_index().rename(columns={'index':'row_id'}),
    on='row_id', how='left'
)
cov_predpos = pred_join.groupby('task', as_index=False).agg(
    coverage_predpos=('has_concept','mean'),
    igmass_predpos=('alert_overlap','mean')
)

# Among true positives (label==1)
gtpos = labels_long[labels_long['y_true']==1]
cov_gtpos = gtpos.groupby('task', as_index=False).agg(
    coverage_gtpos=('has_concept','mean'),
    igmass_gtpos=('alert_overlap','mean')
)

# ---------- Final per-task table ----------
per_task = (support
            .merge(pred_counts, on='task', how='left')
            .merge(overall_cov, on='task', how='left')
            .merge(cov_predpos, on='task', how='left')
            .merge(cov_gtpos, on='task', how='left')
            .fillna({'Pred_plus':0})
           )

per_task = per_task[['task','n','P','Pred_plus',
                     'coverage_overall','igmass_overall',
                     'coverage_predpos','igmass_predpos',
                     'coverage_gtpos','igmass_gtpos']].sort_values('task')

per_task.head(12)


Unnamed: 0,task,n,P,Pred_plus,coverage_overall,igmass_overall,coverage_predpos,igmass_predpos,coverage_gtpos,igmass_gtpos
0,NR-AR,708,29,50,0.382768,0.240741,0.38,0.196313,0.62069,0.152182
1,NR-AR-LBD,708,27,43,0.382768,0.240741,0.418605,0.313073,0.555556,0.223881
2,NR-AhR,708,100,8,0.382768,0.240741,0.5,0.667026,0.51,0.399
3,NR-Aromatase,708,30,173,0.382768,0.240741,0.462428,0.334002,0.466667,0.113027
4,NR-ER,708,87,25,0.382768,0.240741,0.2,0.550407,0.367816,0.249999
5,NR-ER-LBD,708,49,23,0.382768,0.240741,0.217391,0.456235,0.346939,0.128523
6,NR-PPAR-gamma,708,22,50,0.382768,0.240741,0.56,0.437769,0.409091,0.45753
7,SR-ARE,708,102,286,0.382768,0.240741,0.423077,0.213627,0.529412,0.197548
8,SR-ATAD5,708,33,9,0.382768,0.240741,0.555556,0.381369,0.454545,0.347582
9,SR-HSE,708,39,89,0.382768,0.240741,0.516854,0.324239,0.589744,0.197861


In [97]:
coverage_overall = df['has_concept'].mean()
igm_overall = df['alert_overlap'].mean()

mask_pred_any = df['pred_pos_list'].apply(lambda L: len(L)>0)
coverage_pred_any = df.loc[mask_pred_any, 'has_concept'].mean()
igm_pred_any = df.loc[mask_pred_any, 'alert_overlap'].mean()

is_gt_pos_any = df['labels'].apply(lambda d: any((v==1) for v in d.values()))
coverage_gt_any = df.loc[is_gt_pos_any, 'has_concept'].mean()
igm_gt_any = df.loc[is_gt_pos_any, 'alert_overlap'].mean()

{
    'coverage_overall': coverage_overall,
    'igm_overall': igm_overall,
    'coverage_pred_any': coverage_pred_any,
    'igm_pred_any': igm_pred_any,
    'coverage_gt_any': coverage_gt_any,
    'igm_gt_any': igm_gt_any
}


{'coverage_overall': 0.3827683615819209,
 'igm_overall': 0.24074053502490622,
 'coverage_pred_any': 0.4041916167664671,
 'igm_pred_any': 0.24703101536447372,
 'coverage_gt_any': 0.4984025559105431,
 'igm_gt_any': 0.23558890635271879}

In [99]:
# ---------- Ensure flat columns ----------
df = df_64.copy()
df['has_concept'] = df['detected_concepts'].apply(lambda x: isinstance(x, (list, tuple, set)) and len(x) > 0)
df['alert_overlap'] = pd.to_numeric(df['alert_overlap'], errors='coerce')

# ---------- Build labels_long (task, y_true, row_id) ----------
labels_long = (
    df[['labels']].reset_index().rename(columns={'index':'row_id'})
    .assign(items=lambda x: x['labels'].apply(items_from_labels_dict))
    .explode('items', ignore_index=True)
)
labels_long[['task','y_true']] = labels_long['items'].apply(pd.Series)
labels_long.drop(columns=['items','labels'], inplace=True)
# Preserve mapping back to df rows to join per-row fields later
labels_long = labels_long.merge(
    df[['has_concept','alert_overlap']].reset_index().rename(columns={'index':'row_id'}),
    on='row_id', how='left'
)

# ---------- Build predicted-positives long (task, row_id) ----------
# Use 'pred_pos_list' if you already created it earlier; otherwise derive from 'predicted_positive_tasks'
if 'pred_pos_list' not in df.columns:
    df['pred_pos_list'] = df['predicted_positive_tasks'].apply(as_list).apply(
        lambda L: [task_from_item(it) for it in L if task_from_item(it) is not None]
    )

pred_long = (
    df[['pred_pos_list']].reset_index().rename(columns={'index':'row_id'})
    .explode('pred_pos_list', ignore_index=False).dropna().reset_index(drop=True)
    .rename(columns={'pred_pos_list':'task'})
)
pred_long['pred_pos'] = 1

# ---------- Aggregates ----------
# Support and total n per task
support = (
    labels_long.groupby('task', as_index=False)
    .agg(P=('y_true', lambda s: int(np.nansum(s==1))),
         n=('y_true', 'size'))
)

# Predicted positives per task
pred_counts = pred_long.groupby('task', as_index=False).agg(Pred_plus=('pred_pos','sum'))

# Overall coverage/IG mass per task
overall_cov = labels_long.groupby('task', as_index=False).agg(
    coverage_overall=('has_concept','mean'),
    igmass_overall=('alert_overlap','mean')
)

# Among predicted positives: join row_id back to per-row fields
pred_join = pred_long.merge(
    df[['has_concept','alert_overlap']].reset_index().rename(columns={'index':'row_id'}),
    on='row_id', how='left'
)
cov_predpos = pred_join.groupby('task', as_index=False).agg(
    coverage_predpos=('has_concept','mean'),
    igmass_predpos=('alert_overlap','mean')
)

# Among true positives (label==1)
gtpos = labels_long[labels_long['y_true']==1]
cov_gtpos = gtpos.groupby('task', as_index=False).agg(
    coverage_gtpos=('has_concept','mean'),
    igmass_gtpos=('alert_overlap','mean')
)

# ---------- Final per-task table ----------
per_task = (support
            .merge(pred_counts, on='task', how='left')
            .merge(overall_cov, on='task', how='left')
            .merge(cov_predpos, on='task', how='left')
            .merge(cov_gtpos, on='task', how='left')
            .fillna({'Pred_plus':0})
           )

per_task = per_task[['task','n','P','Pred_plus',
                     'coverage_overall','igmass_overall',
                     'coverage_predpos','igmass_predpos',
                     'coverage_gtpos','igmass_gtpos']].sort_values('task')

per_task.head(12)


Unnamed: 0,task,n,P,Pred_plus,coverage_overall,igmass_overall,coverage_predpos,igmass_predpos,coverage_gtpos,igmass_gtpos
0,NR-AR,708,29,69,0.382768,0.236041,0.391304,0.225008,0.62069,0.14953
1,NR-AR-LBD,708,27,75,0.382768,0.236041,0.453333,0.262491,0.555556,0.219841
2,NR-AhR,708,100,199,0.382768,0.236041,0.39196,0.373487,0.51,0.398946
3,NR-Aromatase,708,30,119,0.382768,0.236041,0.512605,0.328018,0.466667,0.113513
4,NR-ER,708,87,174,0.382768,0.236041,0.298851,0.391284,0.367816,0.246178
5,NR-ER-LBD,708,49,63,0.382768,0.236041,0.31746,0.378844,0.346939,0.123053
6,NR-PPAR-gamma,708,22,111,0.382768,0.236041,0.495495,0.380563,0.409091,0.453037
7,SR-ARE,708,102,364,0.382768,0.236041,0.417582,0.193217,0.529412,0.196735
8,SR-ATAD5,708,33,11,0.382768,0.236041,0.545455,0.407477,0.454545,0.342061
9,SR-HSE,708,39,81,0.382768,0.236041,0.54321,0.328836,0.589744,0.196897


In [94]:
coverage_overall = df['has_concept'].mean()
igm_overall = df['alert_overlap'].mean()

mask_pred_any = df['pred_pos_list'].apply(lambda L: len(L)>0)
coverage_pred_any = df.loc[mask_pred_any, 'has_concept'].mean()
igm_pred_any = df.loc[mask_pred_any, 'alert_overlap'].mean()

is_gt_pos_any = df['labels'].apply(lambda d: any((v==1) for v in d.values()))
coverage_gt_any = df.loc[is_gt_pos_any, 'has_concept'].mean()
igm_gt_any = df.loc[is_gt_pos_any, 'alert_overlap'].mean()

{
    'coverage_overall': coverage_overall,
    'igm_overall': igm_overall,
    'coverage_pred_any': coverage_pred_any,
    'igm_pred_any': igm_pred_any,
    'coverage_gt_any': coverage_gt_any,
    'igm_gt_any': igm_gt_any
}


{'coverage_overall': 0.3827683615819209,
 'igm_overall': 0.24074053502490622,
 'coverage_pred_any': 0.4041916167664671,
 'igm_pred_any': 0.24703101536447372,
 'coverage_gt_any': 0.4984025559105431,
 'igm_gt_any': 0.23558890635271879}

In [102]:
# ==== Setup ====
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt

# Your dataframes (one per model). If you only have the ensemble, set df_baseline=None.
df_ensemble = df_64
# df_baseline = df_xgb  # if available; otherwise set to None

TASK_ORDER = [
    "NR-AR","NR-AR-LBD","NR-AhR","NR-Aromatase",
    "NR-ER","NR-ER-LBD","NR-PPAR-gamma",
    "SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]

# Columns that *might* contain per-row dicts of scores {task: prob or logit}
CANDIDATE_SCORE_COLS_ENSEMBLE = ["probs", "pred_probs", "yhat_probs", "scores", "pred_scores", "logits"]
CANDIDATE_SCORE_COLS_BASELINE = ["xgb_probs", "baseline_probs", "xgb_scores", "xgb_logits"]

def pick_score_column(df, candidates):
    for c in candidates:
        if c in df.columns:
            # Heuristic: dict/JSON-like in the first non-null row
            v = df[c].dropna().iloc[0]
            if isinstance(v, dict):
                return c
    raise KeyError(f"None of {candidates} found as dict-like score columns.")

def sigmoid(x): return 1/(1+np.exp(-x))

def extract_y(df, task):
    """Return array of 0/1 labels for given task from df['labels'] dict column."""
    def get_label(d):
        if isinstance(d, dict):
            v = d.get(task, None)
            if v is None: return np.nan
            return int(v == 1)
        return np.nan
    y = df["labels"].apply(get_label).astype("float")
    return y.values  # may contain NaN if missing for some rows

def extract_scores(df, task, score_col):
    """Return probability scores for given task from a dict-like column.
       If column looks like logits, apply sigmoid; else pass through."""
    s = df[score_col].apply(lambda d: d.get(task) if isinstance(d, dict) else np.nan).astype("float").values
    # Heuristic to detect logits: if values often <0 or >1
    finite = np.isfinite(s)
    if finite.sum() > 0:
        frac_outside = np.mean((s[finite] < 0) | (s[finite] > 1))
        if frac_outside > 0.2:  # many outside [0,1] => likely logits
            s = sigmoid(s)
    return s

def pr_data(y, s):
    """Drop NaNs and compute PR arrays + AP."""
    m = np.isfinite(y) & np.isfinite(s)
    yv, sv = y[m].astype(int), s[m].astype(float)
    if yv.sum() == 0:
        return None  # no positives -> PR curve undefined (or a point at precision=prevalence)
    precision, recall, _ = precision_recall_curve(yv, sv)
    ap = average_precision_score(yv, sv)
    return recall, precision, ap

# ==== Locate score columns ====
score_col_ens = pick_score_column(df_ensemble, CANDIDATE_SCORE_COLS_ENSEMBLE)
df_baseline_exists = ('df_baseline' in globals()) and (df_baseline is not None)
if df_baseline_exists:
    score_col_base = pick_score_column(df_baseline, CANDIDATE_SCORE_COLS_BASELINE)

# ==== Plot grid ====
fig, axes = plt.subplots(3, 4, figsize=(12, 9), sharex=False, sharey=False)
axes = axes.ravel()

for i, task in enumerate(TASK_ORDER):
    ax = axes[i]

    # Ensemble
    y_e = extract_y(df_ensemble, task)
    s_e = extract_scores(df_ensemble, task, score_col_ens)
    rpe = pr_data(y_e, s_e)
    if rpe is not None:
        r_e, p_e, ap_e = rpe
        ax.plot(r_e, p_e, label=f"Ensemble (AP={ap_e:.3f})")

    # Baseline (optional)
    if df_baseline_exists:
        y_b = extract_y(df_baseline, task)
        s_b = extract_scores(df_baseline, task, score_col_base)
        rpb = pr_data(y_b, s_b)
        if rpb is not None:
            r_b, p_b, ap_b = rpb
            ax.plot(r_b, p_b, linestyle="--", label=f"ECFP+XGB (AP={ap_b:.3f})")

    ax.set_title(task.replace("-gamma", "-\\gamma"))
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.grid(True, linewidth=0.3)
    ax.legend(fontsize=8, loc="best")

# Hide any unused subplots if TASK_ORDER < 12
for j in range(len(TASK_ORDER), len(axes)):
    fig.delaxes(axes[j])

fig.tight_layout()
plt.savefig("figures/pr_curves_grid.png", dpi=300, bbox_inches="tight")
plt.close()
print("Saved: figures/pr_curves_grid.png")


KeyError: "None of ['probs', 'pred_probs', 'yhat_probs', 'scores', 'pred_scores', 'logits'] found as dict-like score columns."

In [103]:
df_ensemble

Unnamed: 0,index_in_test,smiles,focus,predicted_positive_tasks,top3_tasks,labels,detected_concepts,alert_overlap,why,mitigation,image
0,102,C/C(=N/NC(=O)c1ccncc1)C(=O)O,"{'task': 'SR-ARE', 'prob': 0.5606786012649531,...",[],"[{'task': 'SR-ARE', 'prob': 0.5606786012649531...","{'NR-AR': None, 'NR-AR-LBD': None, 'NR-AhR': N...",[],,"Predicted profile: SR-ARE p=0.56, SR-HSE p=0.3...",,runs/dmpnn_tox21_ensemble/explanations/exp_000...
1,531,CC(=NC#N)N(C)Cc1ccc(Cl)nc1,"{'task': 'SR-ARE', 'prob': 0.43759843707084606...",[],"[{'task': 'SR-ARE', 'prob': 0.4375984370708460...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 0, 'NR-...",[haloarene],0.249293,haloarene is consistent with the predicted pro...,Reduce halogenation or swap to less lipophilic...,runs/dmpnn_tox21_ensemble/explanations/exp_000...
2,830,CC(=O)c1ccccn1,"{'task': 'SR-ARE', 'prob': 0.537938475608825, ...","[{'task': 'NR-AhR', 'prob': 0.4971012771129600...","[{'task': 'SR-ARE', 'prob': 0.537938475608825}...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 0, 'NR-...",[],,"Predicted profile: SR-ARE p=0.54, NR-AhR p=0.5...",,runs/dmpnn_tox21_ensemble/explanations/exp_000...
3,831,CC(=O)c1cccnc1,"{'task': 'SR-ARE', 'prob': 0.5695862770080561,...","[{'task': 'NR-AhR', 'prob': 0.541131854057312,...","[{'task': 'SR-ARE', 'prob': 0.5695862770080561...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 0, 'NR-...",[],,"Predicted profile: SR-ARE p=0.57, NR-AhR p=0.5...",,runs/dmpnn_tox21_ensemble/explanations/exp_000...
4,917,CC(C)(C)NCC(O)c1ccc(O)c(CO)n1,"{'task': 'SR-p53', 'prob': 0.6667737364768981,...","[{'task': 'SR-p53', 'prob': 0.6667737364768981...","[{'task': 'SR-p53', 'prob': 0.6667737364768981...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 0, 'NR-...",[],,"Predicted profile: SR-p53 p=0.67, SR-ARE p=0.6...",,runs/dmpnn_tox21_ensemble/explanations/exp_000...
...,...,...,...,...,...,...,...,...,...,...,...
703,7769,c1ccc2c(c1)-c1cccc3c1c-2cc1ccccc13,"{'task': 'NR-Aromatase', 'prob': 0.85843360424...","[{'task': 'NR-Aromatase', 'prob': 0.8584336042...","[{'task': 'NR-Aromatase', 'prob': 0.8584336042...","{'NR-AR': 0, 'NR-AR-LBD': 1, 'NR-AhR': 1, 'NR-...",[polycyclic aromatic],1.186477,polycyclic aromatic is consistent with the pre...,,runs/dmpnn_tox21_ensemble/explanations/exp_070...
704,7785,c1ccc2c(c1)[nH]c1cnccc12,"{'task': 'NR-Aromatase', 'prob': 0.76600891351...","[{'task': 'NR-Aromatase', 'prob': 0.7660089135...","[{'task': 'NR-Aromatase', 'prob': 0.7660089135...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 1, 'NR-...",[],,"Predicted profile: NR-Aromatase p=0.77, SR-HSE...",,runs/dmpnn_tox21_ensemble/explanations/exp_070...
705,7798,c1ccc2cc3c4ccccc4c4ccccc4c3cc2c1,"{'task': 'NR-Aromatase', 'prob': 0.88854598999...","[{'task': 'NR-Aromatase', 'prob': 0.8885459899...","[{'task': 'NR-Aromatase', 'prob': 0.8885459899...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 1, 'NR-...",[],,"Predicted profile: NR-Aromatase p=0.89, SR-HSE...",,runs/dmpnn_tox21_ensemble/explanations/exp_070...
706,7819,c1cnc2c(n1)CCCC2,"{'task': 'NR-Aromatase', 'prob': 0.66798466444...","[{'task': 'NR-AhR', 'prob': 0.510206758975982,...","[{'task': 'NR-Aromatase', 'prob': 0.6679846644...","{'NR-AR': 0, 'NR-AR-LBD': 0, 'NR-AhR': 0, 'NR-...",[],,"Predicted profile: NR-Aromatase p=0.67, SR-HSE...",,runs/dmpnn_tox21_ensemble/explanations/exp_070...
