In [13]:
import numpy as np
import pandas as pd
from pathlib import Path

import os, sys

# 1) Add the project root (folder that contains snf_lite.py) to sys.path
PROJECT_ROOT = os.path.abspath("..")  # if notebook is in notebooks/
sys.path.append(PROJECT_ROOT)

from src.snf_lite import (
    load_view, align_views,
    gower_affinity, rbf_affinity,
    snf_fuse, eigengap_k, pam_kmedoids_best_of_n
)

from sklearn.manifold import spectral_embedding
from sklearn.metrics import pairwise_distances


In [17]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

def animate_stability(K_vals: np.ndarray,
                      S_norm: np.ndarray,
                      outfile: str = "stability_convergence.gif",
                      fps: int = 12):
    """
    K_vals : array of K values (length = n_K)
    S_norm : (n_K, B_max) array of stability scores in [0,1] (NaNs allowed)
    outfile: path to save GIF/MP4
    fps    : frames per second
    """
    K, B = S_norm.shape

    # --- running mean over bootstraps (for de-flicker + jitter control) ---
    R = np.zeros_like(S_norm)
    counts = np.zeros(K, dtype=float)
    cumul = np.zeros(K, dtype=float)

    for t in range(B):
        col = S_norm[:, t]
        valid = ~np.isnan(col)
        cumul[valid] += col[valid]
        counts[valid] += 1
        R[:, t] = np.divide(
            cumul,
            counts,
            out=np.zeros_like(cumul),
            where=counts > 0
        )

    # --- initial positions: one dot per tile (K x B) ---
    rng = np.random.default_rng(42)
    # each tile spans x in [b, b+1], y in [k, k+1]
    x = np.add.outer(np.arange(B), rng.uniform(0.15, 0.85, size=K)).T  # K×B
    y = np.add.outer(np.arange(K), rng.uniform(0.15, 0.85, size=B))    # K×B

    def reflect_clip(val, lo, hi):
        span = hi - lo
        z = (val - lo) % (2 * span)
        z = np.where(z > span, 2 * span - z, z)
        return lo + z

    # --- figure setup ---
    fig, ax = plt.subplots(figsize=(8, 5), dpi=120)

    # start by showing only first column
    M0 = np.full((K, B), np.nan)
    M0[:, :1] = S_norm[:, :1]
    im = ax.imshow(M0, origin="lower", aspect="auto", interpolation="nearest")

    scat = ax.scatter(x.flatten(), y.flatten(), s=6, alpha=0.8)

    ax.set_xlabel("Bootstrap replicate (B)")
    ax.set_ylabel("Number of classes (K)")
    ax.set_xlim(-0.1, B + 0.1)
    ax.set_ylim(-0.1, K + 0.1)

    # center y-ticks in tiles and label with actual K values
    ax.set_yticks(np.arange(K) + 0.5)
    ax.set_yticklabels(K_vals)

    ax.set_xticks(
        np.linspace(0, B, num=min(B, 10), endpoint=False, dtype=int)
    )
    ax.grid(which="both", linestyle=":", linewidth=0.5, alpha=0.4)

    # global annealing + local (score-based) jitter schedule
    sigma0 = 0.35          # initial jitter amplitude in tile units
    tau = B / 2.0          # global decay scale

    def update(frame):
        # reveal ARIs up to current frame
        M = np.full((K, B), np.nan)
        M[:, :frame+1] = S_norm[:, :frame+1]
        im.set_data(M)
        im.set_extent((0, B, 0, K))

        # jitter: decays over time AND with high running-mean R
        global_scale = sigma0 * np.exp(-frame / tau)
        local_scale = 1.0 - R[:, frame]       # high R -> less jitter
        step_std = global_scale * (0.15 + 0.85 * local_scale)
        step_std_matrix = np.repeat(step_std[:, None], B, axis=1)

        dx = rng.normal(0.0, step_std_matrix, size=(K, B))
        dy = rng.normal(0.0, step_std_matrix, size=(K, B))

        x[:] = reflect_clip(
            x + dx,
            np.arange(B)[None, :],
            (np.arange(B) + 1)[None, :]
        )
        y[:] = reflect_clip(
            y + dy,
            np.arange(K)[:, None],
            (np.arange(K)[:, None] + 1)
        )

        scat.set_offsets(np.c_[x.flatten(), y.flatten()])

        ax.set_title(
            f"Bootstrap stability convergence (frame {frame+1}/{B})"
        )
        return im, scat

    anim = FuncAnimation(fig, update, frames=B, interval=1000/fps, blit=False)
    writer = PillowWriter(fps=fps)
    anim.save(outfile, writer=writer)
    plt.close(fig)
    print(f"Saved animation to {outfile}")


In [8]:
C_PATH = "/Users/harisreedeth/Desktop/D/personal/ProjectMAIP/data/01_processed/C_view.csv"
P_PATH = "/Users/harisreedeth/Desktop/D/personal/ProjectMAIP/data/01_processed/P_view_scaled.csv"
S_PATH = "/Users/harisreedeth/Desktop/D/personal/ProjectMAIP/data/01_processed/S_view.csv"
ID_COL = "eid"

KNN = 25
ITERS = 10
ALPHA = 0.5
K_MIN, K_MAX = 2, 8
LAP_KNN = 25
SEED = 42


In [9]:
ids_C, X_C = load_view(C_PATH, ID_COL)
ids_P, X_P = load_view(P_PATH, ID_COL)
ids_S, X_S = load_view(S_PATH, ID_COL)

ids, [C_aln, P_aln, S_aln] = align_views(
    [ids_C, ids_P, ids_S],
    [X_C, X_P, X_S]
)

len(ids), C_aln.shape, P_aln.shape, S_aln.shape


(9105, (9105, 12), (9105, 19), (9105, 12))

In [10]:
A_C = gower_affinity(C_aln)
A_S = gower_affinity(S_aln)
A_P = rbf_affinity(P_aln)

affinities = [A_C, A_P, A_S]

In [11]:
fused = snf_fuse(
    affinities=affinities,
    k=KNN,
    iters=ITERS,
    alpha=ALPHA
)
np.save("snf_fused_test.npy", fused)


In [12]:
emb = spectral_embedding(
    adjacency=fused,
    n_components=max(K_MAX, 8),
    random_state=SEED,
    eigen_solver="arpack",
    drop_first=True,
)
emb.shape


(9105, 8)

In [16]:
def pam_labels_from_D(D, K, seed=42, n_init=20):
    """
    Wraps your project’s pam_kmedoids_best_of_n on a precomputed distance matrix D.

    Parameters
    ----------
    D : (n, n) array-like
        Precomputed distance matrix between patients.
    K : int
        Number of clusters.
    seed : int
        Random seed passed through to pam_kmedoids_best_of_n.
    n_init : int
        Number of PAM initialisations to try; best one is returned.

    Returns
    -------
    labels : (n,) array of int
        Cluster labels for each row in D.
    """
    labels, _ = pam_kmedoids_best_of_n(
        D,
        K,
        n_init=n_init,
        seed=seed,
    )
    return labels

from sklearn.metrics import adjusted_rand_score

def bootstrap_ari_full(D,
                       K,
                       B=200,
                       frac=0.75,
                       seed=42,
                       n_init=20):
    """
    Bootstrap ARI stability for a given K using your PAM (pam_kmedoids_best_of_n).

    D      : (n, n) distance matrix
    K      : number of clusters
    B      : number of bootstrap replicates
    frac   : fraction of rows per bootstrap sample
    seed   : base random seed
    n_init : passed through to pam_kmedoids_best_of_n
    """
    rng = np.random.default_rng(seed)
    n = D.shape[0]

    # 1) Reference clustering on full D
    lab_ref = pam_labels_from_D(D, K, seed=seed, n_init=n_init)

    rows = []
    for b in range(1, B + 1):
        idx = rng.choice(n, size=int(frac * n), replace=False)
        D_sub = D[np.ix_(idx, idx)]

        lab_sub = pam_labels_from_D(D_sub, K, seed=seed + b, n_init=n_init)
        ari = adjusted_rand_score(lab_ref[idx], lab_sub)

        rows.append({"K": K, "B": b, "ari": ari})

    df = pd.DataFrame(rows)
    summ = df["ari"].agg(mean="mean", sd="std", n="size")
    summ["hw95"] = 1.96 * summ["sd"] / np.sqrt(summ["n"])

    return df, summ

from sklearn.metrics import pairwise_distances

D = pairwise_distances(emb, metric="euclidean")

boot_rows = []
summaries = []

for K in range(K_MIN, K_MAX + 1):
    df_K, summ_K = bootstrap_ari_full(D, K, B=200, frac=0.75,
                                      seed=42, n_init=20)
    boot_rows.append(df_K)
    summaries.append({"K": K, **summ_K.to_dict()})

boot_ari_all = pd.concat(boot_rows, ignore_index=True)
boot_ari_all.to_csv("boot_ari_all_snf_pam.csv", index=False)

summary_df = pd.DataFrame(summaries).sort_values("K")
summary_df


Unnamed: 0,K,mean,sd,n,hw95
0,2,0.471487,0.354734,200.0,0.049164
1,3,0.224582,0.125652,200.0,0.017414
2,4,0.318687,0.127458,200.0,0.017665
3,5,0.364996,0.124762,200.0,0.017291
4,6,0.477985,0.127308,200.0,0.017644
5,7,0.444248,0.101795,200.0,0.014108
6,8,0.612593,0.099544,200.0,0.013796


In [18]:
import numpy as np
import pandas as pd

def load_boot_ari_matrix(csv_path: str):
    """
    Read boot_ari_all.csv with columns [K, B, ari] and return:
      - K_vals: sorted array of unique K values
      - S_norm: K x B_max matrix of ARI scores, column-wise 0–1 normalised
                (NaNs allowed where some K had fewer bootstraps)
    """
    df = pd.read_csv(csv_path)
    df = df.sort_values(["K", "B"])

    K_vals = np.sort(df["K"].unique())
    B_max = int(df["B"].max())

    # K x B_max matrix, NaN where K has fewer reps (early stop)
    S = np.full((len(K_vals), B_max), np.nan)

    for i, k in enumerate(K_vals):
        sub = df[df["K"] == k].sort_values("B")
        S[i, :len(sub)] = sub["ari"].values

    # column-wise 0–1 normalisation (ignore NaNs)
    S_min = np.nanmin(S, axis=0, keepdims=True)
    S_max = np.nanmax(S, axis=0, keepdims=True)
    S_norm = (S - S_min) / np.clip(S_max - S_min, 1e-9, None)

    return K_vals, S_norm

K_vals, S_norm = load_boot_ari_matrix("boot_ari_all_snf_pam.csv")

K_vals, S_norm = load_boot_ari_matrix("boot_ari_all_snf_pam.csv")
animate_stability(K_vals, S_norm,
                  outfile="ukb_snf_lca_stability.gif",
                  fps=12)


Saved animation to ukb_snf_lca_stability.gif


In [20]:
import numpy as np
import pandas as pd

def load_boot_df(csv_path: str,
                 param_col: str = "K",
                 frame_col: str = "B",
                 value_col: str = "ari"):
    df = pd.read_csv(csv_path)
    df = df[[param_col, frame_col, value_col]].copy()
    df = df.sort_values([param_col, frame_col])
    return df

df_boot = load_boot_df("/Users/harisreedeth/Desktop/D/personal/ProjectMAIP/notebooks/boot_ari_all_snf_pam.csv",
                       param_col="K",
                       frame_col="B",
                       value_col="ari")


In [23]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

def animate_param_stability(df,
                            param_col: str = "K",
                            frame_col: str = "B",
                            value_col: str = "ari",
                            layout: np.ndarray | None = None,
                            outfile: str = "param_stability.gif",
                            fps: int = 4):
    """
    df       : long-format table with [param_col, frame_col, value_col]
    param_col: parameter id (e.g. K, or a code for hyperparam combo)
    frame_col: bootstrap index (B)
    value_col: ARI (or other stability metric)
    layout   : optional 2D array telling where each parameter goes in the grid.
               If None, parameters are laid out in a single row.
    outfile  : GIF path
    fps      : frames per second
    """

    # Unique params and frames
    params = np.sort(df[param_col].unique())
    frames = np.sort(df[frame_col].unique())

    # Default layout: 1 x P row
    if layout is None:
        layout = params.reshape(1, -1)

    n_rows, n_cols = layout.shape

    # Map param -> index for quick lookup
    param_to_idx = {p: i for i, p in enumerate(params)}

    # Precompute running means for each frame
    running_means = {}
    for b in frames:
        sub = df[df[frame_col] <= b]
        means = sub.groupby(param_col)[value_col].mean()
        running_means[b] = means

    # Set up figure
    fig, ax = plt.subplots(figsize=(4 + n_cols, 3 + n_rows), dpi=120)

    # Initial grid (all NaN → colormap midpoint)
    grid0 = np.full((n_rows, n_cols), np.nan)
    im = ax.imshow(grid0, origin="upper", vmin=0.0, vmax=1.0,
                   cmap="RdYlGn")  # red→yellow→green

    # Optional: show param labels inside cells
    texts = []
    for i in range(n_rows):
        for j in range(n_cols):
            p = layout[i, j]
            txt = ax.text(j, i, str(p),
                          ha="center", va="center", color="black",
                          fontsize=9)
            texts.append(txt)

    ax.set_xticks([])
    ax.set_yticks([])
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(f"Running mean {value_col.upper()}")

    def update(frame_idx):
        b = frames[frame_idx]
        means = running_means[b]

        grid = np.full((n_rows, n_cols), np.nan)
        for i in range(n_rows):
            for j in range(n_cols):
                p = layout[i, j]
                if p in means.index:
                    grid[i, j] = means.loc[p]

        im.set_data(grid)
        ax.set_title(f"Stability over bootstraps (up to B = {b})")

        return (im, *texts)

    anim = FuncAnimation(fig, update,
                         frames=len(frames),
                         interval=1000 / fps,
                         blit=False)

    writer = PillowWriter(fps=fps)
    anim.save(outfile, writer=writer)
    plt.close(fig)
    print(f"Saved animation to {outfile}")

df_boot = load_boot_df("boot_ari_all_snf_pam.csv",
                       param_col="K",
                       frame_col="B",
                       value_col="ari")

animate_param_stability(df_boot,
                        param_col="K",
                        frame_col="B",
                        value_col="ari",
                        layout=None,  # 1×P row
                        outfile="ukb_lca_K_stability.gif",
                        fps=30)


Saved animation to ukb_lca_K_stability.gif


In [26]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

def animate_param_stability(df,
                            param_col: str = "K",
                            frame_col: str = "B",
                            value_col: str = "ari",
                            layout: np.ndarray | None = None,
                            outfile: str = "param_stability.gif",
                            fps: int = 4,
                            dynamic_clim: bool = True,
                            gamma: float | None = 0.7):
    """
    df           : long-format table with [param_col, frame_col, value_col]
    param_col    : parameter id (e.g. K, or hyperparameter combo id)
    frame_col    : bootstrap index (B)
    value_col    : stability metric (e.g. ARI)
    layout       : optional 2D array telling where each parameter goes in grid
    outfile      : GIF path
    fps          : frames per second
    dynamic_clim : if True, rescale color limits per frame to maximise contrast
    gamma        : if not None, apply value -> value**gamma to boost contrast.
                   Use gamma < 1 (e.g. 0.7) to exaggerate differences.
    """

    # Unique params and frames
    params = np.sort(df[param_col].unique())
    frames = np.sort(df[frame_col].unique())

    # Default layout: 1 x P row
    if layout is None:
        layout = params.reshape(1, -1)

    n_rows, n_cols = layout.shape

    # Precompute running means for each frame
    running_means = {}
    for b in frames:
        sub = df[df[frame_col] <= b]
        means = sub.groupby(param_col)[value_col].mean()
        running_means[b] = means

    # Global min/max across *all* running means (for reference)
    all_vals = np.concatenate([
        running_means[b].values for b in frames
    ])
    global_min = np.nanmin(all_vals)
    global_max = np.nanmax(all_vals)
    if global_max - global_min < 1e-9:
        global_max = global_min + 1e-3

    # Figure setup
    fig, ax = plt.subplots(figsize=(4 + n_cols, 3 + n_rows), dpi=120)

    grid0 = np.full((n_rows, n_cols), np.nan)
    im = ax.imshow(grid0, origin="upper",
                   vmin=0.0, vmax=1.0, cmap="RdYlGn")

    texts = []
    for i in range(n_rows):
        for j in range(n_cols):
            p = layout[i, j]
            txt = ax.text(j, i, str(p),
                          ha="center", va="center",
                          color="black", fontsize=9)
            texts.append(txt)

    ax.set_xticks([])
    ax.set_yticks([])
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(f"Running mean {value_col.upper()} (scaled)")

    def transform_vals(grid_raw):
        """Apply optional gamma + scaling to [0,1]."""
        # Start from global 0–1 scale so colors still roughly map to ARI
        norm = (grid_raw - global_min) / (global_max - global_min)
        norm = np.clip(norm, 0.0, 1.0)

        if gamma is not None:
            # gamma < 1 exaggerates differences near 1.0
            norm = norm ** gamma

        return norm

    def update(frame_idx):
        b = frames[frame_idx]
        means = running_means[b]

        # Fill grid with running means
        grid_raw = np.full((n_rows, n_cols), np.nan)
        for i in range(n_rows):
            for j in range(n_cols):
                p = layout[i, j]
                if p in means.index:
                    grid_raw[i, j] = means.loc[p]

        # Transform to 0–1 and apply gamma boost
        grid = transform_vals(grid_raw)

        im.set_data(grid)

        if dynamic_clim:
            # Per-frame contrast stretch to keep things visually alive
            vmin = np.nanmin(grid)
            vmax = np.nanmax(grid)
            if vmax - vmin < 1e-6:
                # tiny range -> add a bit of spread so colors don’t collapse
                pad = 0.02
                vmin = max(0.0, vmin - pad)
                vmax = min(1.0, vmax + pad)
            im.set_clim(vmin, vmax)
        else:
            im.set_clim(0.0, 1.0)

        ax.set_title(f"Stability over bootstraps (up to B = {b})")
        return (im, *texts)

    anim = FuncAnimation(fig, update,
                         frames=len(frames),
                         interval=1000 / fps,
                         blit=False)
    writer = PillowWriter(fps=fps)
    anim.save(outfile, writer=writer)
    plt.close(fig)
    print(f"Saved animation to {outfile}")

    df_boot = load_boot_df("boot_ari_all_snf_pam.csv",
                       param_col="K",
                       frame_col="B",
                       value_col="ari")

animate_param_stability(
    df_boot,
    param_col="K",
    frame_col="B",
    value_col="ari",
    layout=None,                     # or your custom 2D layout
    outfile="ukb_lca_K_stability_v2.gif",
    fps=60,                           # a bit faster now
    dynamic_clim=True,               # per-frame contrast
    gamma=0.7                        # exaggerate differences
)



Saved animation to ukb_lca_K_stability_v2.gif


In [29]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.manifold import spectral_embedding
from sklearn.metrics import pairwise_distances, adjusted_rand_score

# Make sure Python can see your project modules
ROOT = Path("/Users/harisreedeth/Desktop/D/personal/ProjectMAIP")
sys.path.append(str(ROOT))

from src.snf_lite import (
    load_view,
    align_views,
    gower_affinity,
    rbf_affinity,
    snf_fuse,
)
from src.run_mmsp_phase1_pam import pam_kmedoids_best_of_n

def drop_low_signal(df: pd.DataFrame) -> pd.DataFrame:
    keep = []
    for c in df.columns:
        s = df[c]
        if s.nunique(dropna=True) <= 1:     # constant
            continue
        if set(s.dropna().unique()) <= {0, 1}:
            p = float(s.mean())
            if p < 0.005 or p > 0.995:      # ultra-rare/common dummy
                continue
        keep.append(c)
    return df[keep].copy()


In [30]:
from pathlib import Path
import pandas as pd
import numpy as np

ROOT = Path("/Users/harisreedeth/Desktop/D/personal/ProjectMAIP")
PROC = ROOT / "data" / "01_processed"
CLUS = ROOT / "data" / "02_clusters"

# Views
C = pd.read_csv(PROC / "C_view.csv")
P = pd.read_csv(PROC / "P_view_scaled.csv")
S = pd.read_csv(PROC / "S_view.csv")

# MMSP / SNF labels
L_mmsp = pd.read_csv(CLUS / "mmsp_clusters.csv")   # has 'eid', 'stratum', 'label'
L_snf  = pd.read_csv(CLUS / "snf_clusters_all.csv")

strata_ids = {
    s: L_mmsp.loc[L_mmsp["stratum"] == s, "eid"].unique()
    for s in ["Low_MM", "Mid_MM", "High_MM"]
}

def drop_low_signal(df: pd.DataFrame) -> pd.DataFrame:
    keep = []
    for c in df.columns:
        s = df[c]
        if s.nunique(dropna=True) <= 1:
            continue
        if set(s.dropna().unique()) <= {0, 1}:
            p = float(s.mean())
            if p < 0.005 or p > 0.995:
                continue
        keep.append(c)
    return df[keep].copy()

def build_fused_for_stratum(stratum: str, knn: int = 25, iters: int = 10, alpha: float = 0.5):
    # Load + align views once
    ids_c, Xc = load_view(PROC / "C_view.csv", "eid")
    ids_p, Xp = load_view(PROC / "P_view_scaled.csv", "eid")
    ids_s, Xs = load_view(PROC / "S_view.csv", "eid")
    ids_all, (Xc, Xp, Xs) = align_views([ids_c, ids_p, ids_s], [Xc, Xp, Xs])

    # Restrict to this stratum using L_mmsp
    keep = np.intersect1d(ids_all, strata_ids[stratum])
    sel = np.isin(ids_all, keep)

    ids = ids_all[sel]
    Xc = drop_low_signal(Xc.iloc[sel].reset_index(drop=True))
    Xp = Xp.iloc[sel].reset_index(drop=True)
    Xs = drop_low_signal(Xs.iloc[sel].reset_index(drop=True))

    # Affinities and SNF
    A_c = gower_affinity(Xc)
    A_s = gower_affinity(Xs)
    A_p = rbf_affinity(Xp, k_local=7)
    fused = snf_fuse([A_c, A_p, A_s], k=knn, iters=iters, alpha=alpha)
    return fused, ids


In [31]:
def bootstrap_ari_grid(D: np.ndarray, K: int, B: int = 50, seed: int = 42):
    """
    D : pairwise distance matrix (N x N)
    K : number of clusters
    B : number of bootstrap replicates
    Returns: list of ARIs (length B)
    """
    rng = np.random.RandomState(seed)
    N = D.shape[0]

    # Reference clustering on full data
    base_labels, _ = pam_kmedoids_best_of_n(D, K, n_init=20, seed=seed)

    aris = []
    for b in range(1, B + 1):
        idx = rng.choice(N, size=N, replace=True)
        D_b = D[np.ix_(idx, idx)]
        # fewer inits is okay inside bootstrap
        lab_b, _ = pam_kmedoids_best_of_n(D_b, K, n_init=5, seed=rng.randint(1_000_000))
        ari_b = adjusted_rand_score(base_labels[idx], lab_b)
        aris.append(ari_b)
    return aris


Ks = [3, 4, 5, 6, 7, 8]           # or range(3, 9)
B = 50                            # or whatever you used
strata_order = ["Low_MM", "Mid_MM", "High_MM"]

rows = []

for stratum in strata_order:
    print(f"=== Bootstrapping SNF-lite for {stratum} ===")
    fused, ids = build_fused_for_stratum(stratum)
    N = len(ids)

    # One spectral embedding, reuse for all K
    emb = spectral_embedding(
        adjacency=fused,
        n_components=max(max(Ks), 8),
        random_state=42,
        eigen_solver="arpack",
        drop_first=True,
    )

    for K in Ks:
        print(f"  K = {K}")
        Z = emb[:, :K]
        D = pairwise_distances(Z, metric="euclidean")

        aris = bootstrap_ari_grid(D, K=K, B=B, seed=42)
        for b, ari_b in enumerate(aris, start=1):
            rows.append(
                {
                    "stratum": stratum,
                    "K": K,
                    "B": b,
                    "ari": ari_b,
                }
            )

df_boot_all = pd.DataFrame(rows)
df_boot_all.head()


out_csv = ROOT / "reports" / "tables" / "snf_lite_boot_ari_by_stratum.csv"
out_csv.parent.mkdir(parents=True, exist_ok=True)
df_boot_all.to_csv(out_csv, index=False)
print("Saved:", out_csv)


=== Bootstrapping SNF-lite for Low_MM ===
  K = 3
  K = 4
  K = 5
  K = 6
  K = 7
  K = 8
=== Bootstrapping SNF-lite for Mid_MM ===
  K = 3
  K = 4
  K = 5
  K = 6
  K = 7
  K = 8
=== Bootstrapping SNF-lite for High_MM ===
  K = 3
  K = 4
  K = 5
  K = 6
  K = 7
  K = 8
Saved: /Users/harisreedeth/Desktop/D/personal/ProjectMAIP/reports/tables/snf_lite_boot_ari_by_stratum.csv


In [32]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

PORTF_BG = "#f7f1e8"

def animate_param_stability_by_stratum(
    df,
    stratum_col="stratum",
    param_col="K",
    frame_col="B",
    value_col="ari",
    strata_order=("Low_MM", "Mid_MM", "High_MM"),
    layout=None,
    outfile="snf_lite_K_stability_by_stratum.gif",
    fps=8,
    dynamic_clim=True,
    gamma=0.7,
):
    params = np.sort(df[param_col].unique())
    frames = np.sort(df[frame_col].unique())

    # One horizontal row of Ks by default
    if layout is None:
        layout = params.reshape(1, -1)

    n_rows_param, n_cols = layout.shape
    n_strata = len(strata_order)

    # Precompute running means per (stratum, B)
    running_means = {}
    for s in strata_order:
        running_means[s] = {}
        df_s = df[df[stratum_col] == s]
        for b in frames:
            sub = df_s[df_s[frame_col] <= b]
            means = sub.groupby(param_col)[value_col].mean()
            running_means[s][b] = means

    # Global min/max for scaling
    all_vals = []
    for s in strata_order:
        for b in frames:
            all_vals.extend(running_means[s][b].values)
    all_vals = np.array(all_vals)
    global_min = np.nanmin(all_vals)
    global_max = np.nanmax(all_vals)
    if global_max - global_min < 1e-9:
        global_max = global_min + 1e-3

    def transform_vals(grid_raw):
        norm = (grid_raw - global_min) / (global_max - global_min)
        norm = np.clip(norm, 0.0, 1.0)
        if gamma is not None:
            norm = norm ** gamma
        return norm

    # Figure: one row per stratum
    fig, axes = plt.subplots(
        n_strata, 1,
        figsize=(4 + n_cols, 1.8 * n_strata),
        dpi=120,
        constrained_layout=True,
    )

    if n_strata == 1:
        axes = [axes]

    ims = []
    text_grid = []

    for ax, s in zip(axes, strata_order):
        ax.set_facecolor(PORTF_BG)
        grid0 = np.full((n_rows_param, n_cols), np.nan)
        im = ax.imshow(grid0, origin="upper", vmin=0.0, vmax=1.0, cmap="RdYlGn")
        ims.append(im)

        texts = []
        for i in range(n_rows_param):
            for j in range(n_cols):
                p = layout[i, j]
                txt = ax.text(
                    j, i, str(p),
                    ha="center", va="center",
                    color="black", fontsize=9,
                )
                texts.append(txt)
        text_grid.append(texts)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(s.replace("_", " "), fontsize=10)

    # One shared colorbar on the right
    cbar = fig.colorbar(ims[0], ax=axes, fraction=0.035, pad=0.02)
    cbar.set_label(f"Running mean {value_col.upper()} (scaled)")

    def update(frame_idx):
        b = frames[frame_idx]
        for ax, im, texts, s in zip(axes, ims, text_grid, strata_order):
            means = running_means[s][b]
            grid_raw = np.full((n_rows_param, n_cols), np.nan)
            for i in range(n_rows_param):
                for j in range(n_cols):
                    p = layout[i, j]
                    if p in means.index:
                        grid_raw[i, j] = means.loc[p]

            grid = transform_vals(grid_raw)
            im.set_data(grid)

            if dynamic_clim:
                vmin = np.nanmin(grid)
                vmax = np.nanmax(grid)
                if vmax - vmin < 1e-6:
                    pad = 0.02
                    vmin = max(0.0, vmin - pad)
                    vmax = min(1.0, vmax + pad)
                im.set_clim(vmin, vmax)
            else:
                im.set_clim(0.0, 1.0)

        fig.suptitle(f"SNF-lite stability over bootstraps (up to B = {b})", fontsize=12)
        return tuple(ims) + tuple(t for texts in text_grid for t in texts)

    anim = FuncAnimation(fig, update, frames=len(frames), interval=1000 / fps, blit=False)
    writer = PillowWriter(fps=fps)
    anim.save(outfile, writer=writer)
    plt.close(fig)
    print("Saved animation to", outfile)


In [34]:
Ks = np.array(sorted(df_boot_all["K"].unique()))
layout = Ks.reshape(1, -1)

animate_param_stability_by_stratum(
    df_boot_all,
    stratum_col="stratum",
    param_col="K",
    frame_col="B",
    value_col="ari",
    strata_order=["Low_MM", "Mid_MM", "High_MM"],
    layout=layout,
    outfile=str(ROOT / "reports" / "figures" / "snf_lite_K_stability_by_stratum.gif"),
    fps=4,
    dynamic_clim=True,
    gamma=0.7,
)


Saved animation to /Users/harisreedeth/Desktop/D/personal/ProjectMAIP/reports/figures/snf_lite_K_stability_by_stratum.gif


In [44]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

PORTF_BG = "#f7f1e8"

def animate_param_stability_by_stratum(
    df,
    strata_order=("Low_MM", "Mid_MM", "High_MM"),
    param_col="K",
    frame_col="B",
    value_col="ari",
    outfile="snf_lite_K_stability_by_stratum.gif",
    fps=10,
    gamma=0.7,
):
    """
    df : long-format table with [stratum, param_col, frame_col, value_col]
    """

    params = np.sort(df[param_col].unique())
    frames = np.sort(df[frame_col].unique())
    layout = params.reshape(1, -1)      # one row per strip
    n_cols = layout.shape[1]
    n_rows = len(strata_order)

    # --------- pre-compute running means ----------
    running_means = {}
    for s in strata_order:
        dfs = df[df["stratum"] == s]
        for b in frames:
            sub = dfs[dfs[frame_col] <= b]
            means = sub.groupby(param_col)[value_col].mean()
            running_means[(s, b)] = means

    all_vals = np.concatenate(
        [running_means[(s, b)].values for s in strata_order for b in frames]
    )
    global_min = np.nanmin(all_vals)
    global_max = np.nanmax(all_vals)
    if global_max - global_min < 1e-9:
        global_max = global_min + 1e-3

    def transform_vals(raw):
        norm = (raw - global_min) / (global_max - global_min)
        norm = np.clip(norm, 0.0, 1.0)
        if gamma is not None:
            norm = norm ** gamma
        return norm

    # --------- figure & axes (joined strips) ----------
    fig, axes = plt.subplots(
        n_rows,
        1,
        sharex=True,
        figsize=(15, 1.4 * n_rows + 2),
        dpi=120,
    )
    if n_rows == 1:
        axes = [axes]

    fig.patch.set_facecolor(PORTF_BG)
    for ax in axes:
        ax.set_facecolor(PORTF_BG)

    # tighten vertical spacing so strips visually touch
    fig.subplots_adjust(left=0.18, right=0.88, top=0.93, bottom=0.12, hspace=0.03)

    ims = []
    text_grid = []

    for i, strat in enumerate(strata_order):
        ax = axes[i]

        # initial empty grid
        grid0 = np.full((1, n_cols), np.nan)
        im = ax.imshow(
            grid0,
            origin="upper",
            vmin=0.0,
            vmax=1.0,
            cmap="RdYlGn",
            aspect="auto",
        )
        ims.append(im)

        # static K labels (these never move)
        row_texts = []
        for j, k in enumerate(params):
            t = ax.text(
                j,
                0,
                f"{int(k)}",
                ha="center",
                va="center",
                fontsize=11,
                color="black",
            )
            row_texts.append(t)
        text_grid.append(row_texts)

        ax.set_xticks(np.arange(n_cols))
        ax.set_xticklabels([])   # we only show K inside the cells
        ax.set_yticks([])

        # static stratum label on the left
        ax.set_ylabel(
            strat.replace("_", " "),
            rotation=0,
            ha="right",
            va="center",
            labelpad=25,
            fontsize=11,
        )

        # remove axis frame
        for spine in ax.spines.values():
            spine.set_visible(False)

    # shared colorbar on the right
    cbar = fig.colorbar(ims[0], ax=axes, fraction=0.03, pad=0.02)
    cbar.set_label("Running mean ARI (scaled)")

    # single global title that we update, but it does NOT move the axes
    time_text = fig.text(
        0.5,
        0.97,
        "",
        ha="center",
        va="top",
        fontsize=13,
    )

    # --------- animation update ----------
    def update(frame_idx):
        b = frames[frame_idx]
        time_text.set_text(f"SNF-lite stability over bootstraps (up to B = {b})")

        for i, strat in enumerate(strata_order):
            means = running_means[(strat, b)]
            grid_raw = np.full((1, n_cols), np.nan)
            for j, k in enumerate(params):
                if k in means.index:
                    grid_raw[0, j] = means.loc[k]

            grid = transform_vals(grid_raw)
            ims[i].set_data(grid)
            # keep a fixed color scale so the colorbar is stable
            ims[i].set_clim(0.0, 1.0)

        # return all artists that change
        artists = ims + [time_text]
        for row in text_grid:
            artists.extend(row)
        return artists

    anim = FuncAnimation(
        fig,
        update,
        frames=len(frames),
        interval=1000 / fps,
        blit=False,
    )
    writer = PillowWriter(fps=fps)
    anim.save(outfile, writer=writer)
    plt.close(fig)
    print(f"Saved animation to {outfile}")


In [45]:
animate_param_stability_by_stratum(
    df_boot_all,
    strata_order=["Low_MM", "Mid_MM", "High_MM"],   # order of strips
    param_col="K",
    frame_col="B",
    value_col="ari",
    outfile=str(ROOT / "reports" / "figures" / "snf_lite_K_stability_by_stratum.gif"),
    fps=4,
    gamma=0.7,
)


Saved animation to /Users/harisreedeth/Desktop/D/personal/ProjectMAIP/reports/figures/snf_lite_K_stability_by_stratum.gif
