# Analyze predictions

Initially, the ML train loop did not calculate some metrics we are now interested in.
But since we saved the predictions, it should be easy to obtain the metrics.
We add them back to wandb, for easier analysis there.
- Add AUROC, balanced accuracy, and - for multilabel tasks - some averages over all labels.

In [None]:
import pathlib
import sys
sys.path.append(str(pathlib.Path().absolute().parent))

from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, accuracy_score, balanced_accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

from src.util.definitions import PRED_DIR, DATA_ROOT

In [None]:
df_true = pd.read_csv(DATA_ROOT / "synferm_dataset_2023-09-05_40018records.csv")

In [None]:
api = wandb.Api()
runs = api.runs("jugoetz/synferm-predictions")

In [None]:
summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains output keys/values for
    # metrics such as accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
         if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

In [None]:
name_list.index("2023-09-26-132606_155550_fold8")

In [None]:
le = LabelEncoder()
le.fit(["A", "B", "C", "no_product"])

In [None]:
for run in runs[9:-1]:
    # we want to calculate additional metrics
    run_type = run.config["training"]["task"]
    run_target_names = run.config["target_names"]
    run_name = run.name
    # first we check if predicted values are available
    val_pred_path = PRED_DIR / run_name / "val_preds_last.csv"
    test_pred_path = PRED_DIR / run_name / "test_preds_last.csv"
    metrics = {}
    for name, file in zip(["val", "test"], [val_pred_path, test_pred_path]):
        if file.is_file():
            # import predictions and combine with ground truth
            df = pd.read_csv(file, index_col="idx").merge(df_true, how="left", left_index=True, right_index=True)
        
            # extract predictions
            if run_type in ["multilabel", "binary"]:
                y_prob = df[[f"pred_{i}" for i in range(len(run_target_names))]].to_numpy()
                y_hat = (y_prob > 0.5).astype(np.int_)
                y_true = df[run_target_names].to_numpy()
            elif run_type == "multiclass":
                y_prob = df[[f"pred_{i}" for i in range(len(le.classes_))]].to_numpy()
                y_hat = np.argmax(y_prob, axis=1)
                y_true = le.transform(df["major_A-C"].to_numpy())
            else:
                raise ValueError("Unexpected run_type")
        
            # calculate additional metrics
            if run_type == "multilabel":
                auroc_avg = roc_auc_score(y_true, y_prob, average="macro")
                #conf_mat = multilabel_confusion_matrix(y_true, y_hat)
                recall_avg = recall_score(y_true, y_hat, average="macro")
                precision_avg = precision_score(y_true, y_hat, average="macro")
                f1_avg = f1_score(y_true, y_hat, average="micro")
                
                # compute balanced accuracy per label
                balanced_acc = []
                for i, t in enumerate(run_target_names):
                    balanced_acc.append(balanced_accuracy_score(y_true[:,i], y_hat[:,i], adjusted=False))
        
                metrics.update({
                            f"{name}/balanced_accuracy_macro": np.mean(balanced_acc), 
                            f"{name}/auroc_macro": auroc_avg, 
                            f"{name}/recall_macro": recall_avg, 
                            f"{name}/precision_macro": precision_avg,
                            f"{name}/f1_micro": f1_avg 
                            })
                metrics.update({f"{name}/balanced_accuracy_target_{t}": v for t, v in zip(run_target_names, balanced_acc)})
            elif run_type == "binary":
                # calculate additional metrics
                auroc = roc_auc_score(y_true, y_prob, average=None)
                balanced_acc = balanced_accuracy_score(y_true, y_hat, adjusted=False)
                metrics.update({
                            f"{name}/balanced_accuracy": balanced_acc, 
                            f"{name}/auroc": auroc 
                            })
            elif run_type == "multiclass":
                # calculate additional metrics
                balanced_acc = balanced_accuracy_score(y_true, y_hat, adjusted=False)
                auroc_macro = roc_auc_score(y_true, y_prob, average="macro", multi_class="ovo")  # one-v-one + macro average is insensitive to class imbalance
                metrics.update({
                    f"{name}/balanced_accuracy": balanced_acc, 
                    f"{name}/auroc_macro_ovo": auroc_macro
                    })
        else:
            print(f"{name} predictions not found for {run_name}")
    # add new metrics to wandb run
    for k, v in metrics.items():
        run.summary[k] = v
    # update wandb
    run.summary.update()

In [None]:
run.name

In [None]:
runs = api.runs("jugoetz/synferm-predictions", filters={"display_name": {"$regex": "2023-09-26-114651_772025_fold*"}})

for run in runs:
    print(run.name)

In [None]:
metrics = {}
for metric_name in runs[0].summary.keys():
    if not metric_name.startswith("_"):
        metrics[metric_name] = np.array([run.summary[metric_name] for run in runs])
metrics

In [None]:
for metric_name in [
    "val/balanced_accuracy_macro",
    "test/accuracy_target_binary_A",
    "test/accuracy_target_binary_B",
    "test/accuracy_target_binary_C",
    "test/balanced_accuracy_macro",
    "test/balanced_accuracy_target_binary_A",
    "test/balanced_accuracy_target_binary_B",
    "test/balanced_accuracy_target_binary_C",
    "test/precision_macro",
    "test/precision_target_binary_A",
    "test/precision_target_binary_B",
    "test/precision_target_binary_C",
    "test/recall_macro",
    "test/recall_target_binary_A",
    "test/recall_target_binary_B",
    "test/recall_target_binary_C",
]:
    print(f'{metric_name}: {metrics[metric_name].mean():.4f}±{metrics[metric_name].std():.3f}')