In [1]:
# %% [markdown]
# # Archetype Analysis on MNIST
# Comparing archetypes from AAnet, MIDA, and Linear AA.

# %% 
import os
import numpy as np
import torch
from scipy.spatial.distance import pdist, squareform
from scipy.stats import entropy
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from torchvision import datasets, transforms
import umap

# --- Configuration ---
N_ARCHETYPES = 3
N_SAMPLES = 5842
N_PROTOTYPES = 20
IMG_SHAPE = (28, 28)

# Paths
LINEAR_AA_PATH = 'LinearAA/Python/mnist_gaussian_aa_results.pth'
AANET_PATH = 'AAnet/example_notebooks/results/AAnet_MNIST_digit4_results.npz'
MIDDATA_PATH = r"Midaa/midaa_core_matrices.pth"
SAVE_DIR = "analysis_results/mnist"
os.makedirs(SAVE_DIR, exist_ok=True)


  Referenced from: <253997FD-685F-34A9-B3D7-4AF6DAE96CDF> /opt/anaconda3/envs/sae_env/lib/python3.11/site-packages/torchvision/image.so
  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Helper Functions ---

def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

def preprocess(X):
    meanX = np.mean(X, axis=0)
    X_centered = X - meanX
    mSST = np.sum(np.mean(X_centered**2, axis=0))
    return X_centered, mSST

def calcMI(z1, z2):
    eps = 1e-16
    P = z1 @ z2.T
    PXY = P / P.sum()
    PXPY = np.outer(PXY.sum(1), PXY.sum(0))
    return np.sum(PXY * np.log(eps + PXY / (eps + PXPY)))

def calcNMI(z1, z2):
    return 2 * calcMI(z1, z2) / (calcMI(z1, z1) + calcMI(z2, z2))

def ArchetypeConsistency(XC1, XC2, mSST):
    D = squareform(pdist(np.hstack((XC1, XC2)).T, 'euclidean'))**2
    D = D[:XC1.shape[1], XC1.shape[1]:]
    
    i, j, v = [], [], []
    D_temp = D.copy()
    for k in range(XC1.shape[1]):
        min_index = np.unravel_index(np.argmin(D_temp, axis=None), D_temp.shape)
        i.append(min_index[0])
        j.append(min_index[1])
        v.append(D[i[-1], j[-1]])
        D_temp[i[-1], :] = np.inf
        D_temp[:, j[-1]] = np.inf
        
    consistency = 1 - np.mean(v) / mSST
    D2 = np.abs(np.corrcoef(np.hstack((XC1, XC2)).T))[:XC1.shape[1], XC1.shape[1]:]
    ISI = 1 / (2 * XC1.shape[1] * (XC1.shape[1]-1)) * \
          (np.sum(D2/np.max(D2, axis=1, keepdims=True) + D2/np.max(D2, axis=0, keepdims=True)) - 2*XC1.shape[1])
    return consistency, ISI


In [3]:
def get_aanet_matrices(npz_path, X_raw, n_archetypes=N_ARCHETYPES, n_prototypes=N_PROTOTYPES):
    data = np.load(npz_path)
    S = data['latent_coords'].T
    C = np.zeros((X_raw.shape[1], n_archetypes))
    
    for i in range(n_archetypes):
        top_idx = np.argsort(S[i, :])[::-1][:n_prototypes]
        C[:, i] = X_raw[top_idx].mean(axis=0)
        
    return S, C, X_raw

def get_linear_matrices(pth_path, X_tensor):
    checkpoint = torch.load(pth_path)
    S = to_numpy(checkpoint['S'])
    A = to_numpy(checkpoint['C'])
    X = to_numpy(X_tensor)
    C = X @ A
    return S, C, X.T  # X.T for mSST computation


In [17]:
import matplotlib.pyplot as plt
import numpy as np

def to_numpy(tensor):
    """Convert tensor to numpy array if needed."""
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

def plot_archetypes(C, img_shape=(28, 28), title="Archetypes", cmap='gray_r', save_name=None):
    """
    Plots the archetypes (columns of C) as images.

    Args:
        C: np.array or torch.Tensor, shape (n_features, n_archetypes)
        img_shape: tuple, e.g., (28, 28)
        title: figure title
        cmap: colormap for images
        save_name: filename to save the figure (optional)
    """
    C = to_numpy(C)
    n_archetypes = C.shape[1]

    fig, axes = plt.subplots(1, n_archetypes, figsize=(2 * n_archetypes, 3))
    if n_archetypes == 1:
        axes = [axes]

    for i in range(n_archetypes):
        img = C[:, i].reshape(img_shape)
        axes[i].imshow(img, cmap=cmap, vmin=0, vmax=1)
        axes[i].axis('off')
        axes[i].set_title(f'Arc {i+1}')

    plt.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.85])

    if save_name is not None:
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        plt.savefig(save_name, bbox_inches='tight')
        print(f"Saved figure to: {save_name}")
        plt.close(fig)
    else:
        plt.show()

import matplotlib.pyplot as plt
import numpy as np

def to_numpy(tensor):
    """Convert tensor to numpy array if needed."""
    if isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    return tensor

def plot_umap_assignment(X_umap, S, title="Sample Assignment", s=5, alpha=0.6, save_name=None):
    """
    Plots UMAP embedding colored by dominant archetype assignment.

    Args:
        X_umap: np.array, shape (N_samples, 2) -> UMAP 2D embedding
        S: np.array or torch.Tensor, shape (k, N_samples) -> Archetype coefficients
        title: figure title
        s: marker size
        alpha: marker transparency
        save_name: optional filename to save figure
    """
    S = to_numpy(S)
    dominant_arc = np.argmax(S.T, axis=1)
    n_arc = S.shape[0]

    fig, ax = plt.subplots(figsize=(8, 8))
    scatter = ax.scatter(
        X_umap[:, 0],
        X_umap[:, 1],
        c=dominant_arc,
        cmap='viridis',
        s=s,
        alpha=alpha
    )

    ax.set_title(title, fontsize=14)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')

    cbar = plt.colorbar(scatter, ax=ax, ticks=np.arange(n_arc), boundaries=np.arange(n_arc + 1) - 0.5)
    cbar.set_ticklabels([f'Arc {i+1}' for i in range(n_arc)])

    if save_name is not None:
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        plt.savefig(save_name, bbox_inches='tight')
        print(f"Saved figure to: {save_name}")
        plt.close(fig)
    else:
        plt.show()
        
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.stats import entropy
import numpy as np
import os

def plot_simplex(S, C, img_shape=(28,28), title="Simplex Analysis", cmap='gray_r', save_name=None):
    """
    Plots samples in a Barycentric (triangle) projection with archetypes at vertices.
    Only works for k=3.
    
    Args:
        S: (k, N_samples) coefficient matrix
        C: (n_features, k) archetype matrix
        img_shape: tuple for reshaping archetype images
        title: figure title
        cmap: colormap for images
        save_name: optional path to save figure
    """
    S = to_numpy(S)
    C = to_numpy(C)
    k, N = S.shape
    if k != 3:
        print(f"⚠️ Simplex plot requires k=3, got k={k}. Skipping plot.")
        return

    # Convert to 2D simplex coordinates
    S_cells = S.T
    x = S_cells[:, 1] + 0.5 * S_cells[:, 2]
    y = (np.sqrt(3)/2) * S_cells[:, 2]
    X_simplex = np.column_stack([x, y])
    
    # Vertex positions
    vertices = np.array([[0,0], [1,0], [0.5, np.sqrt(3)/2]])
    centroid = vertices.mean(axis=0)
    
    # Mixing strength (entropy)
    mix_strength = np.array([entropy(S_cells[i]) for i in range(N)])
    mix_strength = np.where(np.isfinite(mix_strength), mix_strength, 0)
    mix_strength /= np.log(k)
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 7))
    fig.suptitle(title, fontsize=16, fontweight='bold')

    # Left: simplex scatter
    scatter = ax1.scatter(X_simplex[:,0], X_simplex[:,1], c=mix_strength, cmap='RdYlGn_r', s=10, alpha=0.6, edgecolors='none', vmin=0, vmax=1)
    triangle = plt.Polygon(vertices, fill=False, edgecolor='black', linewidth=2, zorder=0)
    ax1.add_patch(triangle)

    # Add archetype images at vertices
    for i in range(k):
        img = C[:,i].reshape(img_shape)
        im_box = OffsetImage(img, zoom=2.0, cmap=cmap)
        ab = AnnotationBbox(im_box, vertices[i], frameon=True, bboxprops=dict(edgecolor='red', linewidth=2))
        ax1.add_artist(ab)
    
    ax1.set_xlim(-0.3, 1.3)
    ax1.set_ylim(-0.3, 1.1)
    ax1.axis('off')
    ax1.set_aspect('equal')
    cbar = plt.colorbar(scatter, ax=ax1, orientation='vertical', fraction=0.03, pad=0.04)
    cbar.set_label('Mixing Strength (Normalized Entropy)')

    # Right: histogram of mixing strength
    ax2.hist(mix_strength, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
    mean_mix = mix_strength.mean()
    ax2.axvline(mean_mix, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_mix:.2f}')
    ax2.set_title('Distribution of Mixing Strength')
    ax2.set_xlabel('Mixing Strength (0=Pure, 1=Mixed)')
    ax2.set_ylabel('Count')
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout(rect=[0,0,1,0.95])

    if save_name is not None:
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        plt.savefig(save_name, bbox_inches='tight')
        print(f"Saved figure to: {save_name}")
        plt.close(fig)
    else:
        plt.show()

def plot_reconstruction(X, C, S, n_samples=5, img_shape=(28,28), title="Reconstruction Quality Check", save_name=None):
    """
    Visualizes random samples side-by-side with their reconstructions.
    
    Args:
        X: (N_samples, n_features) original data
        C: (n_features, k) archetype matrix
        S: (k, N_samples) coefficient matrix
        n_samples: number of random samples to show
        img_shape: reshape for visualization
        title: figure title
        save_name: optional path to save figure
    """
    X = to_numpy(X)
    C = to_numpy(C)
    S = to_numpy(S)
    
    X_rec = (C @ S).T
    indices = np.random.choice(X.shape[0], n_samples, replace=False)
    
    fig, axes = plt.subplots(2, n_samples, figsize=(2*n_samples, 4))
    
    for i, idx in enumerate(indices):
        axes[0,i].imshow(X[idx].reshape(img_shape), cmap='gray_r')
        axes[0,i].axis('off')
        if i==0: axes[0,i].set_title("Original", x=-0.5, ha='right')

        rec = np.clip(X_rec[idx].reshape(img_shape), 0, 1)
        axes[1,i].imshow(rec, cmap='gray_r')
        axes[1,i].axis('off')
        if i==0: axes[1,i].set_title("Reconstructed", x=-0.5, ha='right')
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    
    if save_name is not None:
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        plt.savefig(save_name, bbox_inches='tight')
        print(f"Saved figure to: {save_name}")
        plt.close(fig)
    else:
        plt.show()




In [18]:
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mask = mnist.targets == 4
indices = torch.where(mask)[0][:N_SAMPLES]

X_raw = mnist.data[indices].float() / 255
X_raw = X_raw.flatten(1).numpy()
X_tensor = torch.from_numpy(X_raw).t().double()


In [19]:
S1, C1, X1 = get_aanet_matrices(AANET_PATH, X_raw)
S2, C2, X2 = get_linear_matrices(LINEAR_AA_PATH, X_tensor)

X_centered, mSST_val = preprocess(X1)
nmi_score = calcNMI(S1, S2)
consistency, isi = ArchetypeConsistency(C1, C2, mSST_val)

print(f"NMI: {nmi_score:.4f}, Consistency: {consistency:.4f}, ISI: {isi:.4f}")


NMI: 0.6580, Consistency: 0.1512, ISI: 0.6402


  checkpoint = torch.load(pth_path)


In [None]:
# --- Archetypes ---
plot_archetypes(
    C1, img_shape=(28, 28), 
    title="AAnet Archetypes", 
    save_name=os.path.join(SAVE_DIR, "aanet_archetypes.png")
)

plot_archetypes(
    C2, img_shape=(28, 28), 
    title="Linear AA Archetypes", 
    save_name=os.path.join(SAVE_DIR, "linear_aa_archetypes.png")
)

# --- Sample assignment (UMAP) ---
X_umap = umap.UMAP(n_components=2, random_state=42).fit_transform(X_centered)

plot_umap_assignment(
    X_umap, S1, 
    title="AAnet Sample Assignment", 
    save_name=os.path.join(SAVE_DIR, "aanet_umap.png")
)

plot_umap_assignment(
    X_umap, S2, 
    title="Linear AA Sample Assignment", 
    save_name=os.path.join(SAVE_DIR, "linear_aa_umap.png")
)

# --- Reconstruction quality ---
plot_reconstruction(
    X_raw, C1, S1, 
    title="AAnet Reconstruction", 
    save_name=os.path.join(SAVE_DIR, "aanet_reconstruction.png")
)

plot_reconstruction(
    X_raw, C2, S2, 
    title="Linear AA Reconstruction", 
    save_name=os.path.join(SAVE_DIR, "linear_aa_reconstruction.png")
)

# --- Simplex visualization (k=3 only) ---
plot_simplex(
    S1, C1, 
    title="AAnet Simplex", 
    save_name=os.path.join(SAVE_DIR, "aanet_simplex.png")
)

plot_simplex(
    S2, C2, 
    title="Linear AA Simplex", 
    save_name=os.path.join(SAVE_DIR, "linear_aa_simplex.png")
)

# --- Metrics comparison ---
metrics = {'NMI': nmi_score, 'Consistency': consistency, 'ISI': isi}
plot_metric_scores(
    metrics, 
    title="AAnet vs Linear AA Comparison Metrics", 
    save_name=os.path.join(SAVE_DIR, "comparison_metrics.png")
)


Saved figure to: analysis_results/mnist/aanet_archetypes.png
Saved figure to: analysis_results/mnist/linear_aa_archetypes.png


  warn(
