In [None]:
# ============================================
# PLS-BASED AGE MODEL  (point+errorbar line plots, distinct colors)
# ============================================

import os, warnings, gc
warnings.filterwarnings("ignore")

import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from scipy import stats

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import Ridge

from statsmodels.stats.multitest import multipletests
from scipy.stats import kendalltau, spearmanr

# -------------------------
# Config
# -------------------------
H5AD_FOLDER   = "/mnt/data/melhajjar/tabula_muris/all_tissues/facs_h5ad"
OUTPUT_FOLDER = "/mnt/data/melhajjar/tabula_muris/all_tissues/scanvi_age_features_regression_model_facs_datasets"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

CELLTYPES_OF_INTEREST = [
    "alveolar macrophage", "macrophage", "monocyte", "classical monocyte",
    "intermediate monocyte", "non-classical monocyte", "Kupffer cell", "lung macrophage",
    "macrophage dendritic cell progenitor", "promonocyte"
]

# Knobs
N_TOP_HVG          = 2000
N_COMPONENTS_MAX   = 20
TOP_AGE_COMPS      = 5
N_PLOT_GENES       = 50
TOP_LINES          = 10
FIG_DPI            = 300
RANDOM_SEED        = 0

# Stability selection
STAB_N_RUNS        = 25
STAB_SAMPLE_FRAC   = 0.8
STAB_MIN_FREQ      = 0.3
W_QUANTILE         = 0.75
S_QUANTILE         = 0.50

# Scatter point size (bigger)
SCATTER_POINT_SIZE = 56

rng = np.random.RandomState(RANDOM_SEED)

# -------------------------
# Small helpers
# -------------------------
def order_ages_numeric(age_series: pd.Series):
    def to_num(a):
        s = str(a).strip().lower()
        if s.endswith("m") and s[:-1].isdigit(): return int(s[:-1])
        if s.isdigit(): return int(s)
        return np.nan
    labels = list(pd.unique(age_series.astype(str)))
    nums = [to_num(a) for a in labels]
    if all([not np.isnan(x) for x in nums]):
        return [x for _, x in sorted(zip(nums, labels), key=lambda t: t[0])]
    return labels

def numeric_age_array(age_labels):
    def to_num(a):
        s = str(a).strip().lower()
        if s.endswith("m") and s[:-1].isdigit(): return int(s[:-1])
        if s.isdigit(): return int(s)
        return np.nan
    return np.array([to_num(a) for a in age_labels], dtype=float)

def cohen_d(a, b):
    a, b = np.asarray(a, dtype=float), np.asarray(b, dtype=float)
    na, nb = len(a), len(b)
    sa, sb = np.nanstd(a, ddof=1), np.nanstd(b, ddof=1)
    s_p = np.sqrt(((na-1)*sa**2 + (nb-1)*sb**2) / (na + nb - 2))
    if s_p == 0 or np.isnan(s_p): return np.nan
    return (np.nanmean(a) - np.nanmean(b)) / s_p

# -------------------------
# Plotting helpers
# -------------------------
def plot_component_heatmap(CxG: pd.DataFrame, title: str, out_png: str, vmin=None, vmax=None):
    n_c, n_g = CxG.shape
    fig_w = max(6, min(20, n_g * 0.4)); fig_h = max(4, min(18, n_c * 0.5))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    im = ax.imshow(
        CxG.values, aspect="auto", interpolation="nearest",
        cmap="coolwarm", norm=TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    )
    ax.set_xticks(np.arange(n_g)); ax.set_xticklabels(CxG.columns, rotation=90, fontsize=7)
    ax.set_yticks(np.arange(n_c)); ax.set_yticklabels(CxG.index, fontsize=9)
    ax.set_xlabel("Genes"); ax.set_ylabel("PLS component"); ax.set_title(title, fontsize=12)
    cbar = plt.colorbar(im, ax=ax); cbar.set_label("Component weight", rotation=270, labelpad=12)
    ax.set_xticks(np.arange(-.5, n_g, 1), minor=True); ax.set_yticks(np.arange(-.5, n_c, 1), minor=True)
    ax.grid(which="minor", linestyle="-", linewidth=0.25, alpha=0.25)
    ax.tick_params(which="minor", bottom=False, left=False)
    plt.tight_layout()
    plt.savefig(out_png, dpi=FIG_DPI, bbox_inches="tight", pad_inches=0.15)
    plt.close(fig)

def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **imshow_kwargs):
    if ax is None: ax = plt.gca()
    if cbar_kw is None: cbar_kw = {}
    im = ax.imshow(data, **imshow_kwargs)
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
    ax.set_xticks(range(data.shape[1]), labels=col_labels, rotation=-30, ha="right", rotation_mode="anchor")
    ax.set_yticks(range(data.shape[0]), labels=row_labels)
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
    ax.spines[:].set_visible(False)
    ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=2)
    ax.tick_params(which="minor", bottom=False, left=False)
    return im, cbar

def plot_expr_heatmap_helper(expr_df, title, out_png, zscore=True,
                             cmap="YlGn", title_fs=18, label_fs=16, tick_fs=11):
    M = expr_df.values.copy()
    if zscore:
        mu = np.nanmean(M, axis=1, keepdims=True)
        sd = np.nanstd(M, axis=1, keepdims=True); sd[sd == 0] = 1.0
        M = (M - mu) / sd
    n_genes, n_ages = M.shape
    fig_w = max(6, min(18, n_ages * 0.9)); fig_h = max(6, min(24, n_genes * 0.45))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    im, cbar = heatmap(
        M, row_labels=list(expr_df.index), col_labels=list(expr_df.columns),
        ax=ax, cmap=cmap, cbarlabel=("Row Z-score" if zscore else "Mean log1p expr"),
        aspect="auto", interpolation="nearest"
    )
    ax.set_xlabel("Age", fontsize=label_fs, labelpad=12)
    ax.set_ylabel("Genes", fontsize=label_fs, labelpad=12)
    ax.set_title(title, fontsize=title_fs)
    ax.tick_params(axis="x", labelsize=tick_fs)
    ax.tick_params(axis="y", labelsize=max(8, tick_fs - 1))
    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight", pad_inches=0.15)
    plt.close(fig)

def plot_expr_points_with_errorbars(expr_df: pd.DataFrame,
                                    sem_df: pd.DataFrame | None,
                                    title: str,
                                    out_png: str,
                                    top_n: int = 8,
                                    ci_mult: float = 1.96,
                                    marker_size: int = 7,
                                    line_width: float = 2.0,
                                    cap_size: float = 4.0):
    """
    Mean ¬± error bars at each age (no ribbons).
    Uses distinct, high-contrast colors per gene.
    """
    if expr_df.shape[0] == 0:
        return

    # Keep top_n most variable genes to avoid clutter
    if expr_df.shape[0] > top_n:
        var = expr_df.var(axis=1)
        expr_df = expr_df.loc[var.sort_values(ascending=False).head(top_n).index]
        if sem_df is not None:
            sem_df = sem_df.loc[expr_df.index]

    ages = list(expr_df.columns)

    def to_num(a):
        s = str(a).strip().lower()
        if s.endswith("m") and s[:-1].isdigit(): return int(s[:-1])
        if s.isdigit(): return int(s)
        return np.nan
    xnum = [to_num(a) for a in ages]
    x = np.arange(len(ages)) if any(np.isnan(x) for x in xnum) else np.array(xnum)

    from itertools import chain
    tab10 = plt.cm.get_cmap("tab10").colors
    tab20 = plt.cm.get_cmap("tab20").colors
    extras = (
        "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
        "#4c72b0", "#dd8452", "#55a868", "#c44e52", "#8172b3", "#937860"
    )
    palette = list(chain(tab10, tab20, extras))

    plt.rcParams.update({
        "axes.spines.top": False,
        "axes.spines.right": False,
        "legend.frameon": False,
        "grid.linestyle": ":",
        "grid.alpha": 0.35,
    })

    fig, ax = plt.subplots(figsize=(10, 6))
    for i, (gene, row) in enumerate(expr_df.iterrows()):
        y = row.values.astype(float)
        # Determine error bars
        if sem_df is not None:
            ysem = sem_df.loc[gene, ages].values.astype(float)
            yerr = ci_mult * ysem
        else:
            yerr = np.full_like(y, fill_value=np.nanstd(y, ddof=1))

        color = palette[i % len(palette)]
        ax.errorbar(
            x, y, yerr=yerr,
            fmt='-o', markersize=marker_size, linewidth=line_width,
            elinewidth=1.2, capsize=cap_size, capthick=1.2,
            color=color, ecolor=color, alpha=0.98
        )

        # Legend label with Spearman over means-per-age (compact indicator)
        try:
            rho, _ = spearmanr(x, y, nan_policy="omit")
            label = f"{gene} (œÅ={rho:.2f})"
        except Exception:
            label = gene
        ax.plot([], [], color=color, label=label)

    ax.set_xticks(x); ax.set_xticklabels(ages, rotation=0)
    ax.set_xlabel("Age")
    ax.set_ylabel("Mean log1p expression per age (¬± CI)")
    ax.set_title(title)
    ax.grid(True, axis="y")
    ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), fontsize=9)

    fig.tight_layout()
    fig.savefig(out_png, dpi=FIG_DPI, bbox_inches="tight", pad_inches=0.15)
    plt.close(fig)

def plot_loading_vs_corr_scatter(
    gene_perf_df,
    x_col="component_maxabs_loading",
    y_col="spearman_age",
    title="Genes: |PLS weight| vs |Spearman(age)|",
    out_path="gene_scatter_weight_vs_corr.pdf",
    top_labels=12,
    annotate_by="xyprod",
    point_alpha=0.9,
    point_size=SCATTER_POINT_SIZE,
    dpi=300
):
    x_raw = gene_perf_df[x_col].astype(float).abs().to_numpy()
    y_signed = gene_perf_df[y_col].astype(float).to_numpy()
    y_abs = np.abs(y_signed)
    genes = gene_perf_df["gene"].astype(str).to_numpy()

    pos_mask = y_signed >= 0
    colors = np.where(pos_mask, "#2875d8", "#d84545")

    fig = plt.figure(figsize=(8.6, 6.2), dpi=dpi, constrained_layout=True)
    ax = fig.add_subplot(111)
    ax.scatter(x_raw, y_abs, s=point_size, alpha=point_alpha,
               linewidths=0.6, edgecolors="white", c=colors)
    ax.scatter([], [], c="#2875d8", s=point_size, label="‚Üë with age (Spearman ‚â• 0)")
    ax.scatter([], [], c="#d84545", s=point_size, label="‚Üì with age (Spearman < 0)")
    ax.legend(frameon=False, loc="lower right", fontsize=9)

    def _q(a):
        lo, hi = np.nanpercentile(a, [1, 99])
        pad = 0.04 * (hi - lo if hi > lo else (np.nanmax(a) - np.nanmin(a) + 1e-9))
        return lo - pad, hi + pad

    xlo, xhi = _q(x_raw); ylo, yhi = _q(y_abs)
    ax.set_xlim(xlo, xhi); ax.set_ylim(ylo, yhi)

    m = np.isfinite(x_raw) & np.isfinite(y_abs)
    if m.sum() >= 3:
        slope, intercept, r, p, _ = stats.linregress(x_raw[m], y_abs[m])
        xx = np.linspace(xlo, xhi, 200)
        ax.plot(xx, slope*xx + intercept, linewidth=2.0, alpha=0.95, color="#444")
        ax.text(0.02, 0.98, f"$r$ = {r:.2f}, $p$ = {p:.1g}",
                transform=ax.transAxes, ha="left", va="top", fontsize=11)

    if m.sum() >= 10:
        for q, ls in zip(np.quantile(y_abs[m], [0.25, 0.5, 0.75]), [":", "--", ":"]):
            if ylo < q < yhi:
                ax.axhline(q, color="gray", lw=1.0, ls=ls, alpha=0.6)

    if top_labels and top_labels > 0:
        score = x_raw * y_abs if annotate_by == "xyprod" else (x_raw if annotate_by == "x" else y_abs)
        order = np.argsort(score)[::-1][:top_labels]
        for k, i in enumerate(order):
            dx = 0.01 * (1 + (k % 3)) * (xhi - xlo)
            dy = 0.01 * (1 + ((k+1) % 3)) * (yhi - ylo)
            tx = np.clip(x_raw[i] + dx, xlo + 0.005*(xhi-xlo), xhi - 0.005*(xhi-xlo))
            ty = np.clip(y_abs[i] + dy, ylo + 0.005*(yhi-ylo), yhi - 0.005*(yhi-ylo))
            ax.text(tx, ty, genes[i], fontsize=9, ha="left", va="bottom",
                    bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.9),
                    clip_on=False)

    ax.set_xlabel("|PLS weight (best age component)|", fontsize=12, labelpad=8)
    ax.set_ylabel("|Spearman corr with continuous age| (ALL)", fontsize=12, labelpad=8)
    ax.set_title(title, fontsize=13, pad=10)
    ax.spines["top"].set_visible(False); ax.spines["right"].set_visible(False)
    ax.grid(True, which="major", linestyle=":", linewidth=0.8, alpha=0.5)
    ax.tick_params(axis="both", labelsize=11)
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.15)
    plt.close(fig)

# -------------------------
# Nicer age distribution plots (PDF)
# -------------------------
def _save_pdf_and_png(base_path_no_ext):
    pdf = base_path_no_ext + ".pdf"
    png = base_path_no_ext + ".png"
    plt.savefig(pdf, dpi=300, bbox_inches="tight", pad_inches=0.15)
    plt.savefig(png, dpi=200, bbox_inches="tight", pad_inches=0.10)

def plot_age_counts_raw(age_strs, base_out_no_ext, title=None):
    vals, cnts = np.unique(age_strs.astype(str), return_counts=True)
    order = [a for a in order_ages_numeric(pd.Series(vals))]
    order_idx = [list(vals).index(a) for a in order]
    vals = vals[order_idx]; cnts = cnts[order_idx]
    total = cnts.sum()
    fig, ax = plt.subplots(figsize=(7.2, 4.0))
    bars = ax.bar(vals, cnts, width=0.7, color="#4C78A8", edgecolor="white", linewidth=0.6)
    ax.set_ylabel("Number of Cells", fontsize=11)
    if title: ax.set_title(title, fontsize=12)
    ax.set_axisbelow(True); ax.yaxis.grid(True, linestyle=":", alpha=0.6)
    ax.spines["top"].set_visible(False); ax.spines["right"].set_visible(False)
    for b, c in zip(bars, cnts):
        pct = 100.0 * c / total if total > 0 else 0.0
        ax.text(b.get_x() + b.get_width()/2, b.get_height() + max(total*0.01, 1),
                f"{c}\n({pct:.1f}%)", ha="center", va="bottom", fontsize=9)
    plt.xticks(vals, vals, rotation=0)
    plt.tight_layout(); _save_pdf_and_png(base_out_no_ext); plt.close()

def plot_age_counts_bins(num_age, base_out_no_ext, title):
    qs = np.quantile(num_age, [1/3, 2/3])
    labels = np.array(["young" if a <= qs[0] else ("middle" if a <= qs[1] else "old") for a in num_age])
    cats = ["young", "middle", "old"]
    counts = np.array([np.sum(labels == c) for c in cats], dtype=int)
    total = counts.sum()
    fig, ax = plt.subplots(figsize=(5.6, 3.8))
    bars = ax.bar(cats, counts, color="#72B7B2", edgecolor="white", linewidth=0.6)
    ax.set_ylabel("# cells", fontsize=11); ax.set_title(title, fontsize=12)
    ax.set_axisbelow(True); ax.yaxis.grid(True, linestyle=":", alpha=0.6)
    ax.spines["top"].set_visible(False); ax.spines["right"].set_visible(False)
    for b, c in zip(bars, counts):
        pct = 100.0 * c / total if total > 0 else 0.0
        ax.text(b.get_x() + b.get_width()/2, b.get_height() + max(total*0.01, 1),
                f"{c}\n({pct:.1f}%)", ha="center", va="bottom", fontsize=9)
    plt.tight_layout(); _save_pdf_and_png(base_out_no_ext); plt.close()

# -------------------------
# CV helpers, Kendall+FDR, etc.
# -------------------------
def stratified_folds_by_age(num_age, n_splits=5, seed=0):
    qs = np.quantile(num_age, [1/3, 2/3])
    bins = np.digitize(num_age, bins=qs, right=True)  # 0,1,2
    unique_bins = np.unique(bins)
    k = min(n_splits, len(unique_bins)); k = max(k, 2)
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
    return list(skf.split(np.zeros_like(num_age), bins)), bins

def cv_r2_pls(Xz, y, folds, n_comp):
    pls_cv = PLSRegression(n_components=n_comp, scale=False)
    scores = []
    for tr, te in folds:
        pls_cv.fit(Xz[tr], y[tr])
        yhat = pls_cv.predict(Xz[te]).ravel()
        ss_res = np.sum((y[te]-yhat)**2)
        ss_tot = np.sum((y[te]-np.mean(y[tr]))**2)
        scores.append(1 - ss_res/ss_tot if ss_tot>0 else np.nan)
    return float(np.nanmean(scores))

def perm_test_r2(Xz, y, folds, n_comp, n_perm=200, seed=0):
    rng = np.random.RandomState(seed)
    obs = cv_r2_pls(Xz, y, folds, n_comp)
    null = []
    for _ in range(n_perm):
        yperm = rng.permutation(y)
        null.append(cv_r2_pls(Xz, yperm, folds, n_comp))
    p = (1 + np.sum(np.array(null) >= obs)) / (1 + len(null))
    return obs, float(p)

def ridge_cv_r2(Xz, y, folds, alpha=1.0):
    scores = []
    for tr, te in folds:
        rr = Ridge(alpha=alpha)
        rr.fit(Xz[tr], y[tr])
        yhat = rr.predict(Xz[te])
        ss_res = np.sum((y[te]-yhat)**2)
        ss_tot = np.sum((y[te]-np.mean(y[tr]))**2)
        scores.append(1 - ss_res/ss_tot if ss_tot>0 else np.nan)
    return float(np.nanmean(scores))

def gene_age_kendall_fdr(X_mat, y, genes):
    taus, pvals = [], []
    for j in range(X_mat.shape[1]):
        t, p = kendalltau(y, X_mat[:, j], nan_policy="omit")
        taus.append(t); pvals.append(p)
    taus = np.array(taus); pvals = np.array(pvals)
    ok, p_adj, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh")
    out = pd.DataFrame({"gene": genes, "kendall_tau": taus, "p_adj": p_adj, "pass_fdr": ok})
    return out.set_index("gene")

# -------------------------
# Main
# -------------------------
rows = []

for file in os.listdir(H5AD_FOLDER):
    if not file.endswith(".h5ad"): continue
    tissue = file.replace(".h5ad", "")
    if tissue.lower() == "trachea": continue

    print(f"\nüß¨ Processing tissue: {tissue}")
    adata = sc.read_h5ad(os.path.join(H5AD_FOLDER, file))

    # Optional: male-only
    if "sex" in adata.obs.columns:
        adata = adata[adata.obs["sex"] != "female"].copy()

    if "cell_ontology_class" not in adata.obs.columns:
        print("    ‚ö†Ô∏è Missing 'cell_ontology_class' in .obs, skipping file.")
        continue

    for celltype in adata.obs["cell_ontology_class"].unique():
        if celltype not in CELLTYPES_OF_INTEREST: continue

        print(f"  üîç Cell type: {celltype}")
        adata_ct = adata[adata.obs["cell_ontology_class"] == celltype].copy()
        if adata_ct.n_obs < 50 or len(adata_ct.obs["age"].unique()) < 2:
            print("    ‚ö†Ô∏è Skipping due to low cell count or one age group.")
            continue

        # ---------- Preprocess ----------
        sc.pp.filter_genes(adata_ct, min_counts=3)
        sc.pp.normalize_total(adata_ct, target_sum=1e4)
        sc.pp.log1p(adata_ct)
        sc.pp.highly_variable_genes(adata_ct, n_top_genes=N_TOP_HVG, flavor="seurat_v3")
        if "highly_variable" in adata_ct.var and adata_ct.var["highly_variable"].sum() > 0:
            adata_ct = adata_ct[:, adata_ct.var["highly_variable"]].copy()
        else:
            print("    ‚ö†Ô∏è No HVGs found ‚Äî proceeding with all genes.")

        # ---------- Age labels ----------
        adata_ct.obs["age"] = adata_ct.obs["age"].astype(str)
        num_age = numeric_age_array(adata_ct.obs["age"].values)
        valid_idx = ~np.isnan(num_age)
        if valid_idx.sum() < 10 or np.unique(num_age[valid_idx]).size < 2:
            print("    ‚ö†Ô∏è Not enough numeric age signal; skipping.")
            continue
        adata_ct = adata_ct[valid_idx].copy()
        num_age = num_age[valid_idx]

        # --- Age distribution PDFs ---
        plot_age_counts_raw(
            adata_ct.obs["age"].values,
            os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_age_counts_raw"),
        )
        plot_age_counts_bins(
            num_age,
            os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_age_counts_bins"),
            title=f"{tissue} ‚Äì {celltype}: cells per tertile group"
        )

        # ---------- Matrix + standardization ----------
        X = adata_ct.X
        if hasattr(X, "toarray"): X = X.toarray()
        X = np.asarray(X, dtype=float)
        scaler = StandardScaler(with_mean=True, with_std=True)
        Xz = scaler.fit_transform(X)

        # ---------- Choose n_components via stratified CV ----------
        folds, age_bins = stratified_folds_by_age(num_age, n_splits=5, seed=RANDOM_SEED)
        k_max = min(N_COMPONENTS_MAX, Xz.shape[1], Xz.shape[0]-1)
        cand = list(range(2, max(3, k_max)+1))
        cv_scan = {k: cv_r2_pls(Xz, num_age, folds, k) for k in cand}
        pd.Series(cv_scan).to_csv(os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_cv_ncomp_scan.csv"))
        n_comp = max(cv_scan, key=cv_scan.get)

        # ---------- Final PLS fit ----------
        pls = PLSRegression(n_components=n_comp, scale=False)
        pls.fit(Xz, num_age)
        scores = pls.transform(Xz)
        pred_age = pls.predict(Xz).ravel()
        spearman_r = float(pd.Series(pred_age).corr(pd.Series(num_age), method="spearman"))

        comp_spearman = np.array([abs(pd.Series(scores[:, i]).corr(pd.Series(num_age), method="spearman"))
                                  for i in range(scores.shape[1])])
        idx_top = np.argsort(comp_spearman)[::-1][:min(TOP_AGE_COMPS, len(comp_spearman))]
        age_components = [f"comp_{i}" for i in idx_top]
        pd.Series(comp_spearman, index=[f"comp_{i}" for i in range(scores.shape[1])]) \
          .sort_values(ascending=False) \
          .to_csv(os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_component_age_importance.csv"))

        # Silhouette using all scores on tertile labels
        try:
            le = LabelEncoder()
            y_binned = le.fit_transform(age_bins)
            from sklearn.metrics import silhouette_score
            sil = silhouette_score(scores, y_binned)
        except Exception:
            sil = np.nan

        # ---------- Stability selection ----------
        genes = adata_ct.var_names.to_list()
        appear_counts = pd.Series(0, index=genes, dtype=int)
        mean_weight = pd.Series(0.0, index=genes, dtype=float)
        mean_spear  = pd.Series(0.0, index=genes, dtype=float)

        n_cells = Xz.shape[0]
        n_sample = max(20, int(STAB_SAMPLE_FRAC * n_cells))

        for run in range(STAB_N_RUNS):
            idx_cells = rng.choice(n_cells, size=n_sample, replace=False)
            Xr, yr = Xz[idx_cells], num_age[idx_cells]
            k_max_r = min(N_COMPONENTS_MAX, Xr.shape[1], Xr.shape[0]-1)
            k_r = max(2, min(n_comp, k_max_r))
            pls_r = PLSRegression(n_components=k_r, scale=False)
            pls_r.fit(Xr, yr)
            scores_r = pls_r.transform(Xr)

            comp_r = np.array([abs(pd.Series(scores_r[:, i]).corr(pd.Series(yr), method="spearman"))
                               for i in range(scores_r.shape[1])])
            best_idx_r = int(np.nanargmax(comp_r))

            w = pd.Series(np.abs(pls_r.x_weights_[:, best_idx_r]), index=genes)

            Xr_raw = X[idx_cells]
            spear_vals = []
            for j in range(Xr_raw.shape[1]):
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", category=RuntimeWarning)
                    sp = pd.Series(Xr_raw[:, j]).corr(pd.Series(yr), method="spearman")
                spear_vals.append(abs(sp) if sp is not None else np.nan)
            s = pd.Series(spear_vals, index=genes).fillna(0.0)

            w_thr = w.quantile(W_QUANTILE)
            s_thr = s.quantile(S_QUANTILE)
            sel = (w >= w_thr) & (s >= s_thr)
            appear_counts[sel] += 1
            mean_weight += w
            mean_spear  += s

            del pls_r, scores_r, Xr, yr, w, s, Xr_raw
            gc.collect()

        mean_weight /= STAB_N_RUNS
        mean_spear  /= STAB_N_RUNS
        freq = appear_counts / STAB_N_RUNS

        gene_strengths = pd.DataFrame({
            "stability_freq": freq,
            "mean_abs_pls_weight": mean_weight,
            "mean_abs_spearman_age": mean_spear
        })
        gene_strengths["combined_score"] = gene_strengths["mean_abs_pls_weight"] * gene_strengths["mean_abs_spearman_age"]
        gene_strengths.sort_values(["stability_freq", "combined_score"], ascending=[False, False], inplace=True)
        audit_path = os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_PLS_gene_strengths_full.csv")
        gene_strengths.to_csv(audit_path)

        # ---------- Gate by monotonicity (Kendall œÑ) + FDR ----------
        Xall_small = adata_ct[:, adata_ct.var_names].X
        if hasattr(Xall_small, "toarray"): Xall_small = Xall_small.toarray()
        ktab = gene_age_kendall_fdr(Xall_small, num_age, adata_ct.var_names.tolist())

        stable = gene_strengths.join(ktab, how="left")
        # enforce stability first
        stable = stable[stable["stability_freq"] >= STAB_MIN_FREQ]
        # then statistical monotonicity (fallback if empty)
        stable_sig = stable[(stable["pass_fdr"] == True) & stable["kendall_tau"].notna()]
        if not stable_sig.empty:
            stable = stable_sig

        age_driver_genes = stable.sort_values(
            ["combined_score", "kendall_tau" if "kendall_tau" in stable.columns else "combined_score"],
            ascending=False
        ).head(N_PLOT_GENES).index.tolist()

        pd.Series(age_driver_genes, name="gene").to_csv(
            os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_age_driver_genes_top{N_PLOT_GENES}.csv"),
            index=False
        )

        # ---------- Heatmap of weights ----------
        best_idx_global = int(np.nanargmax(comp_spearman))
        weights_df_full = pd.DataFrame(pls.x_weights_, index=adata_ct.var_names,
                                       columns=[f"comp_{i}" for i in range(pls.x_weights_.shape[1])])
        age_loading_sub = weights_df_full.iloc[:, [best_idx_global]]
        CxG = age_loading_sub.loc[age_driver_genes].T
        vmax = float(np.nanpercentile(np.abs(CxG.values), 99)) if CxG.size else None
        vmin = -vmax if vmax is not None else None
        plot_component_heatmap(
            CxG,
            title=f"{tissue} ‚Äì {celltype} ‚Äì PLS weights (best age comp; filtered top {N_PLOT_GENES})",
            out_png=os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_PLS_WEIGHTS_bestComp_AGEtop{N_PLOT_GENES}_heatmap.pdf"),
            vmin=vmin, vmax=vmax
        )

        # ---------- Expression vs age (means + SEM -> CI caps on points) ----------
        ages_order = order_ages_numeric(pd.Series(adata_ct.obs["age"]))
        ages_order = [a for a in ages_order if a in list(adata_ct.obs["age"].unique())]
        selected_genes = [g for g in age_driver_genes if g in adata_ct.var_names]

        expr_means = []
        expr_sems  = []
        for age in ages_order:
            idx_age = (adata_ct.obs["age"] == str(age))
            subX = adata_ct[idx_age, selected_genes].X
            if hasattr(subX, "toarray"): subX = subX.toarray()
            subX = np.asarray(subX)
            means = subX.mean(axis=0)
            sds   = subX.std(axis=0, ddof=1)
            n     = subX.shape[0]
            sems  = (sds / np.sqrt(n)) if n > 1 else np.zeros_like(means)
            expr_means.append(pd.Series(means, index=selected_genes, name=str(age)))
            expr_sems.append(pd.Series(sems,  index=selected_genes, name=str(age)))

        expr_df     = pd.DataFrame(expr_means).T
        expr_sem_df = pd.DataFrame(expr_sems).T.loc[expr_df.index]

        expr_df.to_csv(os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_ALL_mean_expression_AGEtop{N_PLOT_GENES}_genes_x_age.csv"))
        expr_sem_df.to_csv(os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_ALL_SEM_expression_AGEtop{N_PLOT_GENES}_genes_x_age.csv"))

        plot_expr_heatmap_helper(
            expr_df.loc[selected_genes, ages_order],
            title=f"{tissue} ‚Äì {celltype} ‚Äì Mean expression across ages (filtered top {N_PLOT_GENES})",
            out_png=os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_ALL_expr_heatmap_FILTERED_AGEtop{N_PLOT_GENES}_genes_x_age.pdf"),
            zscore=True, cmap="YlGn", title_fs=18, label_fs=16, tick_fs=11
        )

        plot_expr_points_with_errorbars(
            expr_df.loc[selected_genes, ages_order],
            expr_sem_df.loc[selected_genes, ages_order],
            title=f"{tissue} ‚Äì {celltype} ‚Äì Expression vs age (mean ¬± 95% CI at each time point; top {min(TOP_LINES, len(selected_genes))})",
            out_png=os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_ALL_expr_points_FILTERED_AGEtop{TOP_LINES}_withCIs.pdf"),
            top_n=TOP_LINES,
            ci_mult=1.96,
            marker_size=7,
            line_width=2.0,
            cap_size=4.0
        )

        # ---------- Gene-level stats + scatter ----------
        age_labels = adata_ct.obs["age"].astype(str).values
        uniq_ages = pd.Index(np.unique(age_labels)).astype(str).tolist()
        uniq_nums = numeric_age_array(uniq_ages)
        youngest = uniq_ages[int(np.nanargmin(uniq_nums))] if np.all(~np.isnan(uniq_nums)) else uniq_ages[0]
        oldest   = uniq_ages[int(np.nanargmax(uniq_nums))] if np.all(~np.isnan(uniq_nums)) else uniq_ages[-1]

        Xall = adata_ct[:, selected_genes].X
        if hasattr(Xall, "toarray"): Xall = Xall.toarray()
        maxabs_weight = age_loading_sub.abs().iloc[:, 0].to_dict()
        gene_rows = []
        for j, g in enumerate(selected_genes):
            gx = np.asarray(Xall[:, j], dtype=float)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                sp = pd.Series(gx).corr(pd.Series(num_age), method="spearman")
                pe = pd.Series(gx).corr(pd.Series(num_age), method="pearson")
            d = cohen_d(gx[age_labels == youngest], gx[age_labels == oldest])
            gene_rows.append({
                "gene": g,
                "spearman_age": sp,
                "pearson_age": pe,
                "cohen_d_young_vs_old": d,
                "component_maxabs_loading": float(maxabs_weight.get(g, np.nan))
            })
        gene_perf_df = pd.DataFrame(gene_rows).sort_values("spearman_age", key=lambda s: s.abs(), ascending=False)
        gene_perf_df.to_csv(os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_gene_performance_ALL.csv"), index=False)

        plot_loading_vs_corr_scatter(
            gene_perf_df,
            title=f"{tissue} ‚Äì {celltype} ‚Äì Genes: |PLS weight (best comp)| vs age-correlation",
            out_path=os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_gene_scatter_weightBEST_vs_corr.pdf"),
        )

        # ---------- Ridge baseline & permutation test ----------
        ridge_r2 = ridge_cv_r2(Xz, num_age, folds, alpha=1.0)
        r2_mean_cv, r2_perm_p = perm_test_r2(Xz, num_age, folds, n_comp, n_perm=200, seed=RANDOM_SEED)

        # ---------- QC markdown ----------
        qc_md = os.path.join(OUTPUT_FOLDER, f"{tissue}_{celltype}_QC.md")
        with open(qc_md, "w") as f:
            f.write(f"# {tissue} ‚Äì {celltype}\n")
            f.write(f"- n_cells: {adata_ct.n_obs}\n")
            f.write(f"- n_genes_used: {adata_ct.n_vars}\n")
            f.write(f"- best_n_components: {n_comp}\n")
            f.write(f"- CV R2 (PLS): {r2_mean_cv:.3f} (perm p={r2_perm_p:.3g})\n")
            f.write(f"- CV R2 (Ridge baseline): {ridge_r2:.3f}\n")
            f.write(f"- Spearman in-sample (PLS fit on all): {spearman_r:.3f}\n")
            f.write(f"- Silhouette (age tertiles on scores): {sil:.3f}\n")
            f.write(f"- Top age components (by |Spearman|): {', '.join(age_components)}\n")
            f.write(f"- Figures saved: "
                    f"age_counts_raw.pdf/png, age_counts_bins.pdf/png, "
                    f"PLS_WEIGHTS_bestComp_AGEtop{N_PLOT_GENES}_heatmap.pdf, "
                    f"ALL_expr_heatmap_FILTERED_AGEtop{N_PLOT_GENES}_genes_x_age.pdf, "
                    f"ALL_expr_points_FILTERED_AGEtop{TOP_LINES}_withCIs.pdf, "
                    f"gene_scatter_weightBEST_vs_corr.pdf\n")
            f.write(f"- Top genes (head): {', '.join(age_driver_genes[:10])}\n")

        rows.append({
            "tissue": tissue,
            "celltype": celltype,
            "n_cells": adata_ct.n_obs,
            "n_genes_used": adata_ct.n_vars,
            "best_n_components": n_comp,
            "cv_r2_pls": r2_mean_cv,
            "r2_perm_p": r2_perm_p,
            "ridge_r2_cv": ridge_r2,
            "spearman_age_in_sample": spearman_r,
            "silhouette_binned_age": sil,
            "top_age_components": ",".join(age_components)
        })

        del X, Xz, scores, expr_df, expr_sem_df, gene_perf_df
        gc.collect()

# ---------- Save summary ----------
summary_df = pd.DataFrame(rows)
csv_path = os.path.join(OUTPUT_FOLDER, "PLS_age_feature_summary.csv")
summary_df.to_csv(csv_path, index=False)
print(f"\n‚úÖ Saved PLS age-feature summary: {csv_path}\n")