In [None]:
import pandas as pd
import numpy as np
import torch
from torchmetrics.classification import (
    MultilabelAccuracy,
    MultilabelAUROC,
    MultilabelAveragePrecision,
    MultilabelExactMatch,
)

In [None]:
results_df = pd.read_csv("Saved/preds_targets_probs_40e_10f.csv")

# parse string arrays
def parse_array(s):
    return np.fromstring(s.strip("[]"), sep=" ")

preds = results_df["Predictions"].apply(parse_array).to_list()
targets = results_df["Targets"].apply(parse_array).to_list()

# Convert to numpy arrays
preds = np.stack(preds)
targets = np.stack(targets).astype(int)

In [None]:

CSV_PATH = "Saved/preds_targets_probs_40e_10f.csv"  # change if needed
FOLD_COL = "Fold"
PRED_COL = "Predictions"
TARG_COL = "Targets"
THRESHOLD = 0.5

# If you have explicit names, set them here. Otherwise they will be auto-generated.
TARGET_NAMES = ["Right Upper", "Left Upper", "Left Lower", "Right Lower"]

# ----------------------------
# Utils
# ----------------------------
def parse_array(s: str) -> np.ndarray:
    return np.fromstring(str(s).strip("[]"), sep=" ")

def _compute_fold_metrics(preds_np: np.ndarray,
                          targets_np: np.ndarray,
                          threshold: float = 0.5) -> dict:

    preds_t   = torch.tensor(preds_np, dtype=torch.float32)
    targets_t = torch.tensor(targets_np, dtype=torch.long)
    num_labels = preds_t.size(1)

    subset_acc = MultilabelExactMatch(num_labels=num_labels)
    macro_acc  = MultilabelAccuracy(num_labels=num_labels, average="macro",    threshold=threshold)
    per_label  = MultilabelAccuracy(num_labels=num_labels, average=None,       threshold=threshold)

    auroc_ma   = MultilabelAUROC(num_labels=num_labels, average="macro")
    auprc_ma   = MultilabelAveragePrecision(num_labels=num_labels, average="macro")

    out = {
        "n":             int(preds_t.size(0)),
        "subset_acc":    float(subset_acc(preds_t, targets_t)),
        "macro_acc":     float(macro_acc(preds_t, targets_t)),
        "macro_auroc":   float(auroc_ma(preds_t, targets_t)),
        "macro_auprc":   float(auprc_ma(preds_t, targets_t)),
        "per_label_acc": per_label(preds_t, targets_t).detach().cpu().numpy(),  # shape (C,)
    }
    return out

def compute_metrics_by_fold(results_df: pd.DataFrame,
                            target_names=None,
                            fold_col: str = "Fold",
                            pred_col: str = "Predictions",
                            targ_col: str = "Targets",
                            threshold: float = 0.5):

    first_row = results_df[pred_col].iloc[0]
    C = parse_array(first_row).shape[0]
    if target_names is None:
        target_names = [f"Label_{i}" for i in range(C)]
    else:
        assert len(target_names) == C, f"len(target_names)={len(target_names)} but C={C}"

    folds = sorted(results_df[fold_col].unique())
    per_fold_rows, per_label_rows = [], []

    for f in folds:
        df_f = results_df[results_df[fold_col] == f]

        preds_np   = np.stack(df_f[pred_col].apply(parse_array).to_list()).astype(np.float32)  # (N_f, C)
        targets_np = np.stack(df_f[targ_col].apply(parse_array).to_list()).astype(np.int64)    # (N_f, C)

        fm = _compute_fold_metrics(preds_np, targets_np, threshold=threshold)

        row = {"fold": f, "n": fm["n"]}
        for k in ["subset_acc","macro_acc",
                 "macro_auroc","macro_auprc"]:
            row[k] = fm[k]
        per_fold_rows.append(row)

        # per-label accuracy for this fold
        pla = fm["per_label_acc"]
        pla_row = {"fold": f, **{f"acc_{name}": float(pla[i]) for i, name in enumerate(target_names)}}
        per_label_rows.append(pla_row)

    per_fold_df = pd.DataFrame(per_fold_rows).set_index("fold").sort_index()
    per_label_acc_df = pd.DataFrame(per_label_rows).set_index("fold").sort_index()
    per_label_acc_df.loc["mean"] = per_label_acc_df.mean(axis=0)

    return per_fold_df, per_label_acc_df, target_names

def summarize_across_folds(per_fold_df: pd.DataFrame):

    if "n" not in per_fold_df.columns:
        raise ValueError("per_fold_df must contain 'n'")

    totals = {"n": int(per_fold_df["n"].sum())}
    metric_cols = [c for c in per_fold_df.columns if c != "n"]

    unweighted_mean = per_fold_df[metric_cols].mean().to_dict()
    std_over_folds = per_fold_df[metric_cols].std(ddof=1).to_dict()

    return unweighted_mean, totals, std_over_folds

def pooled_metrics(results_df: pd.DataFrame,
                   pred_col="Predictions",
                   targ_col="Targets",
                   threshold=0.5) -> dict:

    preds_np = np.stack(results_df[pred_col].apply(parse_array).to_list()).astype(np.float32)
    targs_np = np.stack(results_df[targ_col].apply(parse_array).to_list()).astype(np.int64)
    return _compute_fold_metrics(preds_np, targs_np, threshold=threshold)

# ----------------------------
# Run
# ----------------------------
if __name__ == "__main__":
    # Load CSV
    results_df = pd.read_csv(CSV_PATH)

    # Compute per-fold metrics
    per_fold_df, per_label_acc_df, target_names = compute_metrics_by_fold(
        results_df,
        target_names=TARGET_NAMES,
        fold_col=FOLD_COL,
        pred_col=PREDCOL if (PREDCOL := PRED_COL) else "Predictions",  # defensive alias
        targ_col=TARGCOL if (TARGCOL := TARG_COL) else "Targets",
        threshold=THRESHOLD
    )

    # Summaries across folds
    unweighted_mean, totals, std_over_folds = summarize_across_folds(per_fold_df)

    # Pooled metrics
    pooled = pooled_metrics(results_df, pred_col=PRED_COL, targ_col=TARG_COL, threshold=THRESHOLD)

    # ----------------------------
    # Print results
    # ----------------------------
    pd.set_option("display.width", 180)
    pd.set_option("display.max_columns", 200)

    print("\nPer-fold metrics:")
    print(per_fold_df.round(4))

    print("\nPer-label accuracy per fold:")
    print(per_label_acc_df.round(4))

    print(f"\nTotal samples (sum over folds): {totals['n']}")

    print("\nAverage over folds:")
    for k, v in unweighted_mean.items():
        print(f"{k}: {v:.4f}")

    print("\nStd over folds:")
    for k, v in std_over_folds.items():
        print(f"{k}: {v:.4f}")

    print("\nPooled over all folds:")
    for k, v in pooled.items():
        if k == "per_label_acc":
            continue
        if k == "n":
            print(f"{k}: {int(v)}")
        else:
            print(f"{k}: {v:.4f}")
