# Analyze predictions

The ML train loop did not calculate some metrics we are now interested in.
But since we saved the predictions, it should be easy to obtain the metrics.
We add them back to wandb, for easier analysis there.
- Add AUROC, balanced accuracy, and - for multilabel tasks - some averages over all labels.

In [7]:
import pathlib
import sys
sys.path.append(str(pathlib.Path("__file__").absolute().parents[1]))

from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, accuracy_score, balanced_accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

from src.util.definitions import PRED_DIR, DATA_ROOT

In [8]:
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-09-05_40018records.csv")

In [2]:
api = wandb.Api()
runs = api.runs("jugoetz/synferm-predictions")

In [3]:
summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains output keys/values for
    # metrics such as accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
         if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

In [4]:
name_list.index("2023-09-26-132606_155550_fold8")

10761

In [5]:
le = LabelEncoder()
le.fit(["A", "B", "C", "no_product"])

In [6]:
for run in runs[9:-1]:
    # we want to calculate additional metrics
    run_type = run.config["training"]["task"]
    run_target_names = run.config["target_names"]
    run_name = run.name
    # first we check if predicted values are available
    val_pred_path = PRED_DIR / run_name / "val_preds_last.csv"
    test_pred_path = PRED_DIR / run_name / "test_preds_last.csv"
    metrics = {}
    for name, file in zip(["val", "test"], [val_pred_path, test_pred_path]):
        if file.is_file():
            # import predictions and combine with ground truth
            df = pd.read_csv(file, index_col="idx").merge(df_true, how="left", left_index=True, right_index=True)
        
            # extract predictions
            if run_type in ["multilabel", "binary"]:
                y_prob = df[[f"pred_{i}" for i in range(len(run_target_names))]].to_numpy()
                y_hat = (y_prob > 0.5).astype(np.int_)
                y_true = df[run_target_names].to_numpy()
            elif run_type == "multiclass":
                y_prob = df[[f"pred_{i}" for i in range(len(le.classes_))]].to_numpy()
                y_hat = np.argmax(y_prob, axis=1)
                y_true = le.transform(df["major_A-C"].to_numpy())
            else:
                raise ValueError("Unexpected run_type")
        
            # calculate additional metrics
            if run_type == "multilabel":
                auroc_avg = roc_auc_score(y_true, y_prob, average="macro")
                #conf_mat = multilabel_confusion_matrix(y_true, y_hat)
                recall_avg = recall_score(y_true, y_hat, average="macro")
                precision_avg = precision_score(y_true, y_hat, average="macro")
                f1_avg = f1_score(y_true, y_hat, average="micro")
                
                # compute balanced accuracy per label
                balanced_acc = []
                for i, t in enumerate(run_target_names):
                    balanced_acc.append(balanced_accuracy_score(y_true[:,i], y_hat[:,i], adjusted=False))
        
                metrics.update({
                            f"{name}/balanced_accuracy_macro": np.mean(balanced_acc), 
                            f"{name}/auroc_macro": auroc_avg, 
                            f"{name}/recall_macro": recall_avg, 
                            f"{name}/precision_macro": precision_avg,
                            f"{name}/f1_micro": f1_avg 
                            })
                metrics.update({f"{name}/balanced_accuracy_target_{t}": v for t, v in zip(run_target_names, balanced_acc)})
            elif run_type == "binary":
                # calculate additional metrics
                auroc = roc_auc_score(y_true, y_prob, average=None)
                balanced_acc = balanced_accuracy_score(y_true, y_hat, adjusted=False)
                metrics.update({
                            f"{name}/balanced_accuracy": balanced_acc, 
                            f"{name}/auroc": auroc 
                            })
            elif run_type == "multiclass":
                # calculate additional metrics
                balanced_acc = balanced_accuracy_score(y_true, y_hat, adjusted=False)
                auroc_macro = roc_auc_score(y_true, y_prob, average="macro", multi_class="ovo")  # one-v-one + macro average is insensitive to class imbalance
                metrics.update({
                    f"{name}/balanced_accuracy": balanced_acc, 
                    f"{name}/auroc_macro_ovo": auroc_macro
                    })
        else:
            print(f"{name} predictions not found for {run_name}")
    # add new metrics to wandb run
    for k, v in metrics.items():
        run.summary[k] = v
    # update wandb
    run.summary.update()

NameError: name 'df_true' is not defined

In [231]:
run.name

'2023-09-12-133900_140535_fold1'

In [3]:
runs = api.runs("jugoetz/synferm-predictions", filters={"display_name": {"$regex": "2023-09-26-114651_772025_fold*"}})

for run in runs:
    print(run.name)

2023-09-26-114651_772025_fold8
2023-09-26-114651_772025_fold7
2023-09-26-114651_772025_fold6
2023-09-26-114651_772025_fold5
2023-09-26-114651_772025_fold4
2023-09-26-114651_772025_fold3
2023-09-26-114651_772025_fold2
2023-09-26-114651_772025_fold1
2023-09-26-114651_772025_fold0


In [4]:
metrics = {}
for metric_name in runs[0].summary.keys():
    if not metric_name.startswith("_"):
        metrics[metric_name] = np.array([run.summary[metric_name] for run in runs])
metrics

{'val/balanced_accuracy_target_binary_B': array([0.6236525 , 0.78227259, 0.62739291, 0.73003649, 0.74794857,
        0.78125   , 0.672659  , 0.57377049, 0.75538286]),
 'test/auroc_macro': array([0.73422945, 0.81741415, 0.88277051, 0.81239181, 0.7649315 ,
        0.80790682, 0.73461941, 0.74940904, 0.7357966 ]),
 'val/precision_macro': array([0.72602073, 0.72064617, 0.68837824, 0.78761346, 0.79005984,
        0.72912026, 0.73225959, 0.74548392, 0.78909503]),
 'train/f1_target_binary_A': array([0.93095142, 0.90456831, 0.8875131 , 0.91228336, 0.90294635,
        0.90363753, 0.90689605, 0.90413129, 0.91743642]),
 'val/accuracy_target_binary_C': array([0.8659004 , 0.79669029, 0.54020101, 0.86682242, 0.78698224,
        0.7977941 , 0.7523585 , 0.6557377 , 0.77746481]),
 'test/accuracy_target_binary_B': array([0.78532606, 0.79665738, 0.83552629, 0.79888266, 0.79567307,
        0.73674244, 0.76483518, 0.77220958, 0.74931878]),
 'train/loss': array([0.40467438, 0.37629181, 0.37927896, 0.3443371

In [8]:
for metric_name in [
    "val/balanced_accuracy_macro",
    "test/accuracy_target_binary_A",
    "test/accuracy_target_binary_B",
    "test/accuracy_target_binary_C",
    "test/balanced_accuracy_macro",
    "test/balanced_accuracy_target_binary_A",
    "test/balanced_accuracy_target_binary_B",
    "test/balanced_accuracy_target_binary_C",
    "test/precision_macro",
    "test/precision_target_binary_A",
    "test/precision_target_binary_B",
    "test/precision_target_binary_C",
    "test/recall_macro",
    "test/recall_target_binary_A",
    "test/recall_target_binary_B",
    "test/recall_target_binary_C",
]:
    print(f'{metric_name}: {metrics[metric_name].mean():.4f}±{metrics[metric_name].std():.3f}')

val/balanced_accuracy_macro: 0.6571±0.040
test/accuracy_target_binary_A: 0.7746±0.095
test/accuracy_target_binary_B: 0.7817±0.028
test/accuracy_target_binary_C: 0.7893±0.041
test/balanced_accuracy_macro: 0.6775±0.041
test/balanced_accuracy_target_binary_A: 0.5654±0.066
test/balanced_accuracy_target_binary_B: 0.7702±0.043
test/balanced_accuracy_target_binary_C: 0.6968±0.065
test/precision_macro: 0.7454±0.063
test/precision_target_binary_A: 0.8115±0.091
test/precision_target_binary_B: 0.7697±0.072
test/precision_target_binary_C: 0.6077±0.172
test/recall_macro: 0.7528±0.087
test/recall_target_binary_A: 0.9300±0.061
test/recall_target_binary_B: 0.8250±0.078
test/recall_target_binary_C: 0.4897±0.180
