# Archetype Analysis on the **MNIST DATASET**
### Comparing and analysing the resulting archetypes from AANET, MIDA, LINEAR AA, DEEPA

In [1]:

import numpy as np
import torch
from scipy.spatial.distance import pdist, squareform
from torchvision import datasets, transforms 

# --- Configuration (UPDATE THESE PATHS!) ---
N_ARCHETYPES = 3  # Match the k you used for both models
N_SAMPLES = 5842  # Match the number of samples used for training
N_PROTOTYPES = 20 # Used for AAnet C reconstruction

LINEAR_AA_PATH = 'LinearAA/Python/mnist_gaussian_aa_results.pth'
AANET_PATH = 'AAnet/example_notebooks/results/AAnet_MNIST_digit4_results.npz'
Middata_PATH = 'Midaa\midaa_core_matrices.pth'



def calcMI(z1,z2):
    """Calculates Mutual Information (MI)."""
    eps = 10e-16
    P = z1@z2.T
    PXY = P/P.sum()
    PXPY = np.outer(np.expand_dims(PXY.sum(1), axis=0),np.expand_dims(PXY.sum(0), axis=1))
    MI = np.sum(PXY*np.log(eps+PXY/(eps+PXPY)))
    return MI

def calcNMI(z1,z2):
    """Calculates Normalized Mutual Information (NMI)."""
    NMI=(2*calcMI(z1,z2))/(calcMI(z1,z1)+calcMI(z2,z2))
    return NMI

def preprocess(X):
    """Preprocessing step for mean-centering and computing mSST."""
    meanX = np.mean(X, axis=0)
    X_centered = X - meanX
    mSST = np.sum(np.mean(X_centered**2, axis=0))
    return X_centered, mSST

def ArchetypeConsistency(XC1, XC2, mSST):
    """Calculates Archetype Consistency and ISI."""
    D = squareform(pdist(np.hstack((XC1, XC2)).T, 'euclidean'))**2
    D = D[:XC1.shape[1], XC1.shape[1]:] 
    
    i = []
    j = []
    v = []
    K = XC1.shape[1]
    D_temp = D.copy()
    for k in range(K):
        min_index = np.unravel_index(np.argmin(D_temp, axis=None), D_temp.shape)
        i.append(min_index[0])
        j.append(min_index[1])
        v.append(D[i[-1], j[-1]])
        D_temp[i[-1], :] = np.inf
        D_temp[:, j[-1]] = np.inf
        
    consistency = 1 - np.mean(v) / mSST
    
    D2 = np.abs(np.corrcoef(np.hstack((XC1, XC2)).T))
    D2 = D2[:K, K:]
    ISI = 1 / (2 * K * (K - 1)) * (np.sum(D2 / np.max(D2, axis=1, keepdims=True) + D2 / np.max(D2, axis=0, keepdims=True)) - 2 * K)
    
    return consistency, ISI

# --- Extraction Adapters ---

def get_aanet_matrices(npz_path, X_raw_data, n_archetypes=N_ARCHETYPES, n_prototypes=N_PROTOTYPES):
    """Loads AAnet results and reconstructs C (Archetypes) from the Prototype method."""
    saved_data = np.load(npz_path)
    X_aanet = X_raw_data # (N, F)
    S_aanet = saved_data['latent_coords'].T # (k, N)
    
    n_features = X_aanet.shape[1]
    C_aanet = np.zeros((n_features, n_archetypes))
    
    for i in range(n_archetypes):
        top_indices = np.argsort(S_aanet[i, :])[::-1][:n_prototypes]
        C_aanet[:, i] = X_aanet[top_indices].mean(axis=0)
        
    return S_aanet, C_aanet, X_aanet

def get_linear_matrices(pth_path, data_tensor):
    """Loads Linear AA results and calculates C (Archetypes) using C = X @ A."""
    checkpoint = torch.load(pth_path)
    S_linear = checkpoint['S']
    A_matrix = checkpoint['C'] 
    
    if isinstance(S_linear, torch.Tensor): S_linear = S_linear.detach().cpu().numpy()
    if isinstance(A_matrix, torch.Tensor): A_matrix = A_matrix.detach().cpu().numpy()
    
    X_in = data_tensor.detach().cpu().numpy() # (F, N)
    
    # C_linear = (Features, Samples) @ (Samples, Archetypes) = (Features, Archetypes)
    C_linear = X_in @ A_matrix
    
    X_linear_out = X_in.T # (N, F) for mSST
        
    return S_linear, C_linear, X_linear_out



ModuleNotFoundError: No module named 'torchvision'

In [80]:

# --- Load Raw Data (Filter for digit 4 only!) ---
print(f"Loading digit 4 data ({N_SAMPLES} samples)...")
mnist_data = datasets.MNIST(root='./data', train=True, download=True, 
                            transform=transforms.ToTensor())

# **CRITICAL FIX: Filter for digit 4 only**
digit_4_mask = mnist_data.targets == 4
digit_4_indices = torch.where(digit_4_mask)[0][:N_SAMPLES]

# 1. Load samples for AAnet C reconstruction (N_samples, F_features)
X_raw_data = mnist_data.data[digit_4_indices].float() / 255.0
X_raw_data = X_raw_data.flatten(start_dim=1).numpy()

# 2. Load data for Linear AA input (Features x Samples)
X_linear_in = torch.from_numpy(X_raw_data).t().double() # (F_features, N_samples)

# --- Extract Matrices ---
print("Extracting matrices from saved files...")

# AAnet (S1, C1, X1)
S1, C1, X1 = get_aanet_matrices(AANET_PATH, X_raw_data) # X1 is (N, F)

# Linear AA (S2, C2, X2)
S2, C2, X2 = get_linear_matrices(LINEAR_AA_PATH, X_linear_in) # X2 is (N, F)

print("Extraction successful! Running analysis...")

# --- EXECUTION ---

# Preprocess X for mSST (Total Variance)
X_centered, mSST_val = preprocess(X1) 

print("\n--- Running Archetypal Comparison ---")
print("-" * 40)

# 1. NMI Analysis (S matrices)
nmi_score = calcNMI(S1, S2)
print(f"1. NMI Score (Clustering Similarity): {nmi_score:.4f}")

# 2. Archetype Consistency Analysis (C matrices)
consistency, isi = ArchetypeConsistency(C1, C2, mSST_val)
print(f"2. Archetype Consistency Score: {consistency:.4f}")
print(f"3. In-Sample Instability (ISI) Score: {isi:.4f}")

print("-" * 40)


Loading digit 4 data (5842 samples)...
Extracting matrices from saved files...
Extraction successful! Running analysis...

--- Running Archetypal Comparison ---
----------------------------------------
1. NMI Score (Clustering Similarity): 0.6580
2. Archetype Consistency Score: 0.1512
3. In-Sample Instability (ISI) Score: 0.6402
----------------------------------------


## Attempting plots and visualization of results

In [85]:
# Calculate the umap projections for visualization
import umap
X_umap = umap.UMAP(n_components=2, random_state=42).fit_transform(X_centered)
# --- Visualization of Archetypal Analysis Results ---

  warn(


In [86]:
# Check for constraint violations
S_cells = S1.T  # (N_samples, 3)

# 1. Check for negative values
has_negatives = (S_cells < 0).any(axis=1)
print(f"Samples with negative coefficients: {has_negatives.sum()} ({100*has_negatives.mean():.1f}%)")
print(f"Min coefficient value: {S_cells.min():.4f}")
print(f"Max coefficient value: {S_cells.max():.4f}")

# 2. Check sum-to-one constraint
row_sums = S_cells.sum(axis=1)
print(f"\nRow sums - Min: {row_sums.min():.4f}, Max: {row_sums.max():.4f}")
print(f"Samples with sum ≠ 1 (±0.01): {(np.abs(row_sums - 1) > 0.01).sum()}")

# 3. Show examples of problematic samples
if has_negatives.sum() > 0:
    print("\nExample problematic samples:")
    problem_indices = np.where(has_negatives)[0][:5]
    for idx in problem_indices:
        print(f"  Sample {idx}: {S_cells[idx]} (sum={S_cells[idx].sum():.4f})")

Samples with negative coefficients: 1435 (24.6%)
Min coefficient value: -0.1921
Max coefficient value: 1.1175

Row sums - Min: 1.0000, Max: 1.0000
Samples with sum ≠ 1 (±0.01): 0

Example problematic samples:
  Sample 2: [-0.04563922  0.8944309   0.15120834] (sum=1.0000)
  Sample 4: [ 0.55799156 -0.1323047   0.57431316] (sum=1.0000)
  Sample 6: [ 0.15035586  0.91520286 -0.06555867] (sum=1.0000)
  Sample 7: [ 0.76657945 -0.10369822  0.33711874] (sum=1.0000)
  Sample 8: [ 0.3996544   0.62608373 -0.02573812] (sum=1.0000)


In [95]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os # NEW IMPORT
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.stats import entropy


def to_numpy(tensor):
    """Helper to ensure data is numpy array."""
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

# --- NEW UTILITY FUNCTION FOR SAVING ---
def save_figure(fig, filename, folder_name="analysis_results/mnist"):
    """Saves the figure to the specified folder, creating it if necessary."""
    
    # 1. Construct the full directory path
    # os.makedirs creates directories recursively (i.e., 'analysis_results' then 'mnist')
    # exist_ok=True prevents an error if the directory already exists
    os.makedirs(folder_name, exist_ok=True)
    
    # 2. Construct the full file path
    full_path = os.path.join(folder_name, filename)
    
    # 3. Save the figure
    print(f"Saving figure to: {full_path}")
    fig.savefig(full_path, bbox_inches='tight')
    plt.close(fig) # Close the figure to free memory

# --- UPDATED PLOTTING FUNCTIONS ---

# --- Simplex Visualization ---
def plot_simplex(S, C, img_shape=(28, 28), title="Simplex Analysis", cmap='gray_r', save_name="simplex_plot.png"): # ADDED save_name
    """
    Plots samples in a Barycentric (triangle) projection with Archetypes at vertices.
    ONLY WORKS FOR k=3.
    
    Args:
        S: (k, N_samples) Coefficient matrix.
        C: (n_features, k) Archetype matrix.
        img_shape: Tuple dimensions for archetype images (e.g., (28, 28)).
        save_name: Filename to save the figure as (optional).
    """
    S = to_numpy(S)
    C = to_numpy(C)
    k, N = S.shape
    
    if k != 3:
        print(f"⚠️ Simplex plot requires k=3, but got k={k}. Skipping plot.")
        return

    # ... (Visualization setup remains the same) ...

    # 1. Convert to Simplex Coordinates (2D Triangle)
    S_cells = S.T
    x = S_cells[:, 1] + 0.5 * S_cells[:, 2]
    y = (np.sqrt(3) / 2) * S_cells[:, 2]
    X_simplex = np.column_stack([x, y])

    # Vertex positions
    vertices = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2]])
    # 2. Calculate Mixing Strength (Entropy)
    mix_strength = np.array([entropy(S_cells[i]) for i in range(len(S_cells))])
    mix_strength = np.where(np.isfinite(mix_strength), mix_strength, 0)
    max_entropy = np.log(k)
    mix_strength = mix_strength / max_entropy

    # --- PLOTTING ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 7))
    fig.suptitle(f"{title} (k={k})", fontsize=16, fontweight='bold')

    # Left Plot: The Simplex
    scatter = ax1.scatter(
        X_simplex[:, 0], X_simplex[:, 1],
        c=mix_strength, cmap='RdYlGn_r', 
        s=10, alpha=0.6, edgecolors='none',
        vmin=0, vmax=1
    )
    
    # Draw Triangle Edges
    triangle = plt.Polygon(vertices, fill=False, edgecolor='black', linewidth=2, zorder=0)
    ax1.add_patch(triangle)

    # Add Archetype Images at Vertices
    centroid = vertices.mean(axis=0)
    for i in range(k):
        img = C[:, i].reshape(img_shape)
        im_box = OffsetImage(img, zoom=2.0, cmap=cmap)
        im_box.image.axes = ax1
        pos = vertices[i]
        direction = pos - centroid
        offset_pos = pos + 0.15 * direction / np.linalg.norm(direction)
        ab = AnnotationBbox(im_box, offset_pos, frameon=True, 
                            bboxprops=dict(edgecolor='red', linewidth=2))
        ax1.add_artist(ab)

    # Formatting Left Plot
    ax1.set_xlim(-0.3, 1.3)
    ax1.set_ylim(-0.3, 1.1)
    ax1.axis('off')
    ax1.set_aspect('equal')
    
    # Colorbar
    cbar = plt.colorbar(scatter, ax=ax1, orientation='vertical', fraction=0.03, pad=0.04)
    cbar.set_label('Mixing Strength (Normalized Entropy)')

    # Right Plot: Mixing Histogram
    ax2.hist(mix_strength, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    mean_mix = mix_strength.mean()
    ax2.axvline(mean_mix, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_mix:.2f}')
    
    ax2.set_title('Distribution of Mixing Strength')
    ax2.set_xlabel('Mixing Strength (0=Pure, 1=Mixed)')
    ax2.set_ylabel('Count')
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    # Save the figure instead of showing it
    save_figure(fig, save_name)

def plot_archetypes(C, img_shape=(28, 28), title="Archetypes", cmap='gray_r', save_name="archetypes_plot.png"): # ADDED save_name
    """
    Plots the archetypes (columns of C) as images.
    """
    C = to_numpy(C)
    n_arc = C.shape[1]
    

    fig, axes = plt.subplots(1, n_arc, figsize=(2 * n_arc, 3))
    if n_arc == 1: axes = [axes]
    
    for i in range(n_arc):
        img = C[:, i].reshape(img_shape)
        axes[i].imshow(img, cmap=cmap, vmin=0, vmax=1)
        axes[i].axis('off')
        axes[i].set_title(f'Arc {i+1}')
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.85])
    
    # Save the figure
    save_figure(fig, save_name)

def plot_umap_assignment(X_umap, S, title="Sample Assignment", s=5, alpha=0.6, save_name="umap_assignment.png"): # ADDED save_name
    """
    Plots UMAP embedding colored by dominant archetype assignment.
    """
    S = to_numpy(S)
    
    dominant_arc = np.argmax(S.T, axis=1)
    n_arc = S.shape[0]

    fig = plt.figure(figsize=(8, 8))
    scatter = plt.scatter(
        X_umap[:, 0], 
        X_umap[:, 1], 
        c=dominant_arc,
        cmap='viridis', 
        s=s, 
        alpha=alpha
    )
    
    plt.title(title, fontsize=14)
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    
    cbar = plt.colorbar(scatter, ticks=np.arange(n_arc), boundaries=np.arange(n_arc + 1) - 0.5)
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(n_arc)])
    
    # Save the figure
    save_figure(fig, save_name)

def plot_reconstruction(X, C, S, n_samples=5, img_shape=(28, 28), title="Reconstruction Quality Check", save_name="reconstruction_check.png"):
    """
    Visualizes random samples side-by-side with their reconstruction.
    X_rec = C @ S
    Now accepts a dynamic 'title' argument.
    """
    X = to_numpy(X)
    C = to_numpy(C)
    S = to_numpy(S)
    
    X_rec = (C @ S).T
    indices = np.random.choice(X.shape[0], n_samples, replace=False)
    
    fig, axes = plt.subplots(2, n_samples, figsize=(2 * n_samples, 4))
    
    for i, idx in enumerate(indices):
        orig = X[idx].reshape(img_shape)
        axes[0, i].imshow(orig, cmap='gray_r')
        axes[0, i].axis('off')
        if i == 0: axes[0, i].set_title("Original", x=-0.5, ha='right')
        
        rec = X_rec[idx].reshape(img_shape)
        rec = np.clip(rec, 0, 1) 
        axes[1, i].imshow(rec, cmap='gray_r')
        axes[1, i].axis('off')
        if i == 0: axes[1, i].set_title("Reconstructed", x=-0.5, ha='right')
        
    # FIX: Use the variable 'title' instead of the hardcoded string
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    
    save_figure(fig, save_name)
    


def plot_metric_scores(metrics, title="Model Comparison Scores", color='#4e79a7', save_name="metric_scores.png"): # ADDED save_name
    """
    Plots a single set of metric scores.
    """
    metric_names = sorted(list(metrics.keys()))
    values = [metrics[k] for k in metric_names]

    x = np.arange(len(metric_names))
    width = 0.5

    fig, ax = plt.subplots(figsize=(8, 6))
    
    rects = ax.bar(x, values, width, label='Score', color=color)

    ax.set_ylabel('Score')
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(metric_names, fontsize=11, fontweight='bold')
    
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(0, max(values) * 1.15)

    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.3f}',
                        xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=10)

    autolabel(rects)

    plt.tight_layout()
    
    # Save the figure
    save_figure(fig, save_name)


In [96]:
# --- 1. Plot Archetypes (C Matrix) ---

# Linear AA (Gaussian)
plot_archetypes(C2, title="Linear AA Archetypes", save_name="1_archetypes_linear_aa.png") 

# AAnet (Deep)
plot_archetypes(C1, title="AAnet Archetypes", save_name="1_archetypes_aanet.png")

# --- 2. Plot Sample Assignments (S Matrix) ---
# Assuming X_umap is already calculated globally
plot_umap_assignment(X_umap, S2, title="Linear AA Sample Assignment", save_name="2_umap_assignment_linear_aa.png")
plot_umap_assignment(X_umap, S1, title="AAnet Sample Assignment", save_name="2_umap_assignment_aanet.png")


# --- 3. Check Reconstruction Quality ---
# This shows how well the archetypes combined to recreate the digits
plot_reconstruction(X_raw_data, C2, S2, save_name="3_reconstruction_linear_aa.png", title="Linear AA Reconstruction Quality Check") # Linear
plot_reconstruction(X_raw_data, C1, S1, save_name="3_reconstruction_aanet.png", title="AAnet Reconstruction Quality Check") # AAnet

# --- 4. Simplex Visualization (MNIST, k=3) ---

# Plot Linear AA (Gaussian)
plot_simplex(S2, C2, title="Linear AA Simplex", img_shape=(28, 28), save_name="4_simplex_linear_aa.png")

# Plot AAnet (Deep)
plot_simplex(S1, C1, title="AAnet Simplex", img_shape=(28, 28), save_name="4_simplex_aanet.png")


# --- 5. Metrics Plot --
metrics = {
    'NMI': nmi_score,
    'Archetype Consistency': consistency,
    'In-Sample Instability (ISI)': isi
}
plot_metric_scores(metrics, title="AAnet vs Linear AA Comparison Metrics", save_name="5_comparison_metrics.png")

Saving figure to: analysis_results/mnist/1_archetypes_linear_aa.png
Saving figure to: analysis_results/mnist/1_archetypes_aanet.png
Saving figure to: analysis_results/mnist/2_umap_assignment_linear_aa.png
Saving figure to: analysis_results/mnist/2_umap_assignment_aanet.png
Saving figure to: analysis_results/mnist/3_reconstruction_linear_aa.png
Saving figure to: analysis_results/mnist/3_reconstruction_aanet.png
Saving figure to: analysis_results/mnist/4_simplex_linear_aa.png
Saving figure to: analysis_results/mnist/4_simplex_aanet.png
Saving figure to: analysis_results/mnist/5_comparison_metrics.png


In [3]:
import torch

Middata_PATH = r"Midaa/midaa_core_matrices.pth"   # use raw string or / instead of \

data = torch.load(Middata_PATH)

A = data["A"]
B = data["B"]
C = data["C"]

print(A.shape, B.shape, C.shape)




(2638, 3) (3, 2638) (3, 2)
