# Archetype Analysis on **BETA CELLS**
### Analyzing Linear AA stability across 5 runs

**Dataset:** Single-Cell RNA-seq (HFD Beta Cells)
**Pipeline:** Adapted from MNIST Analysis Pipeline
**Goal:** Calculate consistency and stability metrics across 5 runs of Linear AA

In [2]:
import numpy as np
import torch
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
import umap
from matplotlib.colors import BoundaryNorm
import matplotlib.cm as cm
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# --- Configuration ---
N_ARCHETYPES = 4      # k=4 (Confirmed for Beta Cells)
N_RUNS = 4            # Number of runs in your results file

# --- Paths ---
LINEAR_AA_PATH = '/Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/LinearAA/Python/gaussian_betacells_5runs_magic/betacells_gaussian_aa_results_5runs_magic.pth'
ADATA_PATH = '/Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/LinearAA/Python/data/beta_cells_hfd.h5ad'
AANET_PATH = '/Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/AAnet/example_notebooks/results/beta_Cells/AAnet_singlecell_runs.pth'
# Output folder for plots
RESULTS_DIR = "analysis_results/beta_cells"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [4]:
# --- Metric Functions ---

def calcMI(z1, z2):
    """Calculates Mutual Information (MI)."""
    eps = 1e-16
    P = z1 @ z2.T
    PXY = P / P.sum()
    PXPY = np.outer(PXY.sum(1), PXY.sum(0))
    MI = np.sum(PXY * np.log(eps + PXY / (eps + PXPY)))
    return MI

def calcNMI(z1, z2):
    """Calculates Normalized Mutual Information (NMI)."""
    NMI = (2 * calcMI(z1, z2)) / (calcMI(z1, z1) + calcMI(z2, z2))
    return NMI

def preprocess(X):
    """Preprocessing step for mean-centering and computing mSST."""
    meanX = np.mean(X, axis=0)
    X_centered = X - meanX
    mSST = np.sum(np.mean(X_centered**2, axis=0))
    return X_centered, mSST

def ArchetypeConsistency(XC1, XC2, mSST):
    """Calculates Archetype Consistency and ISI between two runs."""
    D = squareform(pdist(np.hstack((XC1, XC2)).T, 'euclidean'))**2
    D = D[:XC1.shape[1], XC1.shape[1]:]
    
    i, j, v = [], [], []
    K = XC1.shape[1]
    D_temp = D.copy()
    
    for k in range(K):
        min_index = np.unravel_index(np.argmin(D_temp, axis=None), D_temp.shape)
        i.append(min_index[0])
        j.append(min_index[1])
        v.append(D[i[-1], j[-1]])
        D_temp[i[-1], :] = np.inf
        D_temp[:, j[-1]] = np.inf
    
    consistency = 1 - np.mean(v) / mSST
    
    D2 = np.abs(np.corrcoef(np.hstack((XC1, XC2)).T))
    D2 = D2[:K, K:]
    ISI = 1 / (2 * K * (K - 1)) * (
        np.sum(D2 / np.max(D2, axis=1, keepdims=True) + 
               D2 / np.max(D2, axis=0, keepdims=True)) - 2 * K
    )
    
    return consistency, ISI

In [5]:
# --- Load Single-Cell Data ---
print(f"Loading single-cell data from {ADATA_PATH}...")

adata = sc.read_h5ad(ADATA_PATH)


# Dataset is too big, subsample for faster testing
adata = adata[:1000, :]
print(f"Subsampled dataset shape: {adata.X.shape}")

X_dense = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X.copy()
X_log_normalized = sc.pp.log1p(X_dense, copy=True)

N_SAMPLES = X_log_normalized.shape[0]
N_FEATURES = X_log_normalized.shape[1]
print(f"Data Loaded: {N_SAMPLES} cells x {N_FEATURES} genes")

# Prepare data
X_raw_data = X_log_normalized.astype(np.float64)
X_centered, mSST = preprocess(X_raw_data)

print(f"mSST: {mSST:.4f}")
print("Data preparation complete.")

Loading single-cell data from /Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/LinearAA/Python/data/beta_cells_hfd.h5ad...
Subsampled dataset shape: (1000, 16483)
Data Loaded: 1000 cells x 16483 genes
mSST: 636.6481
Data preparation complete.


In [6]:
# Use to_numpy from previous cell (already defined)
# --- Helper Functions ---

def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

# --- Load Linear AA runs with dimension validation ---
print(f"Loading runs from {LINEAR_AA_PATH}...")
checkpoint = torch.load(LINEAR_AA_PATH, map_location='cpu')

if 'C' not in checkpoint or 'A' not in checkpoint:
    raise KeyError("The checkpoint file does not contain expected 'C' and 'A' keys.")

C_list = checkpoint['C']
A_list = checkpoint['A']

# Lists to store numpy arrays for all runs
S_all = []
C_all = []

# First pass: identify the common dimension
print("\n--- Checking dimensions across runs ---")
feature_dims = []
for run_idx in range(N_RUNS):
    C = to_numpy(C_list[run_idx])
    S = to_numpy(A_list[run_idx])
    
    # Transpose if needed
    if S.shape[0] != N_ARCHETYPES and S.shape[1] == N_ARCHETYPES:
        S = S.T
    if C.shape[1] != N_ARCHETYPES and C.shape[0] == N_ARCHETYPES:
        C = C.T
    
    feature_dims.append(C.shape[0])
    print(f"Run {run_idx+1}: S shape {S.shape}, C shape {C.shape}, Features: {C.shape[0]}")

# Find the most common dimension (or minimum to be safe)
from collections import Counter
dim_counts = Counter(feature_dims)
target_dim = min(feature_dims)  # Use minimum to avoid index errors

print(f"\n--- Using target dimension: {target_dim} features ---")
print(f"This matches the subsampled data dimension: {N_FEATURES}")

# Second pass: load and align all runs to target dimension
for run_idx in range(N_RUNS):
    C = to_numpy(C_list[run_idx])
    S = to_numpy(A_list[run_idx])
    
    # Transpose if needed
    if S.shape[0] != N_ARCHETYPES and S.shape[1] == N_ARCHETYPES:
        S = S.T
    if C.shape[1] != N_ARCHETYPES and C.shape[0] == N_ARCHETYPES:
        C = C.T
    
    # Truncate or select features to match target dimension
    if C.shape[0] > target_dim:
        print(f"  Run {run_idx+1}: Truncating from {C.shape[0]} to {target_dim} features")
        C = C[:target_dim, :]
    elif C.shape[0] < target_dim:
        raise ValueError(f"Run {run_idx+1} has fewer features ({C.shape[0]}) than target ({target_dim})")
    
    # Similarly check S dimension matches N_SAMPLES
    if S.shape[1] > N_SAMPLES:
        print(f"  Run {run_idx+1}: Truncating S from {S.shape[1]} to {N_SAMPLES} samples")
        S = S[:, :N_SAMPLES]
    
    S_all.append(S)
    C_all.append(C)
    
    print(f"Run {run_idx+1} final: S shape {S.shape}, C shape {C.shape}")

print(f"\n✓ Successfully loaded and aligned {N_RUNS} runs")
print(f"  All C matrices: ({target_dim}, {N_ARCHETYPES})")
print(f"  All S matrices: ({N_ARCHETYPES}, {N_SAMPLES})")

Loading runs from /Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/LinearAA/Python/gaussian_betacells_5runs_magic/betacells_gaussian_aa_results_5runs_magic.pth...

--- Checking dimensions across runs ---
Run 1: S shape (4, 1000), C shape (1000, 4), Features: 1000
Run 2: S shape (4, 1000), C shape (1000, 4), Features: 1000
Run 3: S shape (4, 1000), C shape (1000, 4), Features: 1000
Run 4: S shape (4, 1000), C shape (1000, 4), Features: 1000

--- Using target dimension: 1000 features ---
This matches the subsampled data dimension: 16483
Run 1 final: S shape (4, 1000), C shape (1000, 4)
Run 2 final: S shape (4, 1000), C shape (1000, 4)
Run 3 final: S shape (4, 1000), C shape (1000, 4)
Run 4 final: S shape (4, 1000), C shape (1000, 4)

✓ Successfully loaded and aligned 4 runs
  All C matrices: (1000, 4)
  All S matrices: (4, 1000)


In [7]:
# --- Load AAnet results ---
print("\nLoading AAnet results...")

try:
    aanet_checkpoint = torch.load(AANET_PATH, map_location='cpu', weights_only=False)
    
    # Extract C and S
    C_aanet_list = [to_numpy(c) for c in aanet_checkpoint['C_list']]
    S_aanet_list = [to_numpy(s) for s in aanet_checkpoint['S_list']]

    # Align dimensions to target_dim like Linear AA
    C_aanet_aligned = []
    S_aanet_aligned = []

    for C, S in zip(C_aanet_list, S_aanet_list):
        # Transpose if needed (features × archetypes)
        if C.shape[1] != N_ARCHETYPES and C.shape[0] == N_ARCHETYPES:
            C = C.T
        if S.shape[0] != N_ARCHETYPES and S.shape[1] == N_ARCHETYPES:
            S = S.T
        
        # Truncate C to target_dim
        if C.shape[0] > target_dim:
            C = C[:target_dim, :]
        
        # Truncate S to N_SAMPLES
        if S.shape[1] > N_SAMPLES:
            S = S[:, :N_SAMPLES]
        
        C_aanet_aligned.append(C)
        S_aanet_aligned.append(S)

    print(f"✓ Loaded {len(C_aanet_aligned)} AAnet runs")
    print(f"  Example C shape: {C_aanet_aligned[0].shape}")
    print(f"  Example S shape: {S_aanet_aligned[0].shape}")

    aanet_available = True

except Exception as e:
    print(f"⚠ Could not load AAnet results: {e}")
    aanet_available = False



Loading AAnet results...
✓ Loaded 5 AAnet runs
  Example C shape: (20, 4)
  Example S shape: (4, 1000)


In [8]:
# --- Calculate Pairwise Metrics Across All Runs ---
print("\n" + "="*60)
print("CALCULATING STABILITY METRICS ACROSS 4 RUNS")
print("="*60)

nmi_scores = []
consistency_scores = []
isi_scores = []

# Compare all pairs of runs
for i in range(N_RUNS):
    for j in range(i+1, N_RUNS):  # Fixed: was "range(, N_RUNS)"
        # NMI between runs
        nmi = calcNMI(S_all[i], S_all[j])
        nmi_scores.append(nmi)
        
        # Consistency and ISI between runs
        cons, isi = ArchetypeConsistency(C_all[i], C_all[j], mSST)
        consistency_scores.append(cons)
        isi_scores.append(isi)
        
        print(f"Run {i+1} vs Run {j+1}: NMI={nmi:.4f}, Consistency={cons:.4f}, ISI={isi:.4f}")

# Calculate average metrics
avg_nmi = np.mean(nmi_scores)
avg_consistency = np.mean(consistency_scores)
avg_isi = np.mean(isi_scores)

print("\n" + "="*60)
print("AVERAGE METRICS ACROSS ALL PAIRWISE COMPARISONS")
print("="*60)
print(f"Average NMI:         {avg_nmi:.4f} ± {np.std(nmi_scores):.4f}")
print(f"Average Consistency: {avg_consistency:.4f} ± {np.std(consistency_scores):.4f}")
print(f"Average ISI:         {avg_isi:.4f} ± {np.std(isi_scores):.4f}")
print("="*60)


CALCULATING STABILITY METRICS ACROSS 4 RUNS
Run 1 vs Run 2: NMI=0.9859, Consistency=0.9997, ISI=0.0759
Run 1 vs Run 3: NMI=0.9769, Consistency=0.9997, ISI=0.0392
Run 1 vs Run 4: NMI=0.9693, Consistency=0.9995, ISI=0.0669
Run 2 vs Run 3: NMI=0.9645, Consistency=0.9999, ISI=0.0085
Run 2 vs Run 4: NMI=0.9934, Consistency=1.0000, ISI=0.0094
Run 3 vs Run 4: NMI=0.9472, Consistency=0.9998, ISI=0.0097

AVERAGE METRICS ACROSS ALL PAIRWISE COMPARISONS
Average NMI:         0.9729 ± 0.0150
Average Consistency: 0.9998 ± 0.0002
Average ISI:         0.0349 ± 0.0280


In [9]:
# --- Calculate Pairwise Metrics Across All Runs (Linear AA & AAnet) ---
print("\n" + "="*60)
print("CALCULATING STABILITY METRICS FOR LINEAR AA AND AANET")
print("="*60)

def calculate_pairwise_metrics(C_list, S_list, mSST, method_name):
    nmi_scores = []
    consistency_scores = []
    isi_scores = []
    n_runs = len(C_list)
    for i in range(n_runs):
        for j in range(i+1, n_runs):
            nmi = calcNMI(S_list[i], S_list[j])
            nmi_scores.append(nmi)
            cons, isi = ArchetypeConsistency(C_list[i], C_list[j], mSST)
            consistency_scores.append(cons)
            isi_scores.append(isi)
            print(f"{method_name} Run {i+1} vs Run {j+1}: NMI={nmi:.4f}, Consistency={cons:.4f}, ISI={isi:.4f}")
    avg_nmi = np.mean(nmi_scores)
    avg_consistency = np.mean(consistency_scores)
    avg_isi = np.mean(isi_scores)
    print("\n" + "="*60)
    print(f"\n{method_name} AVERAGE METRICS ACROSS ALL PAIRWISE COMPARISONS")
    print("="*60)
    print(f"Average NMI:         {avg_nmi:.4f} ± {np.std(nmi_scores):.4f}")
    print(f"Average Consistency: {avg_consistency:.4f} ± {np.std(consistency_scores):.4f}")
    print(f"Average ISI:         {avg_isi:.4f} ± {np.std(isi_scores):.4f}")
    print("="*60)
    return nmi_scores, consistency_scores, isi_scores, avg_nmi, avg_consistency, avg_isi

# Linear AA metrics
nmi_scores, consistency_scores, isi_scores, avg_nmi, avg_consistency, avg_isi = calculate_pairwise_metrics(
    C_all, S_all, mSST, "Linear AA"
)

# AAnet metrics (if available)
if aanet_available:
    nmi_scores_aanet, consistency_scores_aanet, isi_scores_aanet, avg_nmi_aanet, avg_consistency_aanet, avg_isi_aanet = calculate_pairwise_metrics(
        C_aanet_aligned, S_aanet_aligned, mSST, "AAnet"
    )


CALCULATING STABILITY METRICS FOR LINEAR AA AND AANET
Linear AA Run 1 vs Run 2: NMI=0.9859, Consistency=0.9997, ISI=0.0759
Linear AA Run 1 vs Run 3: NMI=0.9769, Consistency=0.9997, ISI=0.0392
Linear AA Run 1 vs Run 4: NMI=0.9693, Consistency=0.9995, ISI=0.0669
Linear AA Run 2 vs Run 3: NMI=0.9645, Consistency=0.9999, ISI=0.0085
Linear AA Run 2 vs Run 4: NMI=0.9934, Consistency=1.0000, ISI=0.0094
Linear AA Run 3 vs Run 4: NMI=0.9472, Consistency=0.9998, ISI=0.0097


Linear AA AVERAGE METRICS ACROSS ALL PAIRWISE COMPARISONS
Average NMI:         0.9729 ± 0.0150
Average Consistency: 0.9998 ± 0.0002
Average ISI:         0.0349 ± 0.0280
AAnet Run 1 vs Run 2: NMI=1.0030, Consistency=1.0000, ISI=0.8781
AAnet Run 1 vs Run 3: NMI=1.0000, Consistency=1.0000, ISI=0.8679
AAnet Run 1 vs Run 4: NMI=0.9997, Consistency=1.0000, ISI=0.8771
AAnet Run 1 vs Run 5: NMI=0.9975, Consistency=1.0000, ISI=0.8733
AAnet Run 2 vs Run 3: NMI=1.0055, Consistency=1.0000, ISI=0.8702
AAnet Run 2 vs Run 4: NMI=1.0023, C

In [10]:
# --- Visualization Functions ---

def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

def save_figure(fig, filename):
    full_path = os.path.join(RESULTS_DIR, filename)
    print(f"Saving: {full_path}")
    fig.savefig(full_path, bbox_inches='tight', dpi=150)
    plt.close(fig)

def plot_archetype_heatmap(C, title, save_name):
    """Heatmap of C matrix (Genes x Archetypes)."""
    C = to_numpy(C)
    fig, ax = plt.subplots(figsize=(8, 10))
    sns.heatmap(C, cmap="viridis", ax=ax, cbar_kws={'label': 'Expression'})
    ax.set_title(title, fontsize=14)
    ax.set_xlabel("Archetypes")
    ax.set_ylabel("Genes")
    ax.set_xticks(np.arange(C.shape[1]) + 0.5)
    ax.set_xticklabels([f"Arc {i+1}" for i in range(C.shape[1])])
    plt.tight_layout()
    save_figure(fig, save_name)

def plot_umap_assignment(X_umap, S, title, save_name):
    """UMAP colored by dominant archetype."""
    S = to_numpy(S)
    dominant_arc = np.argmax(S.T, axis=1)
    n_arc = S.shape[0]
    
    cmap_discrete = plt.cm.get_cmap('tab10', n_arc)
    bounds = np.arange(n_arc + 1) - 0.5
    norm = BoundaryNorm(bounds, cmap_discrete.N)
    
    fig = plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        X_umap[:, 0], X_umap[:, 1],
        c=dominant_arc, cmap=cmap_discrete, norm=norm,
        s=10, alpha=0.6
    )
    plt.title(title, fontsize=14)
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    cbar = plt.colorbar(scatter, ticks=np.arange(n_arc))
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(n_arc)])
    save_figure(fig, save_name)

def plot_3d_simplex(S, title, save_name):
    """3D tetrahedron plot for k=4."""
    S = to_numpy(S)
    k = S.shape[0]
    
    if k != 4:
        print(f"Skipping 3D Simplex: requires k=4, got k={k}")
        return
    
    S_plot = S.T
    dominant_arc = np.argmax(S_plot, axis=1)
    
    from mpl_toolkits.mplot3d import Axes3D
    
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    cmap_discrete = plt.cm.get_cmap('tab10', k)
    bounds = np.arange(k + 1) - 0.5
    norm = BoundaryNorm(bounds, cmap_discrete.N)
    
    scatter = ax.scatter(
        S_plot[:, 0], S_plot[:, 1], S_plot[:, 2],
        c=dominant_arc, cmap=cmap_discrete, norm=norm,
        s=10, alpha=0.6
    )
    
    # Vertices
    ax.scatter([1, 0, 0], [0, 1, 0], [0, 0, 1], 
              c='k', marker='D', s=100, label='Archetype Vertices')
    
    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Arc 1 Weight')
    ax.set_ylabel('Arc 2 Weight')
    ax.set_zlabel('Arc 3 Weight')
    ax.set_box_aspect([1, 1, 1])
    
    cbar = fig.colorbar(scatter, ticks=np.arange(k), pad=0.1)
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(k)])
    
    save_figure(fig, save_name)

def plot_metric_distribution(scores, metric_name, save_name):
    """Box plot of metric scores across all pairwise comparisons."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    bp = ax.boxplot([scores], labels=[metric_name], patch_artist=True)
    bp['boxes'][0].set_facecolor('#4e79a7')
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title(f'{metric_name} Distribution Across Runs', fontsize=14)
    ax.grid(axis='y', alpha=0.3)
    
    # Add mean line
    mean_val = np.mean(scores)
    ax.axhline(mean_val, color='r', linestyle='--', label=f'Mean: {mean_val:.4f}')
    ax.legend()
    
    save_figure(fig, save_name)

In [11]:
# --- Generate Visualizations ---
print("\nGenerating visualizations...")

# Calculate UMAP (once)
print("Calculating UMAP...")
X_umap = umap.UMAP(n_components=2, random_state=42).fit_transform(X_centered)

# Use first run for detailed visualizations
S_ref = S_all[0]
C_ref = C_all[0]

# 1. Archetype heatmap
plot_archetype_heatmap(
    C_ref, 
    title="Linear AA Archetypes (Run 0)",
    save_name="1_linear_archetypes_heatmap.png"
)

# 2. UMAP assignment
plot_umap_assignment(
    X_umap, S_ref,
    title="Linear AA Sample Assignment (Run 0)",
    save_name="2_linear_umap_assignment.png"
)

# 3. 3D Simplex (if k=4)
if N_ARCHETYPES == 4:
    plot_3d_simplex(
        S_ref,
        title="Linear AA 3D Simplex (Run 0)",
        save_name="3_linear_3d_simplex.png"
    )

    # --- AAnet Visualizations ---
    if aanet_available:
        print("\nGenerating AAnet visualizations...")
        S_ref_aanet = S_aanet_aligned[0]
        C_ref_aanet = C_aanet_aligned[0]

        # 1. Archetype heatmap
        plot_archetype_heatmap(
            C_ref_aanet,
            title="AAnet Archetypes (Run 0)",
            save_name="1_aanet_archetypes_heatmap.png"
        )

        # 2. UMAP assignment
        plot_umap_assignment(
            X_umap, S_ref_aanet,
            title="AAnet Sample Assignment (Run 0)",
            save_name="2_aanet_umap_assignment.png"
        )

        # 3. 3D Simplex (if k=4)
        plot_3d_simplex(
            S_ref_aanet,
            title="AAnet 3D Simplex (Run 0)",
            save_name="3_aanet_3d_simplex.png"
        )

        # 4. Metric distributions
        plot_metric_distribution(nmi_scores_aanet, "NMI", "4_aanet_nmi_distribution.png")
        plot_metric_distribution(consistency_scores_aanet, "Consistency", "5_aanet_consistency_distribution.png")
        plot_metric_distribution(isi_scores_aanet, "ISI", "6_aanet_isi_distribution.png")

        # 5. Summary bar plot
        fig, ax = plt.subplots(figsize=(10, 6))
        metrics = ['NMI', 'Consistency', 'ISI']
        means_aanet = [avg_nmi_aanet, avg_consistency_aanet, avg_isi_aanet]
        stds_aanet = [np.std(nmi_scores_aanet), np.std(consistency_scores_aanet), np.std(isi_scores_aanet)]

        bars = ax.bar(metrics, means_aanet, yerr=stds_aanet, capsize=10,
                      color=['#4e79a7', '#f28e2b', '#e15759'], alpha=0.8)

        ax.set_ylabel('Score', fontsize=12)
        ax.set_title('AAnet Stability Metrics (Average ± Std)', fontsize=14, fontweight='bold')
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim([0, 1.05])

        for bar, mean in zip(bars,

SyntaxError: incomplete input (735876729.py, line 80)