In [1]:
import os
import time
import numpy as np
from xgboost import XGBClassifier
import pandas as pd
import math
from sklearn.model_selection import train_test_split, StratifiedKFold, ParameterSampler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample
import umap
from sklearn.model_selection import StratifiedKFold


from sklearn.metrics import (
    accuracy_score,
    classification_report,
    balanced_accuracy_score,
    roc_auc_score,
    f1_score,
    precision_recall_curve,
    auc,
    average_precision_score,
    PrecisionRecallDisplay,
    roc_curve
)
from sklearn.calibration import CalibratedClassifierCV
from sklearn.frozen import FrozenEstimator
from sklearn.preprocessing import label_binarize
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from itertools import cycle
from sklearn.preprocessing import label_binarize
from tqdm import tqdm  # progress bar

CELL_TYPE_MAPPING = {'B': 0, 'BnT': 1, 'CD4': 2, 'CD8': 3, 'DC': 4, 'HLADR': 5, 'MacCD163': 6, 'Mural': 7, 'NK': 8, 'Neutrophil': 9, 'Treg': 10, 'Tumor': 11, 'pDC': 12, 'plasma': 13}
CELL_TYPE_MAPPING_REVERSE = {v: k for k, v in CELL_TYPE_MAPPING.items()}

  from .autonotebook import tqdm as notebook_tqdm


## Helper functions

In [2]:
def bootstrap_evaluation(bootstrap_runs, df_X, df_y, path_to_save=None):
    results = {}  # will store everything per bootstrap run

    aucs = []
    f1s_per_class_ovr = []         # NEW: list of lists (OvR, AUC-style)
    f1s_per_class_multiclass = []  # NEW: list of lists (average=None)
    accuracies = []
    balanced_accuracies = []
    macro_f1s = []
    weighted_f1s = []

    X_train, X_val, y_train, y_val = train_test_split(
        df_X,
        df_y,
        test_size=0.2,
        random_state=42,
        stratify=df_y
    )
    
    # Fix an explicit class order for consistent storage
    class_names = list(CELL_TYPE_MAPPING.keys())
    class_labels = [CELL_TYPE_MAPPING[name] for name in class_names]  # e.g. [0,1,2,...] in your mapping

    for i in range(bootstrap_runs):
        X_train_bootstrap, y_train_bootstrap = resample(
            X_train, y_train,
            replace=True,
            random_state=i,
            stratify=y_train,
            n_samples=int(2/3 * len(X_train))
        )

        clf = LogisticRegression(max_iter=5000, class_weight="balanced", random_state=42)
        clf.fit(X_train_bootstrap, y_train_bootstrap)

        y_val_pred = clf.predict(X_val)

        # Safe proba indexing (columns correspond to clf.classes_)
        proba = clf.predict_proba(X_val)
        class_to_col = {c: j for j, c in enumerate(clf.classes_)}

        # ---- AUC per class + OvR F1 per class (loop, analogous to AUC) ----
        auc_this = []
        f1_ovr_this = []

        for name, label in zip(class_names, class_labels):
            y_true_binary = (y_val == label).astype(int)
            y_pred_binary = (y_val_pred == label).astype(int)

            y_score = proba[:, class_to_col[label]]
            fpr, tpr, _ = roc_curve(y_true_binary, y_score)
            auc_this.append(auc(fpr, tpr))

            f1_ovr_this.append(f1_score(y_true_binary, y_pred_binary, zero_division=0))

        aucs.append(auc_this)
        f1s_per_class_ovr.append(f1_ovr_this)

        # ---- Multiclass per-class F1 (sklearn’s average=None) ----
        # Important: force the same label order as CELL_TYPE_MAPPING
        f1_multi_this = f1_score(
            y_val,
            y_val_pred,
            average=None,
            labels=class_labels,
            zero_division=0
        )
        f1s_per_class_multiclass.append(f1_multi_this.tolist())

        # ---- Global metrics ----
        accuracies.append(accuracy_score(y_val, y_val_pred))
        balanced_accuracies.append(balanced_accuracy_score(y_val, y_val_pred))
        macro_f1s.append(f1_score(y_val, y_val_pred, average="macro", zero_division=0))
        weighted_f1s.append(f1_score(y_val, y_val_pred, average="weighted", zero_division=0))

    # ---- Save everything ----
    results["class_names"] = class_names
    results["class_labels"] = class_labels

    results["auc_per_class"] = aucs
    results["f1_per_class_ovr"] = f1s_per_class_ovr
    results["f1_per_class_multiclass"] = f1s_per_class_multiclass

    results["accuracy"] = accuracies
    results["balanced_accuracy"] = balanced_accuracies
    results["macro_f1"] = macro_f1s
    results["weighted_f1"] = weighted_f1s
    if path_to_save:
        np.save(path_to_save, results, allow_pickle=True)
    return results

In [3]:
def crossval_evaluation(df_X, df_y, path_to_save=None, n_splits=10):
    results = {}  # will store everything per fold

    aucs = []
    f1s_per_class_ovr = []         # list of lists (OvR, AUC-style)
    f1s_per_class_multiclass = []  # list of lists (average=None)
    accuracies = []
    balanced_accuracies = []
    macro_f1s = []
    weighted_f1s = []

    # Fix an explicit class order for consistent storage
    class_names = list(CELL_TYPE_MAPPING.keys())
    class_labels = [CELL_TYPE_MAPPING[name] for name in class_names]

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(df_X, df_y)):
        X_train = df_X[train_idx] if hasattr(df_X, "__getitem__") else df_X.iloc[train_idx]
        X_val   = df_X[val_idx]   if hasattr(df_X, "__getitem__") else df_X.iloc[val_idx]
        y_train = df_y[train_idx] if hasattr(df_y, "__getitem__") else df_y.iloc[train_idx]
        y_val   = df_y[val_idx]   if hasattr(df_y, "__getitem__") else df_y.iloc[val_idx]

        clf = LogisticRegression(max_iter=5000, class_weight="balanced", random_state=42)
        clf.fit(X_train, y_train)

        y_val_pred = clf.predict(X_val)

        # Safe proba indexing (columns correspond to clf.classes_)
        proba = clf.predict_proba(X_val)
        class_to_col = {c: j for j, c in enumerate(clf.classes_)}

        # ---- AUC per class + OvR F1 per class (loop, analogous to AUC) ----
        auc_this = []
        f1_ovr_this = []

        for name, label in zip(class_names, class_labels):
            y_true_binary = (y_val == label).astype(int)
            y_pred_binary = (y_val_pred == label).astype(int)

            # If a class is missing from training fold, clf.classes_ won't contain it.
            if label not in class_to_col:
                auc_this.append(np.nan)
                f1_ovr_this.append(f1_score(y_true_binary, y_pred_binary, zero_division=0))
                continue

            y_score = proba[:, class_to_col[label]]
            fpr, tpr, _ = roc_curve(y_true_binary, y_score)
            auc_this.append(auc(fpr, tpr))

            f1_ovr_this.append(f1_score(y_true_binary, y_pred_binary, zero_division=0))

        aucs.append(auc_this)
        f1s_per_class_ovr.append(f1_ovr_this)

        # ---- Multiclass per-class F1 (sklearn’s average=None) ----
        f1_multi_this = f1_score(
            y_val,
            y_val_pred,
            average=None,
            labels=class_labels,
            zero_division=0
        )
        f1s_per_class_multiclass.append(f1_multi_this.tolist())

        # ---- Global metrics ----
        accuracies.append(accuracy_score(y_val, y_val_pred))
        balanced_accuracies.append(balanced_accuracy_score(y_val, y_val_pred))
        macro_f1s.append(f1_score(y_val, y_val_pred, average="macro", zero_division=0))
        weighted_f1s.append(f1_score(y_val, y_val_pred, average="weighted", zero_division=0))

    # ---- Save everything ----
    results["class_names"] = class_names
    results["class_labels"] = class_labels

    results["auc_per_class"] = aucs
    results["f1_per_class_ovr"] = f1s_per_class_ovr
    results["f1_per_class_multiclass"] = f1s_per_class_multiclass

    results["accuracy"] = accuracies
    results["balanced_accuracy"] = balanced_accuracies
    results["macro_f1"] = macro_f1s
    results["weighted_f1"] = weighted_f1s

    if path_to_save:
        np.save(path_to_save, results, allow_pickle=True)

    return results