# Basic Imports and Data + Model Loading

In [1]:
# Data loading
from glob import glob
from pathlib import Path
import joblib

# Key imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from scipy.stats import uniform, randint

# Preprocessing
from imblearn.pipeline import Pipeline as ImbPipeline
from sklearn.feature_selection import SelectKBest, f_classif, mutual_info_classif, VarianceThreshold
from sklearn.model_selection import StratifiedKFold 
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.combine import SMOTEENN
from sklearn.base import clone
from sklearn.utils import resample
from scipy.stats import norm

# ML implementation
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.dummy import DummyClassifier

# Hyperparameter tuning
from sklearn.experimental import enable_halving_search_cv 
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, HalvingGridSearchCV, HalvingRandomSearchCV
from skopt import BayesSearchCV

# Model evaluation
import sklearn.metrics as skmetrics
import imblearn.metrics as imbmetrics

In [2]:
import warnings
import os
import sys
from sklearn.exceptions import ConvergenceWarning
from IPython.core.interactiveshell import InteractiveShell

# Suppress Python warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=ConvergenceWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=RuntimeWarning)

In [3]:
# Set configuration
config = "config1"

# Set project root (where PreppedData and TunedModels live)
project_root = Path.cwd()
print(f"Project root directory: {project_root}")

# Define data and model directories
prepped_data_dir = project_root / f"PreppedData_{config}"
model_dir = project_root / f"TunedModels_{config}"
bootstrap_results_dir = project_root/f"BootstrapResults_{config}"

# Validate directories exist
if not prepped_data_dir.exists():
    raise FileNotFoundError(f"Prepped data directory not found at: {prepped_data_dir}")
if not model_dir.exists():
    raise FileNotFoundError(f"Model directory not found at: {model_dir}")

# Load data
data_file = prepped_data_dir / f'combined_log_transformed_{config}.npz'
data = np.load(data_file)
print(f"Loaded data from {data_file}")

# Display relevant summary stats
n_split = 5
shuffle = True
random_state = 42

X = data['X']
y = data['y']

print('X:')
display(X)
print('\nClass counts:', Counter(y))
print('\nX shape:', X.shape, 'y shape:', y.shape)
print('\nUnique class labels:', np.unique(y))

Project root directory: /home/fs1620/MLBD_2024_25/Research_Project/LiaDataAnalysis/SampleSizePowerCalc
Loaded data from /home/fs1620/MLBD_2024_25/Research_Project/LiaDataAnalysis/SampleSizePowerCalc/PreppedData_config1/combined_log_transformed_config1.npz
X:


array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       ...,
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        7.47222030e-05, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        8.29809237e-05, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00]])


Class counts: Counter({np.int64(0): 1734, np.int64(1): 1331})

X shape: (3065, 2947) y shape: (3065,)

Unique class labels: [0 1]


In [4]:
import numpy as np
from scipy.stats import norm
from sklearn.utils import resample
from sklearn.metrics import (
    accuracy_score, recall_score, precision_score,
    f1_score, roc_auc_score, average_precision_score
)
from imblearn.metrics import geometric_mean_score

def evaluate_metric(y_true, y_pred, y_prob, metric='accuracy', pos_label=1):
    """
    Compute performance metric based on predictions.

    Parameters:
    - y_true: array-like of true class labels
    - y_pred: array-like of predicted class labels
    - y_prob: array-like of predicted positive class probabilities
    - metric: performance metric string
    - pos_label: label of positive class (default=1)

    Returns:
    - metric_value: float
    """
    accepted_metrics = [
        'accuracy', 'sensitivity', 'recall', 'precision',
        'specificity', 'f1_score', 'roc_auc', 'average_precision', 'gmean'
    ]
    assert metric in accepted_metrics, f"metric must be one of {accepted_metrics}"

    if metric == 'accuracy':
        return accuracy_score(y_true, y_pred)
        
    elif metric in ['sensitivity', 'recall']:
        return recall_score(y_true, y_pred, pos_label=pos_label)
        
    elif metric == 'precision':
        return precision_score(y_true, y_pred, pos_label=pos_label)
        
    elif metric == 'specificity':
        # True negatives / (true negatives + false positives)
        tn = np.sum((y_true != pos_label) & (y_pred != pos_label))
        fp = np.sum((y_true != pos_label) & (y_pred == pos_label))
        return tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
    elif metric == 'f1_score':
        return f1_score(y_true, y_pred, pos_label=pos_label)
        
    elif metric == 'roc_auc':
        return roc_auc_score(y_true, y_prob)
        
    elif metric == 'average_precision':
        return average_precision_score(y_true, y_prob)
        
    elif metric == 'gmean':
        return geometric_mean_score(y_true, y_pred, pos_label=pos_label)

In [5]:
import joblib
import os
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from joblib import Parallel, delayed
import numpy as np

def cv_bootstrap_scores(
    X, y, model, metrics,
    sample_fracs=np.linspace(0.1, 1.0, 10),
    n_splits=5, n_bootstrap=100,
    random_state=42, n_jobs=-1,
    replace=False,
    save_results=False,
    save_path="results.pkl"
):
    """
    Returns bootstrap distributions of CV scores per training fraction.
    results[frac][metric] = {'scores': [...], 'mean': float, 'std': float, 'ci95': (lo, hi)}
    """
    
    rng = np.random.RandomState(random_state)
    results = {}

    def run_bootstrap(seed, n_samples_inner):
        rng_local = np.random.RandomState(seed)
        subsample_idx = rng_local.choice(len(X), n_samples_inner, replace=replace)
        X_sub, y_sub = X[subsample_idx], y[subsample_idx]
        
        # Must have both classes for stratification
        if len(np.unique(y_sub)) < 2:
            return None

        fold_scores = {m: [] for m in metrics}
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True,
                              random_state=rng_local.randint(1e6))

        for train_idx, test_idx in skf.split(X_sub, y_sub):
            model_clone = clone(model)
            X_train, X_test = X_sub[train_idx], X_sub[test_idx]
            y_train, y_test = y_sub[train_idx], y_sub[test_idx]

            model_clone.fit(X_train, y_train)
            y_pred = model_clone.predict(X_test)
            y_prob = None
            
            if hasattr(model_clone, "predict_proba"):
                y_prob = model_clone.predict_proba(X_test)[:, 1]

            for metric in metrics:
                s = evaluate_metric(y_test, y_pred, y_prob, metric=metric)
                fold_scores[metric].append(s)

        out = {}
        for metric in metrics:
            s = np.array(fold_scores[metric], dtype=float)
            s = s[~np.isnan(s)]
            if s.size:
                out[metric] = float(np.mean(s))
        return out

    for frac in sample_fracs:
        n_samples = int(len(X) * frac)
        print(f"Training fraction = {frac:.2f} ≈ {n_samples} samples")

        frac_res = {metric: {'scores': []} for metric in metrics}
        seeds = rng.randint(0, int(1e6), size=n_bootstrap)

        bootstraps = Parallel(n_jobs=n_jobs)(
            delayed(run_bootstrap)(s, n_samples) for s in seeds
        )

        for b in bootstraps:
            if b is None:
                continue
            for metric, v in b.items():
                frac_res[metric]['scores'].append(v)

        for metric in metrics:
            s = np.array(frac_res[metric]['scores'], dtype=float)
            if s.size == 0:
                continue
            lo, hi = np.percentile(s, [2.5, 97.5])
            frac_res[metric]['mean'] = float(np.mean(s))
            frac_res[metric]['std'] = float(np.std(s))
            frac_res[metric]['ci95'] = (float(lo), float(hi))

        results[frac] = frac_res

    # ---- SAVE RESULTS ----
    if save_results:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        joblib.dump(results, save_path)
        print(f"Results saved to {save_path}")

    return results

In [6]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def plot_learning_curves(
    results,
    core_metrics=None,
    combined_metrics=None,
    plot_ci=True,
    ci_type='std',             # 'std' or 'ci95'
    figsize=(14, 5),
    title_prefix="Learning Curves",
    legend_ncol=2,
    save_plot=False,
    save_folder='LearningCurves',
    save_name='LearningCurves_model.png',
    # ---- NEW ----
    overlay_extrapolation=False,
    extrapolation_df=None,     # DataFrame from plan_required_samples
    pred_fns=None,              # dict returned from plan_required_samples
    total_n_current=None,
    show_ci=True
):
    """
    Plot learning curves with confidence bands (std or ci95),
    and optionally overlay extrapolated curves and target lines from plan_required_samples.
    """

    # ---- Canonical metric mapping ----
    canon_map = {
        'recall': 'recall', 'sensitivity': 'recall', 'tpr': 'recall',
        'precision': 'precision',
        'specificity': 'specificity', 'tnr': 'specificity',
        'specificity_score': 'specificity', 'true_negative_rate': 'specificity',
        'f1': 'f1_score', 'f1_score': 'f1_score',
        'gmean': 'gmean', 'g_mean': 'gmean', 'g-mean': 'gmean',
        'geometric_mean': 'gmean', 'geometric_mean_score': 'gmean',
        'roc_auc': 'roc_auc',
        'average_precision': 'average_precision', 'ap': 'average_precision',
        'accuracy': 'accuracy'
    }

    pretty_labels = {
        'recall': 'Sensitivity (Recall)',
        'precision': 'Precision',
        'specificity': 'Specificity',
        'f1_score': 'F1 Score',
        'gmean': 'G-mean',
        'roc_auc': 'ROC-AUC',
        'average_precision': 'Average Precision',
        'accuracy': 'Accuracy'
    }

    # Defaults
    if core_metrics is None:
        core_metrics = ['recall', 'precision', 'specificity']
    if combined_metrics is None:
        combined_metrics = ['f1_score', 'gmean']

    fracs = sorted(results.keys())

    # Canonicalize results
    canon_results = {}
    for f in fracs:
        cf = {}
        for k, v in results[f].items():
            c = canon_map.get(k, k)
            if c in cf:
                # merge duplicates
                old = cf[c]
                merged = {**old, **v}
                if isinstance(old.get('scores', []), list) and isinstance(v.get('scores', []), list):
                    merged['scores'] = old.get('scores', []) + v.get('scores', [])
                cf[c] = merged
            else:
                cf[c] = v
        canon_results[f] = cf

    # Helper for filtering metrics
    def available(metric):
        return any(metric in canon_results[f] for f in fracs)

    def canonical_and_filter(metric_list):
        seen, keep = set(), []
        for m in metric_list:
            cm = canon_map.get(m, m)
            if available(cm) and cm not in seen:
                keep.append(cm)
                seen.add(cm)
        return keep

    core_metrics = canonical_and_filter(core_metrics)
    combined_metrics = canonical_and_filter(combined_metrics)

    sets = []
    if core_metrics:
        sets.append(('Core metrics', core_metrics))
    if combined_metrics:
        sets.append(('Combined metrics', combined_metrics))

    if not sets:
        raise ValueError("No requested metrics were found in results.")

    # ---- Helper: build arrays ----
    def series_for_metric(metric):
        means, lows, highs = [], [], []
        for f in fracs:
            if metric not in canon_results[f]:
                means.append(np.nan); lows.append(np.nan); highs.append(np.nan)
                continue

            d = canon_results[f][metric]
            mean = d.get('mean', np.nan)

            if plot_ci:
                if ci_type == 'ci95':
                    ci = d.get('ci95', None)
                    if ci and isinstance(ci, (list, tuple)) and len(ci) == 2 and np.all(np.isfinite(ci)):
                        lo, hi = ci
                    else:
                        lo, hi = np.nan, np.nan
                elif ci_type == 'std':
                    std = d.get('std', np.nan)
                    if np.isfinite(std):
                        lo, hi = mean - std, mean + std
                    else:
                        lo, hi = np.nan, np.nan
                else:
                    raise ValueError("ci_type must be 'std' or 'ci95'")
            else:
                lo, hi = np.nan, np.nan

            means.append(mean); lows.append(lo); highs.append(hi)

        return np.array(means, float), np.array(lows, float), np.array(highs, float)

    # ---- Plot ----
    ncols = len(sets)
    fig, axs = plt.subplots(1, ncols, figsize=figsize, sharex=True)
    if ncols == 1:
        axs = [axs]

    # Observed sample sizes
    n_obs = np.array(fracs) * (total_n_current if total_n_current else 1)

    for ax, (title, metric_list) in zip(axs, sets):
        any_plotted = False
        for metric in metric_list:
            means, lows, highs = series_for_metric(metric)
            if not np.isfinite(means).any():
                continue

            # Plot observed means
            (line,) = ax.plot(n_obs, means, marker='o', label=pretty_labels.get(metric, metric))
            any_plotted = True

            # Plot CI band
            if plot_ci:
                mask = np.isfinite(lows) & np.isfinite(highs)
                if mask.any():
                    ax.fill_between(n_obs[mask], lows[mask], highs[mask],
                                    alpha=0.2, color=line.get_color())

            # ---- Overlay extrapolated curves and targets ----
            if overlay_extrapolation and extrapolation_df is not None and pred_fns is not None:
                row = extrapolation_df[extrapolation_df["metric"] == metric]
                if len(row):
                    row = row.iloc[0]
                    pred_fn = pred_fns.get(metric, None)
                    if pred_fn:
                        # Extrapolation up to max(n_required) or 2x current
                        max_n = max(row.get("n_required", n_obs[-1]), n_obs[-1]) * 1.2
                        n_grid = np.linspace(1, max_n, 300)
                        ax.plot(n_grid, pred_fn(n_grid), ls="--", lw=1.2,
                                color=line.get_color(), alpha=0.8)

                    # Target horizontal line
                    ax.axhline(row["target"], ls="--", color="red", lw=1.3, alpha=0.7)

                    # Required sample vertical line + CI band
                    if "n_required" in row and np.isfinite(row["n_required"]):
                        ax.axvline(row["n_required"], ls=":", color="blue", lw=1.3, alpha=0.8,
                                   label=f"N required ≈ {int(row['n_required'])}")
                    if show_ci and "ci95_n_lo" in row and "ci95_n_hi" in row:
                        if np.isfinite(row["ci95_n_lo"]) and np.isfinite(row["ci95_n_hi"]):
                            ax.axvspan(row["ci95_n_lo"], row["ci95_n_hi"], color="blue", alpha=0.1)

        ax.set_title(title)
        ax.set_xlabel("Sample size (N)" if total_n_current else "Training Fraction")
        ax.set_ylabel("Score")
        ax.grid(True, alpha=0.3)

        if any_plotted:
            handles, labels = ax.get_legend_handles_labels()
            if plot_ci:
                ci_label = "±1 Std. Dev." if ci_type == "std" else "95% Bootstrap CI"
                handles.append(Patch(facecolor='gray', alpha=0.2, label=ci_label))
                labels.append(ci_label)
            ax.legend(handles, labels, ncol=legend_ncol, frameon=False)

    fig.suptitle(f"{title_prefix}", fontsize=14)
    fig.tight_layout(rect=[0, 0, 1, 0.95])

    if save_plot:
        save_path = Path(save_folder)
        save_path.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path / save_name, dpi=300)

    return fig, axs

In [7]:
import numpy as np
import pandas as pd
from typing import Callable, Dict, List, Tuple, Optional
from dataclasses import dataclass
from scipy.optimize import curve_fit, brentq
from scipy.special import expit, logit

# ------ Models for learning curves ------
def inv_power(n, a, b, c):
    # y = a - b * n^{-c}; a in [0,1], b>=0, c>0   (asymptote at a)
    n = np.asarray(n, float)
    return a - b * np.power(np.maximum(n, 1.0), -np.maximum(c, 1e-6))

def michaelis_menten(n, a, b, c):
    # y = c + a * n / (b + n); asymptote at c + a
    n = np.asarray(n, float)
    return c + a * (n / (np.maximum(b, 1e-6) + n))

# Preferred this model, from experimentation
def logit_logn(n, a, b):
    # logit(y) = a + b*log(n) -> keeps y in (0,1)
    n = np.asarray(n, float)
    return expit(a + b * np.log(np.maximum(n, 1.0)))

MODEL_SPECS = {
    "inv_power": {
        "func": inv_power,
        "p0":  lambda n, y: (np.nanmax(y), max(np.nanmax(y)-np.nanmin(y), 1e-3), 0.5),
        "bounds": ([0.0, 0.0, 1e-3], [1.0, 1.0, 5.0]), # lower and upper bounds for (a, b, c) respectively
        "bounded_01": True,
    },
    "mm": {
        "func": michaelis_menten,
        "p0":  lambda n, y: (max(np.nanmax(y)-np.nanmin(y), 1e-3), np.nanmedian(n), np.nanmin(y)),
        "bounds": ([0.0, 1e-6, 0.0], [1.0, 1e9, 1.0]), # lower and upper bounds for (a, b, c) respectively
        "bounded_01": True,
    },
    "logit_logn": {
        "func": logit_logn,
        "p0":  lambda n, y: (logit(np.clip(np.nanmedian(y), 1e-6, 1-1e-6)), 0.2),
        "bounds": ([-20.0, -5.0], [20.0, 5.0]), # lower and upper bounds for (a, b) respectively
        "bounded_01": True,
    },
}

In [8]:
# ------ Extract series from results ------
def extract_series(results: Dict, metric: str):
    """
    Extract a metric's data across all training fractions.

    Returns:
        fracs: sorted array of training fractions
        means: mean score per fraction
        stds: standard deviation per fraction
        ci_lo, ci_hi: 95% confidence interval bounds per fraction
        score_lists: list of raw bootstrap scores per fraction
    """
    fracs = np.array(sorted(results.keys()), float)
    means, stds, ci_lo, ci_hi, score_lists = [], [], [], [], []

    for f in fracs:
        d = results[f].get(metric, {})
        means.append(d.get("mean", np.nan))   # mean score
        stds.append(d.get("std", np.nan))     # std of score

        # extract 95% CI if available, else NaNs
        if "ci95" in d and d["ci95"] is not None:
            lo, hi = d["ci95"]
        else:
            lo, hi = np.nan, np.nan
        ci_lo.append(lo)
        ci_hi.append(hi)

        # bootstrap sample scores (may be empty if not recorded)
        score_lists.append(d.get("scores", []))

    return (
        fracs,
        np.array(means, float),
        np.array(stds, float),
        np.array(ci_lo, float),
        np.array(ci_hi, float),
        score_lists,
    )

In [9]:
@dataclass
class FitResult:
    model: str
    params: np.ndarray
    pred_fn: Callable[[np.ndarray], np.ndarray]

def fit_curve(n_obs: np.ndarray,
              y_obs: np.ndarray,
              model: str = "inv_power",
              sigma: Optional[np.ndarray] = None) -> FitResult:
    """
    Fit a specified learning-curve model to observed data.

    Args:
        n_obs: array of training sample sizes (x-values)
        y_obs: array of observed metric values (y-values)
        model: name of model in MODEL_SPECS
        sigma: optional per-point standard deviations (used as weights)

    Returns:
        FitResult with best-fit parameters and a callable predictor.
    """
    # Look up the model's function, initial guess, and bounds
    spec = MODEL_SPECS[model]
    func = spec["func"]

    # Keep only finite points (drop NaN or inf)
    mask = np.isfinite(n_obs) & np.isfinite(y_obs)
    n_fit, y_fit = n_obs[mask], y_obs[mask]

    # Ensure we have enough points to estimate all parameters
    n_params = len(np.atleast_1d(spec["p0"](n_fit, y_fit)))
    if n_fit.size < n_params:
        raise ValueError("Not enough finite points to fit the model.")

    # Clean sigma (if provided)
    if sigma is not None:
        sigma = np.asarray(sigma, float)[mask]
        if not np.isfinite(sigma).any():
            sigma = None
        else:
            # Replace NaNs/zeros with the median sigma
            med = np.nanmedian(sigma[np.isfinite(sigma)])
            sigma = np.where(np.isfinite(sigma) & (sigma > 1e-12), sigma, med)

    # Fit the model using nonlinear least squares
    popt, _ = curve_fit(
        func,
        n_fit,
        y_fit,
        p0=spec["p0"](n_fit, y_fit),    # initial guess
        bounds=spec["bounds"],          # parameter bounds
        sigma=sigma,                    # weights (if available)
        absolute_sigma=bool(sigma is not None),
        maxfev=20000
    )

    # Return fitted parameters and a predictor function
    return FitResult(model=model, params=popt, pred_fn=lambda n: func(n, *popt))

In [10]:
from typing import Callable, Optional, Tuple
from scipy.optimize import brentq
import numpy as np

def solve_required_n(pred_fn: Callable[[np.ndarray], np.ndarray],
                     target: float,
                     n_lower: float,
                     n_upper: Optional[float] = None,
                     max_expand: int = 14,
                     debug: bool = False) -> Optional[Tuple[float, str]]:
    """
    Find the smallest n >= n_lower such that pred_fn(n) >= target.

    Expands an upper bracket exponentially (x2) and then uses brentq root-finding on:
        g(n) = pred_fn(n) - target

    Args:
        pred_fn: callable returning predicted metric for an array of n
        target: target metric value
        n_lower: starting point (minimum allowable n)
        n_upper: optional starting upper bound (if None, guesses one)
        max_expand: maximum number of bracket expansions before stopping
        debug: if True, prints debug info on bracket expansion

    Returns:
        (n_required, status) where:
            n_required: smallest n meeting target, or None if unreachable
            status: description of how the solution was found
    """
    f_lower = float(pred_fn([n_lower]))
    if f_lower >= target:
        return float(n_lower), 'Already meets or exceeds target'

    if n_upper is None:
        n_upper = max(n_lower * 2.0, n_lower + 1.0)

    for i in range(max_expand):
        f_upper = float(pred_fn([n_upper]))
        if debug:
            print(f"[debug] Bracket expansion iteration {i}: n_upper={n_upper}, pred={f_upper}")
        if f_upper >= target:
            # We have bracketed the target: solve for root of g(n) = pred_fn(n) - target
            g = lambda n: float(pred_fn([n]) - target)
            try:
                n_required = float(brentq(g, n_lower, n_upper, maxiter=200))
                return n_required, 'brentq root found'
            except ValueError:
                # Root finder failed (flat or non-monotonic region): return conservative upper bound
                return float(n_upper), 'brentq failed, using conservative upper bound'
        # Expand bracket if target not yet reached
        n_lower, n_upper = n_upper, n_upper * 2.0

    # Could not reach the target even after expanding bracket
    return None, "Exhausted bracket expansions, no roots found"

In [11]:
import numpy as np
from typing import Tuple, Union

def compute_target(means: np.ndarray,
                   improvement: Tuple[str, float],
                   benchmark: Union[float, str] = "max_observed",
                   clip01: bool = True) -> float:
    """
    Compute a target metric value based on a baseline benchmark and a specified improvement.

    Args:
        means: array of observed mean metric values (may contain NaNs).
        improvement: tuple ('add' or 'mult', value) specifying additive or multiplicative improvement over baseline.
        benchmark: either 'max_observed' to use max of means as baseline, or a numeric baseline value.
        clip01: whether to clip the target value between 0 and 1.

    Returns:
        A float target value representing the improved metric.
    """
    # Determine baseline
    if benchmark == "max_observed":
        base = float(np.nanmax(means))
    elif benchmark == 'first_observed':
        base = float(np.round(means[0], 3))
    elif isinstance(benchmark, (int, float)):
        base = float(benchmark)
    else:
        raise ValueError("benchmark must be 'max_observed', 'first_observed', or a numeric value")

    kind, val = improvement
    val = float(val)
    if kind == "add":
        target = base + val
    elif kind == "mult":
        target = base * val
    else:
        raise ValueError("improvement must be ('add', Δ) or ('mult', r)")

    return float(np.clip(target, 0.0, 1.0) if clip01 else target)

In [12]:
def plan_required_samples(
    results: Dict,
    metrics: List[str],
    total_n_current: int,
    improvement: Tuple[str, float] | Dict[str, Tuple[str, float]],
    benchmark: float | str = "max_observed",
    model: str | Dict[str, str] = "inv_power",
    ci_mode: str = "bootstrap",
    ci_band_type: str = "ci95",
    n_bootstrap_fits: int = 400,
    random_state: int = 42,
    use_sigma_from_std: bool = True,
    save_path: str | None = None,
    expected_annotated_per_sample: int = 183
) -> Tuple[pd.DataFrame, Dict[str, Callable]]:
    """
    Estimate the number of additional annotated samples required to achieve
    specified improvements in model performance metrics by fitting and extrapolating
    learning curves.

    Parameters
    ----------
    results : Dict
        Dictionary containing learning curve data (fracs, means, stds, etc.) per metric.
    metrics : List[str]
        List of performance metric names to evaluate (e.g., ["sensitivity", "F1"]).
    total_n_current : int
        Current total number of annotated data points (e.g., pixels).
    improvement : Tuple[str, float] or Dict[str, Tuple[str, float]]
        Improvement targets. Each can be:
            - ("add", 0.02): additive increase over baseline
            - ("mult", 1.02): multiplicative increase over baseline
        Can be a single global value or specified per metric.
    benchmark : str or float, default="max_observed"
        Baseline to improve upon. Options:
            - "max_observed": best observed score
            - "first_observed": score at smallest training size
            - float: fixed baseline value
    model : str or Dict[str, str], default="inv_power"
        Curve fitting model to use, e.g., "inv_power", "log", "exp".
        Can be a single model or a dict specifying one per metric.
    ci_mode : str, default="bootstrap"
        Method for uncertainty estimation. Currently only "bootstrap" is supported.
    ci_band_type : str, default="ci95"
        Type of confidence interval band to use (e.g., "ci95").
    n_bootstrap_fits : int, default=400
        Number of bootstrap resamples to estimate uncertainty in n_required.
    random_state : int, default=42
        Random seed for reproducibility.
    use_sigma_from_std : bool, default=True
        Whether to use observed standard deviations when fitting curves.
    save_path : str or None, default=None
        If given, saves the resulting DataFrame as a CSV at the specified path.
    expected_annotated_per_sample : int, default=183
        Median number of annotated pixels per new sample, used to translate pixel
        requirements into sample counts.

    Returns
    -------
    Tuple[pd.DataFrame, Dict[str, Callable]]
        - DataFrame summarizing required additional samples per metric, with bounds.
        - Dictionary of fitted prediction functions per metric.
    """
    rng = np.random.RandomState(random_state)
    rows = []
    pred_fns = {}
    baseline_metrics = {}

    for metric in metrics:
        fracs, means, stds, ci_lo, ci_hi, score_lists = extract_series(results, metric)
        if not np.isfinite(means).any():
            rows.append(dict(metric=metric, fit_status="no_data"))
            continue

        # ---- Baseline ----
        if benchmark == "max_observed":
            baseline_metrics[metric] = float(np.round(np.nanmax(means), 3))
        elif benchmark == "first_observed":
            baseline_metrics[metric] = float(np.round(means[0], 3))
        elif isinstance(benchmark, (int, float)):
            baseline_metrics[metric] = float(np.round(benchmark, 3))
        else:
            raise ValueError("benchmark must be 'max_observed', 'first_observed', or a number")

        # ---- Improvement ----
        metric_improvement = improvement[metric] if isinstance(improvement, dict) else improvement
        improve_type, improve_value = metric_improvement

        # ---- Map fracs -> absolute Ns ----
        n_obs = np.maximum(1.0, np.round(fracs * total_n_current))
        n_current_max = int(np.round(np.nanmax(n_obs)))

        # ---- Target value ----
        target = compute_target(means, metric_improvement, benchmark=benchmark, clip01=True)

        # ---- Fit mean curve ----
        sigma = stds if (use_sigma_from_std and np.isfinite(stds).any()) else None
        try:
            fitting_model = model[metric] if isinstance(model, dict) else model
            fit_mean = fit_curve(n_obs, means, model=fitting_model, sigma=sigma)
            pred_fns[metric] = fit_mean.pred_fn
            n_req, root_status = solve_required_n(fit_mean.pred_fn, target, n_lower=float(n_current_max))
        except Exception as e:
            rows.append(dict(metric=metric, baseline=baseline_metrics[metric],
                             fit_status=f"fit_error: {e}"))
            continue

        # ---- Bootstrap CIs ----
        if ci_mode == "bootstrap":
            n_reqs_boot = []
            new_samples_boot = []

            for _ in range(n_bootstrap_fits):
                boot_means = []
                for s in score_lists:
                    s = np.asarray(s, float)
                    boot_means.append(rng.choice(s, size=s.size, replace=True).mean()
                                      if s.size and np.isfinite(s).any() else np.nan)
                boot_means = np.asarray(boot_means, float)

                try:
                    fit_b = fit_curve(n_obs, boot_means, model=fitting_model, sigma=None)
                    n_req_b, _ = solve_required_n(fit_b.pred_fn, target, n_lower=float(n_current_max))
                    if n_req_b and np.isfinite(n_req_b):
                        n_reqs_boot.append(float(n_req_b))
                        new_samples_boot.append(
                            (n_req_b - n_current_max) / expected_annotated_per_sample
                        )
                except Exception:
                    pass

            if n_req is None:
                row = dict(metric=metric, 
                           baseline=baseline_metrics[metric],
                           target=target,
                           improve_specs = metric_improvement,
                           n_current=n_current_max,
                           fitting_model=fitting_model,
                           fit_status="unreachable_under_model",
                           root_status=root_status,
                           n_required=None, delta_n=None,
                           new_samples_req=None,
                           new_samples_req_lo=None,
                           new_samples_req_hi=None,
                           ci95_n_lo=None, ci95_n_hi=None)
            else:
                # Compute point estimate + conservative bounds
                samples_req = np.ceil((n_req - n_current_max) / expected_annotated_per_sample)
                lo_samples = np.percentile(new_samples_boot, 2.5) if new_samples_boot else samples_req
                hi_samples = np.percentile(new_samples_boot, 97.5) if new_samples_boot else samples_req
                lo_samples = min(samples_req, lo_samples)
                hi_samples = max(samples_req, hi_samples)

                ci_lo_n = np.percentile(n_reqs_boot, 2.5) if n_reqs_boot else n_req
                ci_hi_n = np.percentile(n_reqs_boot, 97.5) if n_reqs_boot else n_req

                row = dict(
                    metric=metric,
                    baseline=baseline_metrics[metric],
                    target=target,
                    improve_specs = metric_improvement,
                    n_current=n_current_max,
                    fitting_model=fitting_model,
                    fit_status="ok",
                    root_status=root_status,
                    n_required=int(round(n_req)),
                    delta_n=int(round(max(0.0, n_req - n_current_max))),
                    ci95_n_lo=int(np.round(ci_lo_n)),
                    ci95_n_hi=int(np.round(ci_hi_n)),
                    new_samples_req=int(samples_req),
                    new_samples_req_lo=int(np.ceil(lo_samples)),
                    new_samples_req_hi=int(np.ceil(hi_samples))
                )

            rows.append(row)
        else:
            raise NotImplementedError("Only bootstrap CI mode is currently supported.")

    # ---- Final DataFrame ----
    df = pd.DataFrame(rows)

    # Move sample-related columns to the end
    sample_cols = [
        "n_required", "delta_n",
        "ci95_n_lo", "ci95_n_hi", "new_samples_req",
        "new_samples_req_lo", "new_samples_req_hi"
    ]
    all_cols = [col for col in df.columns if col not in sample_cols] + sample_cols
    df = df[all_cols]

    if save_path is not None:
        df.to_csv(save_path, index=False)
        print(f"Dataframe saved as .csv at: {save_path}")

    return df, pred_fns

In [None]:
# Load model
model_file = model_dir / f'svm_v4_pipeline_{config}.pkl'
model = joblib.load(model_file)
print(f"Loaded model from {model_file}\n")
model_name = 'SVM'

overall_metrics = ['sensitivity', 'precision', 'specificity', 
                   'f1_score', 'gmean']
core_metrics = ['sensitivity', 'precision', 'specificity']
combined_metrics = ['f1_score', 'gmean']

results_svm = cv_bootstrap_scores(
    X=X, y=y,
    model=model,
    metrics=overall_metrics,
    sample_fracs=np.linspace(0.05, 1, 20),
    n_bootstrap = 50,
    replace = True,
    n_splits = 5,
    n_jobs = -1,
    save_results=True,
    save_path=f'BootstrapResults_config1/learning_curves_svm_{config}.pkl'
)

results_svm = joblib.load(bootstrap_results_dir/f'learning_curves_svm_{config}.pkl')

# print('\n')

fig_std, axs_std = plot_learning_curves(results=results_svm, 
                                        core_metrics=core_metrics, combined_metrics=combined_metrics,
                                        plot_ci=True, ci_type = 'std',
                                        save_plot=True, save_folder='LearningCurves_config1',
                                        title_prefix = f"Learning Curves ±1 Std. Dev., Model: {model_name}",
                                        save_name=f'LearningCurves_{model_name}_std.png')

fig_95, axs_95 = plot_learning_curves(results=results_svm, 
                                      core_metrics=core_metrics, combined_metrics=combined_metrics,
                                      plot_ci=True, ci_type = 'ci95',
                                      save_plot=True, save_folder='LearningCurves_config1',
                                      title_prefix = f"Learning Curves 95% Bootstrap CI, Model: {model_name}",
                                      save_name=f'LearningCurves_{model_name}_ci_95.png')

Loaded model from /home/fs1620/MLBD_2024_25/Research_Project/LiaDataAnalysis/SampleSizePowerCalc/TunedModels_config1/svm_v4_pipeline_config1.pkl

Training fraction = 0.05 ≈ 153 samples
Training fraction = 0.10 ≈ 306 samples
Training fraction = 0.15 ≈ 459 samples
Training fraction = 0.20 ≈ 613 samples
Training fraction = 0.25 ≈ 766 samples
Training fraction = 0.30 ≈ 919 samples
Training fraction = 0.35 ≈ 1072 samples
Training fraction = 0.40 ≈ 1226 samples
Training fraction = 0.45 ≈ 1379 samples
Training fraction = 0.50 ≈ 1532 samples
Training fraction = 0.55 ≈ 1685 samples
Training fraction = 0.60 ≈ 1839 samples
Training fraction = 0.65 ≈ 1992 samples


In [None]:
# From ROI Mask generation, found expected annotated number of pixels per sample = 183
# store in expected_annotated_per_sample = 183, round down to 180 for conservative count
# hardcoded for simplicity

expected_annotated_per_sample = 180

total_n_current = len(X)
metrics_to_plan = ["precision", "sensitivity", "specificity", "gmean", "f1_score"]

# 1.02 multiplicative increase over current baseline (max value of metric), for all metrics
improvements_dict = {"precision": ("mult", 1.02), 
                     "sensitivity": ("mult", 1.02), 
                     "specificity": ("mult", 1.02), 
                     "gmean": ("mult", 1.02), 
                     "f1_score": ("mult", 1.01)}

# Use the same "logit_logn" fitting model for all metrics
fitting_models_dict = {"precision": "logit_logn", 
                     "sensitivity": "logit_logn",
                     "specificity": "logit_logn", 
                     "gmean": "logit_logn", 
                     "f1_score": "logit_logn"}

model_save_name = "logit_logn"

df_plan, pred_fns = plan_required_samples(
    results=results_svm,
    metrics=metrics_to_plan,
    total_n_current=len(X),
    improvement=improvements_dict,
    model=fitting_models_dict,
    ci_mode="bootstrap",
    expected_annotated_per_sample = expected_annotated_per_sample,
    save_path=f"LearningCurves_config1/LCAnalysis_{model_save_name}_svm.csv",
)

display(df_plan)

# plot_learning_curves(
#     results,
#     core_metrics=core_metrics,
#     combined_metrics=combined_metrics,
#     overlay_extrapolation=True,
#     extrapolation_df=df_plan,
#     pred_fns=pred_fns,
#     total_n_current=len(X)
# )