# Analyze decision threshold
We notice that some models always predict 1 for binary_A. We try modifiying the decision threshold and see the impact on validation metrics

In [None]:
import pathlib
import sys
sys.path.append(str(pathlib.Path().absolute().parent))

from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, accuracy_score, balanced_accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, fbeta_score, PrecisionRecallDisplay
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

from src.util.definitions import PRED_DIR, DATA_ROOT, LOG_DIR
from src.util.io import read_predictions
from ghost import optimize_threshold_from_predictions

In [None]:
api = wandb.Api(timeout=59)

runs = api.runs("jugoetz/synferm-predictions", filters={"group": "2023-12-20-202602_330364"})  # best 0D model JG1309

In [None]:
# get run info
config_list = [run.config for run in runs]
summary_list = [run.summary._json_dict for run in runs]
name_list = [run.name for run in runs]

name_list

In [None]:
# read train predictions
preds = [read_predictions(n, "train") for n in name_list]
preds[0]

In [None]:
# combine with ground truth
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]
comb[0]

In [None]:
# calculate thresholded metrics over many thresholds
# and obtain best threshold with GHOST with Cohen's Kappa method
bal_acc = []
f_beta = []
x = np.arange(0, 1.0001, 0.01)
y_bal_acc = np.empty((101, 9))
y_f_beta = np.empty((101, 9))
best_thresh = np.empty(9)
for fold in range(9):
    for i, threshold in enumerate(x):
        y_bal_acc[i, fold] = balanced_accuracy_score(comb[fold]["binary_A"], (comb[fold]["pred_0"] > threshold))
        y_f_beta[i, fold] = fbeta_score(comb[fold]["binary_A"], (comb[fold]["pred_0"] > threshold), beta=0.1)
    best_thresh[fold] = optimize_threshold_from_predictions(comb[fold]["binary_A"], comb[fold]["pred_0"], thresholds=x, ThOpt_metrics="Kappa") 

In [None]:
# obtain best threshold with GHOST with AUROC method
best_thresh_roc = np.empty(9)
for fold in range(9):
    best_thresh_roc[fold] = optimize_threshold_from_predictions(comb[fold]["binary_A"], comb[fold]["pred_0"], thresholds=x, ThOpt_metrics="ROC") 

In [None]:
best_thresh

In [None]:
best_thresh_roc

In [None]:
# plot best thresholds from GHOST with metric value by threshold
fig, ax = plt.subplots(figsize=(6,4))
for i, fold in enumerate(name_list):
    line = ax.plot(x, y_bal_acc[:, i], label=fold[-1])
    ax.vlines(x=best_thresh[i], ymin=0, ymax=y_bal_acc[int((best_thresh[i] * 100).round()), i], ls=":", colors=line[0].get_color(), lw=1.2)
    ax.vlines(x=best_thresh_roc[i], ymin=0, ymax=y_bal_acc[int((best_thresh_roc[i] * 100).round()), i], ls="--", colors=line[0].get_color(), lw=1.2)

ax.set_xlabel("Decision boundary")
ax.set_ylabel("Balanced accuracy")
ax.set_xlim(0,1)
ax.set_ylim(0.4,1)
ax.legend()

### Maximize balanced accuracy

Simple way to pick the decision threshold is to calculate the balanced accuracy over different thresholds and pick the one that maximizes it.

Note:
- Since this is average recall per class, for our problem, this heavily tilts the predictions to minimize false negatives. That also means, recall for the positive class will take a heavy hit
- Picking the maximum can be a bit unstable if the balanced_accuracy/threshold curve is not so smooth.

In [None]:
# try maximizing balanced accuracy on all folds
for fold in range(9):
    threshold = np.argmax(y_bal_acc[:, fold]) / 100
    y_true = comb[fold]["binary_A"]
    y_prob = comb[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    print("balanced accuracy:", f"{balanced_accuracy_score(y_true, y_pred):.2f}")
    print("recall:", f"{recall_score(y_true, y_pred):.2f}")
    print("precision:", f"{precision_score(y_true, y_pred):.2f}")
    print("f_0.5 score:", f"{fbeta_score(y_true, y_pred, beta=0.5):.2f}")
    print()

# plot the PRC
PrecisionRecallDisplay.from_predictions(y_true, y_prob, pos_label=1, drop_intermediate=True, plot_chance_level=True)

### GHOST
The Rinicker lab has published GHOST (1) for decision threshold picking.
On a high level this:
- Takes startified samples from the training set (default: without replacement)
- Determines metric (default: Cohen's kappa) for all decision thresholds
- Calculates the median metric over stratified samples for all decision thresholds
- Returns the threshold with highest median metric

(1) Esposito, C.; A. Landrum, G.; Schneider, N.; Stiefl, N.; Riniker, S. GHOST: Adjusting the Decision Threshold to Handle Imbalanced Data in Machine Learning. Journal of Chemical Information and Modeling 2021, 61 (6), 2623–2640. https://doi.org/10.1021/acs.jcim.1c00160.

In [None]:
# evaluate GHOST on all folds (training data)
for i in range(9):
    threshold = best_thresh[i]
    y_true = comb[i]["binary_A"]
    y_prob = comb[i]["pred_0"]
    y_pred = (y_prob > threshold)
    print("balanced accuracy:", f"{balanced_accuracy_score(y_true, y_pred):.2f}")
    print("recall:", f"{recall_score(y_true, y_pred):.2f}")
    print("precision:", f"{precision_score(y_true, y_pred):.2f}")
    print("f_0.5 score:", f"{fbeta_score(y_true, y_pred, beta=0.5):.2f}")
    print()

In [None]:
# take a look at validation set metrics
# read val predictions
preds_val = [read_predictions(n, "val") for n in name_list]
preds_val[0]

In [None]:
# combine with ground truth
comb_val = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds_val]
comb_val[0]

In [None]:
# threshold obtained by GHOST with Kappa
acc, bal_acc, recall, precision, fbeta = [], [], [], [], []

for fold in range(9):
    threshold = best_thresh[fold]
    y_true = comb_val[fold]["binary_A"]
    y_prob = comb_val[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    acc.append(accuracy_score(y_true, y_pred))
    bal_acc.append(balanced_accuracy_score(y_true, y_pred))
    recall.append(recall_score(y_true, y_pred))
    precision.append(precision_score(y_true, y_pred))
    fbeta.append(fbeta_score(y_true, y_pred, beta=0.5))

print("accuracy:", f"{np.array(acc).mean():.2f}±{np.array(acc).std():.2f}")
print("balanced accuracy:", f"{np.array(bal_acc).mean():.2f}±{np.array(bal_acc).std():.2f}")
print("recall:", f"{np.array(recall).mean():.2f}±{np.array(recall).std():.2f}")
print("precision:", f"{np.array(precision).mean():.2f}±{np.array(precision).std():.2f}")
print("f_0.5 score:", f"{np.array(fbeta).mean():.2f}±{np.array(fbeta).std():.2f}")

In [None]:
# threshold obtained by GHOST with ROC
acc, bal_acc, recall, precision, fbeta = [], [], [], [], []

for fold in range(9):
    threshold = best_thresh_roc[fold]
    y_true = comb_val[fold]["binary_A"]
    y_prob = comb_val[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    acc.append(accuracy_score(y_true, y_pred))
    bal_acc.append(balanced_accuracy_score(y_true, y_pred))
    recall.append(recall_score(y_true, y_pred))
    precision.append(precision_score(y_true, y_pred))
    fbeta.append(fbeta_score(y_true, y_pred, beta=0.5))

print("accuracy:", f"{np.array(acc).mean():.2f}±{np.array(acc).std():.2f}")
print("balanced accuracy:", f"{np.array(bal_acc).mean():.2f}±{np.array(bal_acc).std():.2f}")
print("recall:", f"{np.array(recall).mean():.2f}±{np.array(recall).std():.2f}")
print("precision:", f"{np.array(precision).mean():.2f}±{np.array(precision).std():.2f}")
print("f_0.5 score:", f"{np.array(fbeta).mean():.2f}±{np.array(fbeta).std():.2f}")

In [None]:
best_tresh_bal_acc =  np.argmax(y_bal_acc, axis=0) / 100
best_tresh_bal_acc

In [None]:
# threshold obtained by maximizing balanced accuracy

acc, bal_acc, recall, precision, fbeta = [], [], [], [], []

for fold in range(9):
    threshold = np.argmax(y_bal_acc[:, fold]) / 100
    y_true = comb_val[fold]["binary_A"]
    y_prob = comb_val[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    acc.append(accuracy_score(y_true, y_pred))
    bal_acc.append(balanced_accuracy_score(y_true, y_pred))
    recall.append(recall_score(y_true, y_pred))
    precision.append(precision_score(y_true, y_pred))
    fbeta.append(fbeta_score(y_true, y_pred, beta=0.5))

print("accuracy:", f"{np.array(acc).mean():.2f}±{np.array(acc).std():.2f}")
print("balanced accuracy:", f"{np.array(bal_acc).mean():.2f}±{np.array(bal_acc).std():.2f}")
print("recall:", f"{np.array(recall).mean():.2f}±{np.array(recall).std():.2f}")
print("precision:", f"{np.array(precision).mean():.2f}±{np.array(precision).std():.2f}")
print("f_0.5 score:", f"{np.array(fbeta).mean():.2f}±{np.array(fbeta).std():.2f}")

In [None]:
# fixed threshold at 0.5

acc, bal_acc, recall, precision, fbeta = [], [], [], [], []

for fold in range(9):
    threshold = 0.5
    y_true = comb_val[fold]["binary_A"]
    y_prob = comb_val[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    acc.append(accuracy_score(y_true, y_pred))
    bal_acc.append(balanced_accuracy_score(y_true, y_pred))
    recall.append(recall_score(y_true, y_pred))
    precision.append(precision_score(y_true, y_pred))
    fbeta.append(fbeta_score(y_true, y_pred, beta=0.5))

print("accuracy:", f"{np.array(acc).mean():.2f}±{np.array(acc).std():.2f}")
print("balanced accuracy:", f"{np.array(bal_acc).mean():.2f}±{np.array(bal_acc).std():.2f}")
print("recall:", f"{np.array(recall).mean():.2f}±{np.array(recall).std():.2f}")
print("precision:", f"{np.array(precision).mean():.2f}±{np.array(precision).std():.2f}")
print("f_0.5 score:", f"{np.array(fbeta).mean():.2f}±{np.array(fbeta).std():.2f}")

In [None]:
# threshold equal to training data distribution

acc, bal_acc, recall, precision, fbeta = [], [], [], [], []

for fold in range(9):
    threshold = comb[fold]["binary_A"].mean()
    y_true = comb_val[fold]["binary_A"]
    y_prob = comb_val[fold]["pred_0"]
    y_pred = (y_prob > threshold)
    acc.append(accuracy_score(y_true, y_pred))
    bal_acc.append(balanced_accuracy_score(y_true, y_pred))
    recall.append(recall_score(y_true, y_pred))
    precision.append(precision_score(y_true, y_pred))
    fbeta.append(fbeta_score(y_true, y_pred, beta=0.5))

print("accuracy:", f"{np.array(acc).mean():.2f}±{np.array(acc).std():.2f}")
print("balanced accuracy:", f"{np.array(bal_acc).mean():.2f}±{np.array(bal_acc).std():.2f}")
print("recall:", f"{np.array(recall).mean():.2f}±{np.array(recall).std():.2f}")
print("precision:", f"{np.array(precision).mean():.2f}±{np.array(precision).std():.2f}")
print("f_0.5 score:", f"{np.array(fbeta).mean():.2f}±{np.array(fbeta).std():.2f}")

## Conclusion

Using GHOST with Cohen's Kappa gives the best results.
We will use the thresholds obtained this way for our final predictions

In [None]:
# save best thresholds for this model
for thresh, name in zip(best_thresh, name_list):
    with open(LOG_DIR / "thresholds" / f"{name}.txt", "w") as f:
        f.write(f"{thresh}\n")

## Faster way to pick decision treshold
Given a run_group, identify and save all the decision thresholds using GHOST with Cohen's kappa


In [None]:
def get_decision_thresholds(run_group: str) -> None:
    """
    Given a run group, identify and save all the decision thresholds using GHOST with Cohen's kappa.
    Best thresholds are saved to a txt file with 3 lines.
    The 1st line is the threshold for target binary_A, the 2nd for binary_B, and the 3rd for binary_C.
    If the run_group contains multiple runs, one txt file is generated per run.
    
    Args: 
        run_group (str): Name of a run_group present in the wandb database.
    Returns:
        None
    Raises:
        ValueError: Raised if no runs are found in WandB for the run_group
        RuntimeError: Raised if the predictions for a run are not found on disk.
    """
    # get runs from wandb
    api = wandb.Api(timeout=59)
    runs = api.runs("jugoetz/synferm-predictions", filters={"group": run_group})
    name_list = [run.name for run in runs]
    if len(name_list) == 0:
        raise ValueError(f"No runs found for run_group {run_group}.")
        
    # read train predictions
    try:
        preds = [read_predictions(n, "train") for n in name_list]
    except FileNotFoundError as e:
        raise RuntimeError(f"Did not find all predictions for runs {name_list}. Exception leading to this: {e}")
    
    # combine predictions with ground truth
    df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
    comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]
    
    # obtain best threshold with GHOST with Cohen's Kappa method
    print(f"Running GHOST for {len(name_list)} runs. This may take a while...")
    best_thresh = np.empty((len(name_list), 3))
    thresholds = np.arange(0, 1.0001, 0.01)
    for run_idx in range(len(name_list)):
        for i, target in enumerate("ABC"):
            best_thresh[run_idx, i] = optimize_threshold_from_predictions(
                comb[run_idx][f"binary_{target}"], 
                comb[run_idx][f"pred_{i}"], 
                thresholds=thresholds, 
                ThOpt_metrics="Kappa",
                random_seed=42, # we seed for reproducible results
            )
    
    # write thresholds to disk
    for threshs, name in zip(best_thresh, name_list):
        with open(LOG_DIR / "thresholds" / f"{name}.txt", "w") as f:
            for thresh in threshs:
                f.write(f"{thresh:.2f}\n")
    print(best_thresh)

In [None]:
def get_inverse_class_distribution(run_group: str) -> None:
    """
    An alternative is to set the decision boundary to the class distribution.
    Note that this is a quite extreme measure and is interesting mostly as a reference point.
    Here we only need to know the split, but for convenience we can get this from the run group (through train indices)
    Best thresholds are saved to a txt file with 3 lines.
    The 1st line is the threshold for target binary_A, the 2nd for binary_B, and the 3rd for binary_C.
    If the run_group contains multiple runs, one txt file is generated per run.
    
    Args: 
        run_group (str): Name of a run_group present in the wandb database.
    Returns:
        None
    Raises:
        ValueError: Raised if no runs are found in WandB for the run_group
        RuntimeError: Raised if the predictions for a run are not found on disk.
    """
    # get runs from wandb
    api = wandb.Api(timeout=59)
    runs = api.runs("jugoetz/synferm-predictions", filters={"group": run_group})
    name_list = [run.name for run in runs]
    if len(name_list) == 0:
        raise ValueError(f"No runs found for run_group {run_group}.")
        
    # read train predictions
    try:
        preds = [read_predictions(n, "train") for n in name_list]
    except FileNotFoundError as e:
        raise RuntimeError(f"Did not find all predictions for runs {name_list}. Exception leading to this: {e}")
    
    # combine predictions with ground truth
    df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
    comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]
    
    # obtain class distributions
    best_thresh = np.empty((len(name_list), 3))
    for run_idx in range(len(name_list)):
        for i, target in enumerate("ABC"):
            best_thresh[run_idx, i] = comb[run_idx][f"binary_{target}"].mean()
                
    
    # write thresholds to disk
    for threshs, name in zip(best_thresh, name_list):
        with open(LOG_DIR / "thresholds" / f"{name}.txt", "w") as f:
            for thresh in threshs:
                f.write(f"{thresh:.2f}\n")
    print(best_thresh)

In [None]:
# the best OD model from validation (JG1309, FFN/OHE)
get_decision_thresholds("2023-12-20-202602_330364")

In [None]:
# the OD production model (JG1349, FFN/OHE)
get_decision_thresholds("2024-01-30-112912_514212")

In [None]:
# the best XGB/FP models (JG1486, 1D)
get_decision_thresholds("2024-01-23-063840_864375")

In [None]:
# the XGB/FP modelfor Euan on 27 folds (JG1526, 1D)
get_decision_thresholds("2024-02-23-134822_777158")

In [None]:
# the best XGB/FP models (JG1495, 2D)
get_decision_thresholds("2024-01-25-192032_503662")

In [None]:
# the best XGB/FP models (JG1504, 3D)
get_decision_thresholds("2024-01-26-161936_145583")

### Thresholded metrics for test scores of best models
(i.e. we evaluate the models that were selected based on validation performance)

In [None]:
records = []

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
# get the thresholded metrics for these models
# 0D
api = wandb.Api(timeout=59)
runs = api.runs("jugoetz/synferm-predictions", filters={"group": "2023-12-20-202602_330364"})

# get run info
config_list = [run.config for run in runs]
summary_list = [run.summary._json_dict for run in runs]
name_list = [run.name for run in runs]

# read val predictions
preds = [read_predictions(n, "test") for n in name_list]
# combine with ground truth
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]

# load decision thresholds
threshs = []
for name in name_list:
    with open(LOG_DIR / "thresholds" / f"{name}.txt", "r") as f:
        threshs.append([float(i) for i in f.readlines()])
threshs = np.array(threshs)

# get truth
fold_true = [comb[i][["binary_A", "binary_B", "binary_C"]].to_numpy() for i in range(len(name_list))]

# get preds by applying the thresholds
fold_preds = []
for fold_i in range(len(name_list)):
    fold_preds.append(np.stack([np.where(preds[fold_i].to_numpy()[:, i] > threshs[fold_i, i], 1, 0) for i in range(3)], axis=1))

# evaluate
acc, bal_acc, recall, precision, fbeta = np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3))

for fold in range(len(name_list)):
    for target in range(3):
        threshold = threshs[fold]
        y_true = fold_true[fold][:, target]
        y_pred = fold_preds[fold][:, target]
        acc[fold, target] = accuracy_score(y_true, y_pred)
        bal_acc[fold, target] = balanced_accuracy_score(y_true, y_pred)
        recall[fold, target] = recall_score(y_true, y_pred)
        precision[fold, target] = precision_score(y_true, y_pred)
        fbeta[fold, target] = fbeta_score(y_true, y_pred, beta=0.5)
        print(confusion_matrix(y_true, y_pred))

for i, name in enumerate(name_list):
    records.append({"name": name, "split": "0D", "metric": "accuracy", "target": "macro", "value": acc[i].mean()})
    records.append({"name": name, "split": "0D", "metric": "recall", "target": "macro", "value": recall[i].mean()})
    records.append({"name": name, "split": "0D", "metric": "precision", "target": "macro", "value": precision[i].mean()})
    records.append({"name": name, "split": "0D", "metric": "accuracy", "target": "A", "value": acc[i, 0]})
    records.append({"name": name, "split": "0D", "metric": "recall", "target": "A", "value": recall[i, 0]})
    records.append({"name": name, "split": "0D", "metric": "precision", "target": "A", "value": precision[i, 0]})


print("Mean macro accuracy:", f"{acc.mean():.2f}±{acc.mean(axis=1).std():.3f}")
print("Mean macro recall:", f"{recall.mean():.2f}±{recall.mean(axis=1).std():.3f}")
print("Mean macro precision:", f"{precision.mean():.2f}±{precision.mean(axis=1).std():.3f}")
#print("Mean macro balanced accuracy:", f"{bal_acc.mean():.2f}±{bal_acc.mean(axis=1).std():.2f}")
#print("Mean macro f_0.5 score:", f"{fbeta.mean():.2f}±{fbeta.mean(axis=1).std():.2f}")

print("Mean target_A accuracy:", f"{acc[:, 0].mean():.2f}±{acc[:, 0].std():.3f}")
print("Mean target_A recall:", f"{recall[:, 0].mean():.2f}±{recall[:, 0].std():.3f}")
print("Mean target_A precision:", f"{precision[:, 0].mean():.2f}±{precision[:, 0].std():.3f}")
#print("Mean target_A balanced accuracy:", f"{bal_acc[:, 0].mean():.2f}±{bal_acc[:, 0].std():.2f}")
#print("Mean target_A f_0.5 score:", f"{fbeta[:, 0].mean():.2f}±{fbeta[:, 0].std():.2f}")

In [None]:
y_true

In [None]:
confusion_matrix(y_true, y_pred)

In [None]:
# get the thresholded metrics for these models
# 1D
api = wandb.Api()
runs = api.runs("jugoetz/synferm-predictions", filters={"group": "2024-01-23-063840_864375"})

# get run info
config_list = [run.config for run in runs]
summary_list = [run.summary._json_dict for run in runs]
name_list = [run.name for run in runs]

# read val predictions
preds = [read_predictions(n, "test") for n in name_list]
# combine with ground truth
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]

# load decision thresholds
threshs = []
for name in name_list:
    with open(LOG_DIR / "thresholds" / f"{name}.txt", "r") as f:
        threshs.append([float(i) for i in f.readlines()])
threshs = np.array(threshs)

# get truth
fold_true = [comb[i][["binary_A", "binary_B", "binary_C"]].to_numpy() for i in range(len(name_list))]

# get preds by applying the thresholds
fold_preds = []
for fold_i in range(len(name_list)):
    fold_preds.append(np.stack([np.where(preds[fold_i].to_numpy()[:, i] > threshs[fold_i, i], 1, 0) for i in range(3)], axis=1))

# evaluate
acc, bal_acc, recall, precision, fbeta = np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3))

for fold in range(len(name_list)):
    for target in range(3):
        threshold = threshs[fold]
        y_true = fold_true[fold][:, target]
        y_pred = fold_preds[fold][:, target]
        acc[fold, target] = accuracy_score(y_true, y_pred)
        bal_acc[fold, target] = balanced_accuracy_score(y_true, y_pred)
        recall[fold, target] = recall_score(y_true, y_pred)
        precision[fold, target] = precision_score(y_true, y_pred)
        fbeta[fold, target] = fbeta_score(y_true, y_pred, beta=0.5)

for i, name in enumerate(name_list):
    records.append({"name": name, "split": "1D", "metric": "accuracy", "target": "macro", "value": acc[i].mean()})
    records.append({"name": name, "split": "1D", "metric": "recall", "target": "macro", "value": recall[i].mean()})
    records.append({"name": name, "split": "1D", "metric": "precision", "target": "macro", "value": precision[i].mean()})
    records.append({"name": name, "split": "1D", "metric": "accuracy", "target": "A", "value": acc[i, 0]})
    records.append({"name": name, "split": "1D", "metric": "recall", "target": "A", "value": recall[i, 0]})
    records.append({"name": name, "split": "1D", "metric": "precision", "target": "A", "value": precision[i, 0]})


print("Mean macro accuracy:", f"{acc.mean():.2f}±{acc.mean(axis=1).std():.3f}")
print("Mean macro recall:", f"{recall.mean():.2f}±{recall.mean(axis=1).std():.3f}")
print("Mean macro precision:", f"{precision.mean():.2f}±{precision.mean(axis=1).std():.3f}")
print("Mean macro balanced accuracy:", f"{bal_acc.mean():.2f}±{bal_acc.mean(axis=1).std():.2f}")
#print("Mean macro f_0.5 score:", f"{fbeta.mean():.2f}±{fbeta.mean(axis=1).std():.2f}")

print("Mean target_A accuracy:", f"{acc[:, 0].mean():.2f}±{acc[:, 0].std():.3f}")
print("Mean target_A recall:", f"{recall[:, 0].mean():.2f}±{recall[:, 0].std():.3f}")
print("Mean target_A precision:", f"{precision[:, 0].mean():.2f}±{precision[:, 0].std():.3f}")
print("Mean target_A balanced accuracy:", f"{bal_acc[:, 0].mean():.2f}±{bal_acc[:, 0].std():.2f}")
#print("Mean target_A f_0.5 score:", f"{fbeta[:, 0].mean():.2f}±{fbeta[:, 0].std():.2f}")

In [None]:
# get the thresholded metrics for these models
# 2D
api = wandb.Api()
runs = api.runs("jugoetz/synferm-predictions", filters={"group": "2024-01-25-192032_503662"})

# get run info
config_list = [run.config for run in runs]
summary_list = [run.summary._json_dict for run in runs]
name_list = [run.name for run in runs]

# read val predictions
preds = [read_predictions(n, "test") for n in name_list]
# combine with ground truth
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]

# load decision thresholds
threshs = []
for name in name_list:
    with open(LOG_DIR / "thresholds" / f"{name}.txt", "r") as f:
        threshs.append([float(i) for i in f.readlines()])
threshs = np.array(threshs)

# get truth
fold_true = [comb[i][["binary_A", "binary_B", "binary_C"]].to_numpy() for i in range(len(name_list))]

# get preds by applying the thresholds
fold_preds = []
for fold_i in range(len(name_list)):
    fold_preds.append(np.stack([np.where(preds[fold_i].to_numpy()[:, i] > threshs[fold_i, i], 1, 0) for i in range(3)], axis=1))

# evaluate
acc, bal_acc, recall, precision, fbeta = np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3))

for fold in range(len(name_list)):
    for target in range(3):
        threshold = threshs[fold]
        y_true = fold_true[fold][:, target]
        y_pred = fold_preds[fold][:, target]
        acc[fold, target] = accuracy_score(y_true, y_pred)
        bal_acc[fold, target] = balanced_accuracy_score(y_true, y_pred)
        recall[fold, target] = recall_score(y_true, y_pred)
        precision[fold, target] = precision_score(y_true, y_pred)
        fbeta[fold, target] = fbeta_score(y_true, y_pred, beta=0.5)

for i, name in enumerate(name_list):
    records.append({"name": name, "split": "2D", "metric": "accuracy", "target": "macro", "value": acc[i].mean()})
    records.append({"name": name, "split": "2D", "metric": "recall", "target": "macro", "value": recall[i].mean()})
    records.append({"name": name, "split": "2D", "metric": "precision", "target": "macro", "value": precision[i].mean()})
    records.append({"name": name, "split": "2D", "metric": "accuracy", "target": "A", "value": acc[i, 0]})
    records.append({"name": name, "split": "2D", "metric": "recall", "target": "A", "value": recall[i, 0]})
    records.append({"name": name, "split": "2D", "metric": "precision", "target": "A", "value": precision[i, 0]})


print("Mean macro accuracy:", f"{acc.mean():.2f}±{acc.mean(axis=1).std():.3f}")
print("Mean macro recall:", f"{recall.mean():.2f}±{recall.mean(axis=1).std():.3f}")
print("Mean macro precision:", f"{precision.mean():.2f}±{precision.mean(axis=1).std():.3f}")
print("Mean macro balanced accuracy:", f"{bal_acc.mean():.2f}±{bal_acc.mean(axis=1).std():.2f}")
#print("Mean macro f_0.5 score:", f"{fbeta.mean():.2f}±{fbeta.mean(axis=1).std():.2f}")

print("Mean target_A accuracy:", f"{acc[:, 0].mean():.2f}±{acc[:, 0].std():.3f}")
print("Mean target_A recall:", f"{recall[:, 0].mean():.2f}±{recall[:, 0].std():.3f}")
print("Mean target_A precision:", f"{precision[:, 0].mean():.2f}±{precision[:, 0].std():.3f}")
print("Mean target_A balanced accuracy:", f"{bal_acc[:, 0].mean():.2f}±{bal_acc[:, 0].std():.2f}")
#print("Mean target_A f_0.5 score:", f"{fbeta[:, 0].mean():.2f}±{fbeta[:, 0].std():.2f}")

In [None]:
# get the thresholded metrics for these models
# 3D
api = wandb.Api()
runs = api.runs("jugoetz/synferm-predictions", filters={"group": "2024-01-26-161936_145583"})

# get run info
config_list = [run.config for run in runs]
summary_list = [run.summary._json_dict for run in runs]
name_list = [run.name for run in runs]

# read val predictions
preds = [read_predictions(n, "test") for n in name_list]
# combine with ground truth
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-12-20_39486records.csv")
comb = [pred.merge(df_true[["binary_A", "binary_B", "binary_C"]], left_index=True, right_index=True) for pred in preds]

# load decision thresholds
threshs = []
for name in name_list:
    with open(LOG_DIR / "thresholds" / f"{name}.txt", "r") as f:
        threshs.append([float(i) for i in f.readlines()])
threshs = np.array(threshs)

# get truth
fold_true = [comb[i][["binary_A", "binary_B", "binary_C"]].to_numpy() for i in range(len(name_list))]

# get preds by applying the thresholds
fold_preds = []
for fold_i in range(len(name_list)):
    fold_preds.append(np.stack([np.where(preds[fold_i].to_numpy()[:, i] > threshs[fold_i, i], 1, 0) for i in range(3)], axis=1))

# evaluate
acc, bal_acc, recall, precision, fbeta = np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3)), np.zeros((len(name_list), 3))

for fold in range(len(name_list)):
    for target in range(3):
        threshold = threshs[fold]
        y_true = fold_true[fold][:, target]
        y_pred = fold_preds[fold][:, target]
        acc[fold, target] = accuracy_score(y_true, y_pred)
        bal_acc[fold, target] = balanced_accuracy_score(y_true, y_pred)
        recall[fold, target] = recall_score(y_true, y_pred)
        precision[fold, target] = precision_score(y_true, y_pred)
        fbeta[fold, target] = fbeta_score(y_true, y_pred, beta=0.5)
        if target == 0:
            print(confusion_matrix(y_true, y_pred))

for i, name in enumerate(name_list):
    records.append({"name": name, "split": "3D", "metric": "accuracy", "target": "macro", "value": acc[i].mean()})
    records.append({"name": name, "split": "3D", "metric": "recall", "target": "macro", "value": recall[i].mean()})
    records.append({"name": name, "split": "3D", "metric": "precision", "target": "macro", "value": precision[i].mean()})
    records.append({"name": name, "split": "3D", "metric": "accuracy", "target": "A", "value": acc[i, 0]})
    records.append({"name": name, "split": "3D", "metric": "recall", "target": "A", "value": recall[i, 0]})
    records.append({"name": name, "split": "3D", "metric": "precision", "target": "A", "value": precision[i, 0]})


print("Mean macro accuracy:", f"{acc.mean():.2f}±{acc.mean(axis=1).std():.3f}")
print("Mean macro recall:", f"{recall.mean():.2f}±{recall.mean(axis=1).std():.3f}")
print("Mean macro precision:", f"{precision.mean():.2f}±{precision.mean(axis=1).std():.3f}")
print("Mean macro balanced accuracy:", f"{bal_acc.mean():.2f}±{bal_acc.mean(axis=1).std():.2f}")
#print("Mean macro f_0.5 score:", f"{fbeta.mean():.2f}±{fbeta.mean(axis=1).std():.2f}")

print("Mean target_A accuracy:", f"{acc[:, 0].mean():.2f}±{acc[:, 0].std():.3f}")
print("Mean target_A recall:", f"{recall[:, 0].mean():.2f}±{recall[:, 0].std():.3f}")
print("Mean target_A precision:", f"{precision[:, 0].mean():.2f}±{precision[:, 0].std():.3f}")
print("Mean target_A balanced accuracy:", f"{bal_acc[:, 0].mean():.2f}±{bal_acc[:, 0].std():.2f}")
#print("Mean target_A f_0.5 score:", f"{fbeta[:, 0].mean():.2f}±{fbeta[:, 0].std():.2f}")

In [None]:
df = pd.DataFrame(records)
df["metric"] = df["metric"].apply(str.capitalize)
df.head()

In [None]:
import seaborn as sns
import matplotlib
#matplotlib.use('svg')

In [None]:
# settings
sns.set_theme(context="paper", 
              style="white", 
              font_scale=1, #0.7,
              rc={"savefig.transparent": True, 
                  "axes.grid": False, 
                  "axes.spines.bottom": True,
                  "axes.spines.left": False,
                  "axes.spines.right": False,
                  "axes.spines.top": False,
                  "font.family":'sans-serif',
                  "font.sans-serif":["Helvetica", "Arial"],
                  "xtick.major.pad": 0.0,
                  "xtick.minor.pad": 0.0,
                  "ytick.major.pad": 0.0,
                  "ytick.minor.pad": 0.0,
                  "axes.labelweight": "bold",
                  "axes.labelpad": 2.5,  # standard is 4.0
                  "axes.xmargin": .05,
                 }, 
             )

# more settings for all plots
errorbar = "se"  # standard error of the mean
errwidth = .9
errcolor = "black"
capsize = .1  # size of the end of the errorbar
linewidth = 1.  # width of the outline of barplot

In [None]:
palette = sns.color_palette(["#5760bb", "#bd57d5", "#87ba70", "#c6c150"])
palette

In [None]:
# set dir where we will save plots
analysis_dir = pathlib.Path("results")

In [None]:
fig, ax = plt.subplots(figsize=(4.75,4))
sns.barplot(df.loc[df.target == "macro"], 
            x="split", 
            y="value", 
            hue="metric",
            palette=palette,
            errorbar=errorbar, 
            errcolor=errcolor,
            lw=linewidth,
            capsize=capsize / 4,
            errwidth=1.5,
            alpha=.99
           )

# Set the color of each bar manually
for i, bar in enumerate(ax.patches):
    bar.set_facecolor(palette[i % 4])
    if i in range(4, 8):
        bar.set_hatch("//////")
    if i in range(8, 12):
        bar.set_hatch("\\\\\\\\\\\\")  # need more to escape backslashes

ax.set_xlabel(None)
ax.set_ylabel("Metric")
ax.set_ylim((0, 1))

ax.legend(loc="lower left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / "best-model_metrics_2023-12-20_test-macro.pdf")

In [None]:
fig, ax = plt.subplots(figsize=(4.75,4))
sns.barplot(df.loc[df.target == "A"], 
            x="split", 
            y="value", 
            hue="metric",
            palette=palette,
            errorbar=errorbar, 
            errcolor=errcolor,
            lw=linewidth,
            capsize=capsize / 4,
            errwidth=1.5,
           )

# Set the color of each bar manually
for i, bar in enumerate(ax.patches):
    bar.set_facecolor(palette[i % 4])
    if i in range(4, 8):
        bar.set_hatch("//////")
    if i in range(8, 12):
        bar.set_hatch("\\\\\\\\\\\\")  # need more to escape backslashes

ax.set_xlabel(None)
ax.set_ylabel("Metric")
ax.set_ylim((0, 1))

ax.legend(loc="lower left", title=None)

fig.tight_layout()
fig.savefig(analysis_dir / "best-model_metrics_2023-12-20_test-A.pdf")