# Archetype Analysis on **BETA CELLS**
### Comparing and analysing the resulting archetypes from AANET and LINEAR AA

**Dataset:** Single-Cell RNA-seq (HFD Beta Cells)
**Pipeline:** Adapted from MNIST Analysis Pipeline

In [69]:
import numpy as np
import torch
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import umap
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import BoundaryNorm
import matplotlib.cm as cm
import gseapy as gp
from sklearn.metrics import mean_squared_error
from scipy.stats import entropy
import pandas as pd

In [70]:


# --- Configuration ---
N_ARCHETYPES = 4      # k=4 (Confirmed for Beta Cells)
N_PROTOTYPES = 20     # Used for AAnet C reconstruction

# --- Paths (From your SC Notebook) ---
LINEAR_AA_PATH = 'LinearAA/Python/betacells_gaussian_aa_results.pth'
AANET_PATH = 'AAnet/example_notebooks/results/Beta_cell_AAnet_results.npz'
ADATA_PATH = 'LinearAA/Python/data/beta_cells_hfd.h5ad'

# Output folder for plots
RESULTS_DIR = "analysis_results/beta_cells"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [71]:
# --- Metric Functions (Identical to MNIST Pipeline) ---

def calcMI(z1,z2):
    """Calculates Mutual Information (MI)."""
    eps = 10e-16
    P = z1@z2.T
    PXY = P/P.sum()
    PXPY = np.outer(np.expand_dims(PXY.sum(1), axis=0),np.expand_dims(PXY.sum(0), axis=1))
    MI = np.sum(PXY*np.log(eps+PXY/(eps+PXPY)))
    return MI

def calcNMI(z1,z2):
    """Calculates Normalized Mutual Information (NMI)."""
    NMI=(2*calcMI(z1,z2))/(calcMI(z1,z1)+calcMI(z2,z2))
    return NMI

def preprocess(X):
    """Preprocessing step for mean-centering and computing mSST."""
    meanX = np.mean(X, axis=0)
    X_centered = X - meanX
    mSST = np.sum(np.mean(X_centered**2, axis=0))
    return X_centered, mSST

def ArchetypeConsistency(XC1, XC2, mSST):
    """Calculates Archetype Consistency and ISI."""
    D = squareform(pdist(np.hstack((XC1, XC2)).T, 'euclidean'))**2
    D = D[:XC1.shape[1], XC1.shape[1]:] 
    
    i = []
    j = []
    v = []
    K = XC1.shape[1]
    D_temp = D.copy()
    for k in range(K):
        min_index = np.unravel_index(np.argmin(D_temp, axis=None), D_temp.shape)
        i.append(min_index[0])
        j.append(min_index[1])
        v.append(D[i[-1], j[-1]])
        D_temp[i[-1], :] = np.inf
        D_temp[:, j[-1]] = np.inf
        
    consistency = 1 - np.mean(v) / mSST
    
    D2 = np.abs(np.corrcoef(np.hstack((XC1, XC2)).T))
    D2 = D2[:K, K:]
    ISI = 1 / (2 * K * (K - 1)) * (np.sum(D2 / np.max(D2, axis=1, keepdims=True) + D2 / np.max(D2, axis=0, keepdims=True)) - 2 * K)
    
    return consistency, ISI


In [72]:
# --- Extraction Adapters ---

def get_aanet_matrices(npz_path, X_raw_data, n_archetypes=N_ARCHETYPES, n_prototypes=N_PROTOTYPES):
    """Loads AAnet results and reconstructs C (Archetypes) from the Prototype method."""
    saved_data = np.load(npz_path)
    X_aanet = X_raw_data # (N, F)
    S_aanet = saved_data['latent_coords'].T # (k, N)
    
    n_features = X_aanet.shape[1]
    C_aanet = np.zeros((n_features, n_archetypes))
    
    for i in range(n_archetypes):
        # Use top N prototypes to estimate the Archetype center
        top_indices = np.argsort(S_aanet[i, :])[::-1][:n_prototypes]
        C_aanet[:, i] = X_aanet[top_indices].mean(axis=0)
        
    return S_aanet, C_aanet, X_aanet

def get_linear_matrices(pth_path, data_tensor):
    """Loads Linear AA results and calculates C (Archetypes) using C = X @ A."""
    checkpoint = torch.load(pth_path)
    S_linear = checkpoint['S']
    A_matrix = checkpoint['C'] 
    
    if isinstance(S_linear, torch.Tensor): S_linear = S_linear.detach().cpu().numpy()
    if isinstance(A_matrix, torch.Tensor): A_matrix = A_matrix.detach().cpu().numpy()
    
    X_in = data_tensor.detach().cpu().numpy() # (F, N)
    
    # C_linear = (Features, Samples) @ (Samples, Archetypes) = (Features, Archetypes)
    C_linear = X_in @ A_matrix
    
    X_linear_out = X_in.T # (N, F)
        
    return S_linear, C_linear, X_linear_out

In [73]:
# --- Load Data (SCANPY / SC-RNA-SEQ) ---
print(f"Loading single-cell data from {ADATA_PATH}...")

try:
    adata = sc.read_h5ad(ADATA_PATH)
    # Convert to dense if sparse
    X_dense = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X.copy()
    # Log1p Transformation (Standard scRNA-seq preprocessing)
    X_log_normalized = sc.pp.log1p(X_dense, copy=True) 
    
    # Set Dimensions
    N_SAMPLES = X_log_normalized.shape[0]
    N_FEATURES = X_log_normalized.shape[1]
    print(f"Data Loaded: {N_SAMPLES} cells x {N_FEATURES} genes")

    # --- Prepare Inputs for Extractors ---
    
    # 1. X_raw_data (N, F) - float64 for precision
    X_raw_data = X_log_normalized.astype(np.float64)

    # 2. X_linear_in (F, N) - Transposed for Linear AA
    X_linear_in = torch.from_numpy(X_raw_data).t().double()

    print("Data preparation complete.")

except FileNotFoundError:
    print(f"ERROR: Could not find data file at {ADATA_PATH}. Please check path.")
    # Create dummy data for testing pipeline structure if file missing
    print("Creating DUMMY data for pipeline verification...")
    X_raw_data = np.random.rand(1000, 3887)
    X_linear_in = torch.from_numpy(X_raw_data).t().double()


Loading single-cell data from LinearAA/Python/data/beta_cells_hfd.h5ad...


Data Loaded: 3887 cells x 16483 genes
Data preparation complete.


In [74]:
# --- Extract Matrices ---
print("Extracting matrices from saved files...")
try:
    # AAnet (S1, C1, X1)
    S1, C1, X1 = get_aanet_matrices(AANET_PATH, X_raw_data, n_archetypes=N_ARCHETYPES)

    # Linear AA (S2, C2, X2)
    S2, C2, X2 = get_linear_matrices(LINEAR_AA_PATH, X_linear_in)

    print("Extraction successful! Running analysis...")

    # --- Preprocess & Metrics ---
    # Preprocess X for mSST (Total Variance)
    X_centered, mSST_val = preprocess(X1) 

    print("\n--- Running Archetypal Comparison (Beta Cells) ---")
    print("-" * 40)

    # 1. NMI Analysis (S matrices)
    nmi_score = calcNMI(S1, S2)
    print(f"1. NMI Score (Clustering Similarity): {nmi_score:.4f}")

    # 2. Archetype Consistency Analysis (C matrices)
    consistency, isi = ArchetypeConsistency(C1, C2, mSST_val)
    print(f"2. Archetype Consistency Score: {consistency:.4f}")
    print(f"3. In-Sample Instability (ISI) Score: {isi:.4f}")
    print("-" * 40)

except Exception as e:
    print(f"An error occurred during extraction/analysis: {e}")
 

Extracting matrices from saved files...
Extraction successful! Running analysis...

--- Running Archetypal Comparison (Beta Cells) ---
----------------------------------------
1. NMI Score (Clustering Similarity): 0.5284
2. Archetype Consistency Score: 0.8981
3. In-Sample Instability (ISI) Score: 0.9792
----------------------------------------


In [75]:
# --- Visualization Functions (ADAPTED FOR SINGLE CELL) ---

def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

def save_figure(fig, filename):
    full_path = os.path.join(RESULTS_DIR, filename)
    print(f"Saving figure to: {full_path}")
    fig.savefig(full_path, bbox_inches='tight')
    plt.close(fig)

def plot_archetype_heatmap(C, title="Archetype Gene Profiles", save_name="archetype_heatmap.png"):
    """
    Plots a heatmap of the C matrix (Genes x Archetypes).
    Since there are too many genes, we visualize the full matrix density or top variance genes.
    """
    C = to_numpy(C)
    
    fig, ax = plt.subplots(figsize=(8, 10))
    sns.heatmap(C, cmap="viridis", ax=ax, cbar_kws={'label': 'Expression'})
    
    ax.set_title(title, fontsize=14)
    ax.set_xlabel("Archetypes")
    ax.set_ylabel("Genes (Feature Index)")
    ax.set_xticks(np.arange(C.shape[1]) + 0.5)
    ax.set_xticklabels([f"Arc {i+1}" for i in range(C.shape[1])])
    
    plt.tight_layout()
    save_figure(fig, save_name)

def plot_umap_assignment(X_umap, S, title="Sample Assignment", s=10, alpha=0.6, save_name="umap_assignment.png"):
    """
    Plots UMAP embedding colored by dominant archetype assignment
    with a corrected discrete color bar.
    """
    S = to_numpy(S)
    dominant_arc = np.argmax(S.T, axis=1) # Indices 0, 1, ..., k-1
    n_arc = S.shape[0]

    # --- FIX START ---
    # 1. Get the discrete colors from the 'tab10' map
    cmap_discrete = cm.get_cmap('tab10', n_arc)
    
    # 2. Define the boundaries for the n_arc colors
    # e.g., for n_arc=4, boundaries are -0.5, 0.5, 1.5, 2.5, 3.5
    bounds = np.arange(n_arc + 1) - 0.5 
    norm = BoundaryNorm(bounds, cmap_discrete.N)
    
    fig = plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        X_umap[:, 0], 
        X_umap[:, 1], 
        c=dominant_arc,
        cmap=cmap_discrete, # Use the discrete colormap
        norm=norm,         # Apply the discrete normalization
        s=s, 
        alpha=alpha
    )
    # --- FIX END ---
    
    plt.title(title, fontsize=14)
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    
    # Create legend
    # The colorbar now correctly uses the discrete normalization
    cbar = plt.colorbar(scatter, ticks=np.arange(n_arc))
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(n_arc)])
    
    save_figure(fig, save_name)

def plot_reconstruction_scatter(X, C, S, n_samples=3, save_name="reconstruction_scatter.png"):
    """
    Visualizes scatter plot of Original vs Reconstructed Gene Expression for random cells.
    """
    X = to_numpy(X)
    C = to_numpy(C)
    S = to_numpy(S)
    
    X_rec = (C @ S).T
    indices = np.random.choice(X.shape[0], n_samples, replace=False)
    
    fig, axes = plt.subplots(1, n_samples, figsize=(5 * n_samples, 5))
    if n_samples == 1: axes = [axes]
    
    for i, idx in enumerate(indices):
        orig = X[idx]
        rec = X_rec[idx]
        
        axes[i].scatter(orig, rec, alpha=0.3, s=5)
        
        # Ideal line
        lims = [min(orig.min(), rec.min()), max(orig.max(), rec.max())]
        axes[i].plot(lims, lims, 'r--', alpha=0.75, label='Ideal')
        
        axes[i].set_title(f"Cell {idx} Reconstruction")
        axes[i].set_xlabel("Original Expression")
        axes[i].set_ylabel("Reconstructed Expression")
        axes[i].grid(True, alpha=0.3)

    plt.tight_layout()
    save_figure(fig, save_name)

def plot_metric_scores(metrics, title="Model Comparison Scores", save_name="metric_scores.png"):
    metric_names = sorted(list(metrics.keys()))
    values = [metrics[k] for k in metric_names]

    fig, ax = plt.subplots(figsize=(8, 6))
    rects = ax.bar(np.arange(len(metric_names)), values, 0.5, color='#4e79a7')

    ax.set_ylabel('Score')
    ax.set_title(title)
    ax.set_xticks(np.arange(len(metric_names)))
    ax.set_xticklabels(metric_names, fontsize=11, fontweight='bold')
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points",
                    ha='center', va='bottom')

    save_figure(fig, save_name)

def plot_3d_simplex(S, title="3D Simplex Plot (Tetrahedron)", save_name="5_3d_simplex_k4.png"):
    """
    Plots the S matrix (k=4) in 3D, representing a tetrahedron,
    with a corrected discrete color bar.
    """
    S = to_numpy(S) # (k, N)
    k = S.shape[0]

    if k != 4:
        print(f"Skipping 3D Simplex Plot: Requires k=4, but current k={k}.")
        return

    S_plot = S.T # (N, 4)
    dominant_arc = np.argmax(S_plot, axis=1) # The color indices (0, 1, 2, 3)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # --- FIX START ---
    # 1. Get the discrete colors from the 'tab10' map
    cmap_discrete = cm.get_cmap('tab10', k)
    
    # 2. Define the boundaries for the k colors
    # e.g., for k=4, boundaries are -0.5, 0.5, 1.5, 2.5, 3.5
    bounds = np.arange(k + 1) - 0.5 
    norm = BoundaryNorm(bounds, cmap_discrete.N)
    
    # Scatter plot now uses the discrete colormap and normalization
    scatter = ax.scatter(
        S_plot[:, 0], # Archetype 1 weight
        S_plot[:, 1], # Archetype 2 weight
        S_plot[:, 2], # Archetype 3 weight
        c=dominant_arc,
        cmap=cmap_discrete, 
        norm=norm,        
        s=10,
        alpha=0.6,
        marker='o'
    )
    # --- FIX END ---

    # Plot the vertices of the tetrahedron
    ax.scatter([1, 0, 0], [0, 1, 0], [0, 0, 1], c='k', marker='D', s=100, label='Archetype Vertices')

    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Archetype 1 Weight')
    ax.set_ylabel('Archetype 2 Weight')
    ax.set_zlabel('Archetype 3 Weight')
    ax.set_box_aspect([1, 1, 1])

    # The colorbar now correctly uses the discrete normalization
    cbar = fig.colorbar(scatter, ticks=np.arange(k), pad=0.1)
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(k)])

    save_figure(fig, save_name)
    

In [76]:
# --- Execute Visualizations ---

# 1. Calculate UMAP (if not already done)
print("Calculating UMAP for visualization...")
X_umap = umap.UMAP(n_components=2, random_state=42).fit_transform(X_centered)

# 2. Plot Archetypes (Heatmaps)
plot_archetype_heatmap(C2, title="Linear AA Archetypes (Gene Profiles)", save_name="1_archetypes_linear_aa.png")
plot_archetype_heatmap(C1, title="AAnet Archetypes (Gene Profiles)", save_name="1_archetypes_aanet.png")

# 3. Plot Assignments (UMAP)
plot_umap_assignment(X_umap, S2, title="Linear AA Sample Assignment", save_name="2_umap_assignment_linear_aa.png")
plot_umap_assignment(X_umap, S1, title="AAnet Sample Assignment", save_name="2_umap_assignment_aanet.png")

# 4. Reconstruction Check
plot_reconstruction_scatter(X_raw_data, C2, S2, save_name="3_reconstruction_linear_aa.png")
plot_reconstruction_scatter(X_raw_data, C1, S1, save_name="3_reconstruction_aanet.png")

# 5. Simplex (Runs if k=3 or k=4)
if N_ARCHETYPES == 3:
    # Note: Simplex plot is currently omitted for simplicity as Beta cells usually use k!=3
    # You can add a dedicated k=3 Simplex plotting function here if needed.
    print("Skipping 2D Simplex Plot: Requires k=3.")
elif N_ARCHETYPES == 4:
    print(f"Plotting 3D Simplex (Tetrahedron) for k={N_ARCHETYPES}...")
    plot_3d_simplex(S2, title="Linear AA 3D Simplex Assignment", save_name="5_3d_simplex_linear_aa.png")
    plot_3d_simplex(S1, title="AAnet 3D Simplex Assignment", save_name="5_3d_simplex_aanet.png")
else:
    print(f"Skipping Simplex Plot: Only k=3 (2D) or k=4 (3D) supported here. Current k={N_ARCHETYPES}.")

# 6. Comparison Metrics
metrics = {
    'NMI': nmi_score,
    'Archetype Consistency': consistency,
    'In-Sample Instability (ISI)': isi
}
plot_metric_scores(metrics, title="AAnet vs Linear AA Metrics (Beta Cells)", save_name="4_comparison_metrics.png")

print("\nAll plots saved to 'analysis_results/beta_cells'.")

Calculating UMAP for visualization...


  warn(


Saving figure to: analysis_results/beta_cells/1_archetypes_linear_aa.png
Saving figure to: analysis_results/beta_cells/1_archetypes_aanet.png


  cmap_discrete = cm.get_cmap('tab10', n_arc)
  cmap_discrete = cm.get_cmap('tab10', n_arc)


Saving figure to: analysis_results/beta_cells/2_umap_assignment_linear_aa.png
Saving figure to: analysis_results/beta_cells/2_umap_assignment_aanet.png
Saving figure to: analysis_results/beta_cells/3_reconstruction_linear_aa.png
Saving figure to: analysis_results/beta_cells/3_reconstruction_aanet.png
Plotting 3D Simplex (Tetrahedron) for k=4...
Saving figure to: analysis_results/beta_cells/5_3d_simplex_linear_aa.png


  cmap_discrete = cm.get_cmap('tab10', k)
  cmap_discrete = cm.get_cmap('tab10', k)


Saving figure to: analysis_results/beta_cells/5_3d_simplex_aanet.png
Saving figure to: analysis_results/beta_cells/4_comparison_metrics.png

All plots saved to 'analysis_results/beta_cells'.
