In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc

# ------------------------
# Config
# ------------------------

base_dir = "/mnt/projects/debruinz_project"

# Output paths
out_hw_csv = (
    "/mnt/projects/debruinz_project/bisholea/capstone/"
    "tsv2_benchmarks/tsv2_k80_nmf_nnae_H_W_stats.csv"
)
out_diag_csv = (
    "/mnt/projects/debruinz_project/bisholea/capstone/"
    "tsv2_benchmarks/tsv2_k80_nmf_diag_loadings.csv"
)

# Tissues present for TSv2 k=80 runs
tissues = ["Blood", "Bone_Marrow", "Lung", "Mammary", "Thymus"]

# Models: (model_code_for_R, folder_name, filename_pattern, H_key_in_obsm, W_key_in_varm)
models = [
    # Base NMF
    ("NMF", "Base NMF",
     "sklearn_nmf_k80_baseNMF_{tissue}.h5ad",
     "H_sklearn_nmf_k80", "W_sklearn_nmf_k80"),
    # AE NMF (NNAE)
    ("AE",  "AE NMF",
     "tied_nmf_k80_no_cond_{tissue}.h5ad",
     "H_shared", "W_tied"),
]

hw_rows = []
diag_rows = []

for model_code, folder, pattern, h_key, w_key in models:
    for tissue in tissues:
        h5ad_path = os.path.join(base_dir, folder, pattern.format(tissue=tissue))
        if not os.path.exists(h5ad_path):
            print(f"[WARN] File not found, skipping: {h5ad_path}")
            continue

        print(f"Loading {h5ad_path}")
        adata = sc.read_h5ad(h5ad_path)

        # --- H ---
        if h_key not in adata.obsm.keys():
            raise KeyError(
                f"{h_key} not found in adata.obsm for {h5ad_path}. "
                f"Available obsm keys: {list(adata.obsm.keys())}"
            )
        H = np.asarray(adata.obsm[h_key])   # cells × factors

        # --- W ---
        if w_key not in adata.varm.keys():
            raise KeyError(
                f"{w_key} not found in adata.varm for {h5ad_path}. "
                f"Available varm keys: {list(adata.varm.keys())}"
            )
        W = np.asarray(adata.varm[w_key])   # genes × factors

        n_cells, k_H = H.shape
        n_genes, k_W = W.shape
        if k_H != k_W:
            raise ValueError(f"k mismatch for {h5ad_path}: H has {k_H}, W has {k_W}")
        k = k_H
        print(f"  H shape: {H.shape}, W shape: {W.shape}")

        # -------- sparsity and norms for H and W --------
        # H: cells × k
        H_zero = (H == 0)
        pct_zero_H = H_zero.sum(axis=0) / float(n_cells) * 100.0
        l2_H = np.sqrt((H ** 2).sum(axis=0))
        l1_H = np.abs(H).sum(axis=0)

        # W: genes × k
        W_zero = (W == 0)
        pct_zero_W = W_zero.sum(axis=0) / float(n_genes) * 100.0
        l2_W = np.sqrt((W ** 2).sum(axis=0))
        l1_W = np.abs(W).sum(axis=0)

        for f in range(k):
            hw_rows.append({
                "model": model_code,  # "AE" or "NMF"
                "tissue": tissue,
                "factor": f,          # 0..k-1
                "matrix": "H",
                "pct_zero": float(pct_zero_H[f]),
                "l1": float(l1_H[f]),
                "l2": float(l2_H[f]),
            })
            hw_rows.append({
                "model": model_code,
                "tissue": tissue,
                "factor": f,
                "matrix": "W",
                "pct_zero": float(pct_zero_W[f]),
                "l1": float(l1_W[f]),
                "l2": float(l2_W[f]),
            })

        # -------- NMF diagonal loadings (W/H scaled to unit L2) --------
        if model_code == "NMF":
            # Avoid division by zero
            eps = 1e-12
            inv_l2_W = 1.0 / np.maximum(l2_W, eps)
            inv_l2_H = 1.0 / np.maximum(l2_H, eps)
            diag_vals = inv_l2_W * inv_l2_H  # length k

            for f in range(k):
                diag_rows.append({
                    "model": model_code,
                    "tissue": tissue,
                    "factor": f,
                    "diag_loading": float(diag_vals[f]),
                })

# ---- save H/W stats ----
hw_df = pd.DataFrame(hw_rows)
print(hw_df.head())
print("Total H/W rows:", len(hw_df))
os.makedirs(os.path.dirname(out_hw_csv), exist_ok=True)
hw_df.to_csv(out_hw_csv, index=False)
print("Wrote H/W stats to:", out_hw_csv)

# ---- save diagonal loadings ----
diag_df = pd.DataFrame(diag_rows)
print(diag_df.head())
print("Total diagonal rows:", len(diag_df))
diag_df.to_csv(out_diag_csv, index=False)
print("Wrote diagonal loadings to:", out_diag_csv)

Loading /mnt/projects/debruinz_project/Base NMF/sklearn_nmf_k80_baseNMF_Blood.h5ad
  H shape: (17802, 80), W shape: (27284, 80)
Loading /mnt/projects/debruinz_project/Base NMF/sklearn_nmf_k80_baseNMF_Bone_Marrow.h5ad
  H shape: (8045, 80), W shape: (26167, 80)
Loading /mnt/projects/debruinz_project/Base NMF/sklearn_nmf_k80_baseNMF_Lung.h5ad
  H shape: (11716, 80), W shape: (29584, 80)
Loading /mnt/projects/debruinz_project/Base NMF/sklearn_nmf_k80_baseNMF_Mammary.h5ad
  H shape: (18539, 80), W shape: (28389, 80)
Loading /mnt/projects/debruinz_project/Base NMF/sklearn_nmf_k80_baseNMF_Thymus.h5ad
  H shape: (9933, 80), W shape: (30453, 80)
Loading /mnt/projects/debruinz_project/AE NMF/tied_nmf_k80_no_cond_Blood.h5ad
  H shape: (17802, 80), W shape: (27284, 80)
Loading /mnt/projects/debruinz_project/AE NMF/tied_nmf_k80_no_cond_Bone_Marrow.h5ad
  H shape: (8045, 80), W shape: (26167, 80)
Loading /mnt/projects/debruinz_project/AE NMF/tied_nmf_k80_no_cond_Lung.h5ad
  H shape: (11716, 80), W 