# SAE implementations

In [161]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    def __init__(self, input_size=784, hidden_size=64, k_top=20):
        super(SparseAutoencoder, self).__init__()
        self.training = True
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.k_top = k_top
        self.name = "Default Sparse Autoencoder"

        # Encoder maps input to hidden representation
        self.encoder = nn.Linear(input_size, hidden_size)

        # Decoder maps hidden representation back to input space
        self.decoder = nn.Linear(hidden_size, input_size)

    def _topk_mask(self, activations: torch.Tensor) -> torch.Tensor:
        # activations: (batch, hidden)
        k = max(0, min(self.k_top, activations.size(1)))
        _, idx = torch.topk(activations, k, dim=1)
        mask = torch.zeros_like(activations)
        mask.scatter_(1, idx, 1.0)
        return mask

    def forward(self, x):
        pre_activations = self.encoder(x)
        pre_activations = F.relu(pre_activations)
        mask = self._topk_mask(pre_activations)
        h = pre_activations * mask
        x_hat = self.decoder(h)
        return h, x_hat


    def compute_loss(self, x, h, x_hat):
        # We compute sum of squares and normalize by input dimension 'd'
        recon_loss = torch.sum((x - x_hat) ** 2) / (x.size(0) * self.input_size)

        return recon_loss

In [None]:

class SparseAutoencoderInit(SparseAutoencoder):
    def __init__(self, input_size=784, hidden_size=64, k_top=20):
        super(SparseAutoencoderInit, self).__init__(input_size, hidden_size, k_top)

        self.name = "Sparse Autoencoder with just weight initialization"
        # Initialize encoder weights first with random directions
        nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        # Initialize the decoder to be the transpose of the encoder weights
        with torch.no_grad():
            self.decoder.weight.copy_(self.encoder.weight.t())


In [None]:

class SparseAutoencoderJumpReLU(SparseAutoencoder):
    def __init__(self, input_size=784, hidden_size=64, k_top=20, jump_value=0.1):
        super(SparseAutoencoderJumpReLU, self).__init__(input_size, hidden_size, k_top)
        self.name = "Sparse Autoencoder with Jump ReLU"
        self.jump_value = jump_value

    def forward(self, x: torch.Tensor):
        h_raw = self.encoder(x)
        mask = self._topk_mask(h_raw)
        h = h_raw * mask
        # Apply JumpReLU
        h = torch.where(h > self.jump_value, h, torch.zeros_like(h))
        x_hat = self.decoder(h)
        return h, x_hat

In [None]:

class SparseAutoencoderInitJumpReLU(SparseAutoencoder):
    def __init__(self, input_size=784, hidden_size=64, k_top=20, jump_value=0.1):
        super(SparseAutoencoderInitJumpReLU, self).__init__(input_size, hidden_size, k_top)
        self.name = "Sparse Autoencoder with Initialization and Jump ReLU"
        self.jump_value = jump_value

        # Initialize encoder weights first with random directions
        nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        # Initialize the decoder to be the transpose of the encoder weights
        with torch.no_grad():
            self.decoder.weight.copy_(self.encoder.weight.t())


    def forward(self, x: torch.Tensor):
        h_raw = self.encoder(x)
        mask = self._topk_mask(h_raw)
        h = h_raw * mask
        # Apply JumpReLU
        h = torch.where(h > self.jump_value, h, torch.zeros_like(h))
        x_hat = self.decoder(h)
        return h, x_hat

Implementing auxiliary loss SAE

In [136]:

class SparseAutoencoderAuxLoss(SparseAutoencoder):
    def __init__(self, input_size, hidden_size, k_top, k_aux, k_aux_param, dead_feature_threshold):
        super(SparseAutoencoderAuxLoss, self).__init__(input_size, hidden_size, k_top)
        self.name = "Sparse Autoencoder with Auxiliary Loss"
        # k_aux is typically 2*k or more to revive dead features
        self.k_aux = k_aux if k_aux is not None else 2 * k_top
        self.k_aux_param = k_aux_param
        # Track dead features: count steps since each feature was last active
        self.register_buffer('steps_since_active', torch.zeros(hidden_size))
        self.dead_feature_threshold = dead_feature_threshold

    # Function to track which features are dead
    def _update_dead_features(self, h: torch.Tensor):
        # Feature is active if ANY sample in batch activates it
        active_mask = (h.abs() > 1e-8).any(dim=0)

        # Increment counter for inactive features, reset for active ones
        self.steps_since_active += 1
        self.steps_since_active[active_mask] = 0

    def _get_dead_feature_mask(self) -> torch.Tensor:
        """Return boolean mask of dead features"""
        return self.steps_since_active > self.dead_feature_threshold

    def forward(self, x: torch.Tensor):
        h_raw = self.encoder(x)
        mask = self._topk_mask(h_raw)
        h = h_raw * mask
        x_hat = self.decoder(h)

        # Track dead features during training
        if self.training:
            self._update_dead_features(h)

        return h, x_hat

    def compute_loss(self, x, h, x_hat):
        # Main reconstruction loss
        recon_error = torch.sum((x - x_hat) ** 2)
        recon_loss = recon_error / self.input_size

        # Auxiliary loss using dead features only
        aux_loss = torch.tensor(0.0, device=x.device)

        if self.training:
            dead_mask = self._get_dead_feature_mask()  # (hidden_size,)
            n_dead = dead_mask.sum().item()

            if n_dead > 0:
                # Compute reconstruction error: e = x - x_hat
                recon_error_vec = x - x_hat  # (batch, input_size)

                # Get raw activations again (before TopK masking)
                with torch.no_grad():
                    h_raw = self.encoder(x)

                # Select only dead features
                h_dead = h_raw * dead_mask.float().unsqueeze(0)  # (batch, hidden_size)

                # Select top-k_aux dead features
                k_aux_features = min(self.k_aux, n_dead)
                _, idx_aux = torch.topk(h_dead, k_aux_features, dim=1)
                mask_aux = torch.zeros_like(h_dead)
                mask_aux.scatter_(1, idx_aux, 1.0)

                # Sparse activations using only dead features
                z_aux = h_raw * mask_aux  # (batch, hidden_size)

                # Reconstruct error using dead features
                e_hat = self.decoder(z_aux)  # (batch, input_size)

                # Auxiliary loss: ||e - e_hat||^2
                aux_loss = torch.sum((recon_error_vec - e_hat) ** 2) / self.input_size

        # Total loss
        total_loss = recon_loss + self.k_aux_param * aux_loss

        return total_loss, recon_loss, aux_loss

Complete with relu, init and aux loss implementation.

In [137]:

class SparseAutoencoderComplete(SparseAutoencoder):
    def __init__(self, input_size, hidden_size, k_top, k_aux, k_aux_param, dead_feature_threshold, jump_value):
        super(SparseAutoencoderComplete, self).__init__(input_size, hidden_size, k_top)
        self.name = "Sparse Autoencoder with weight init., JumpReLU and Auxiliary Loss"
        self.jump_value = jump_value

        # k_aux is typically 2*k or more to revive dead features
        self.k_aux = k_aux if k_aux is not None else 2 * k_top
        self.k_aux_param = k_aux_param
        # Track dead features: count steps since each feature was last active
        self.register_buffer('steps_since_active', torch.zeros(hidden_size))
        self.dead_feature_threshold = dead_feature_threshold

        # Initialize encoder weights first with random directions
        nn.init.kaiming_uniform_(self.encoder.weight, a=math.sqrt(5))
        # Initialize the decoder to be the transpose of the encoder weights
        with torch.no_grad():
            self.decoder.weight.copy_(self.encoder.weight.t())

    # Function to track which features are dead
    def _update_dead_features(self, h: torch.Tensor):
        # Feature is active if ANY sample in batch activates it
        active_mask = (h.abs() > 1e-8).any(dim=0)

        # Increment counter for inactive features, reset for active ones
        self.steps_since_active += 1
        self.steps_since_active[active_mask] = 0

    def _get_dead_feature_mask(self) -> torch.Tensor:
        """Return boolean mask of dead features"""
        return self.steps_since_active > self.dead_feature_threshold

    def forward(self, x: torch.Tensor):
        h_raw = self.encoder(x)
        mask = self._topk_mask(h_raw)
        h = h_raw * mask
        # Apply JumpReLU
        h = torch.where(h > self.jump_value, h, torch.zeros_like(h))
        x_hat = self.decoder(h)

        # Track dead features during training
        if self.training:
            self._update_dead_features(h)

        return h, x_hat

    def compute_loss(self, x, h, x_hat):
        # Main reconstruction loss
        recon_error = torch.sum((x - x_hat) ** 2)
        recon_loss = recon_error / self.input_size

        # Auxiliary loss using dead features only
        aux_loss = torch.tensor(0.0, device=x.device)

        if self.training:
            dead_mask = self._get_dead_feature_mask()  # (hidden_size,)
            n_dead = dead_mask.sum().item()

            if n_dead > 0:
                # Compute reconstruction error: e = x - x_hat
                recon_error_vec = x - x_hat  # (batch, input_size)

                # Get raw activations again (before TopK masking)
                with torch.no_grad():
                    h_raw = self.encoder(x)

                # Select only dead features
                h_dead = h_raw * dead_mask.float().unsqueeze(0)  # (batch, hidden_size)

                # Select top-k_aux dead features
                k_aux_features = min(self.k_aux, n_dead)
                _, idx_aux = torch.topk(h_dead, k_aux_features, dim=1)
                mask_aux = torch.zeros_like(h_dead)
                mask_aux.scatter_(1, idx_aux, 1.0)

                # Sparse activations using only dead features
                z_aux = h_raw * mask_aux  # (batch, hidden_size)

                # Reconstruct error using dead features
                e_hat = self.decoder(z_aux)  # (batch, input_size)

                # Auxiliary loss: ||e - e_hat||^2
                aux_loss = torch.sum((recon_error_vec - e_hat) ** 2) / self.input_size

        # Total loss
        total_loss = recon_loss + self.k_aux_param * aux_loss

        return total_loss, recon_loss, aux_loss

# Data Loading and Preprocessing

In [138]:
from torch import optim
import torchvision
from torch.utils.data import TensorDataset, Subset
from sklearn.datasets import fetch_olivetti_faces
import torchvision.transforms as transforms

def load_mnist_data(batch_size=256):
    # First load raw data to compute mean
    raw_transform = transforms.Compose([
        transforms.ToTensor(),  # Converts to [0,1] and creates tensor
    ])

    # Load training set to compute mean
    trainset_raw = torchvision.datasets.MNIST(root='./data', train=True,
                                              download=True, transform=raw_transform)

    # Compute mean over entire training set
    train_loader_temp = DataLoader(trainset_raw, batch_size=len(trainset_raw), shuffle=False)
    all_data = next(iter(train_loader_temp))[0]
    all_data = all_data.view(all_data.size(0), -1)  # Flatten to (N, 784)
    dataset_mean = all_data.mean(dim=0)  # Mean across samples, shape (784,)

    # Define preprocessing transform with mean subtraction and normalization
    def preprocess(x):
        x_flat = x.view(-1)  # Flatten from (1, 28, 28) to (784,)
        x_centered = x_flat - dataset_mean  # Subtract mean
        x_norm = x_centered / (torch.norm(x_centered) + 1e-8)  # Normalize to unit norm
        return x_norm

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(preprocess)
    ])

    # Load datasets with proper preprocessing
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                          download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                         download=True, transform=transform)

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, dataset_mean

In [139]:

def load_olivetti_data(batch_size=32, train_split=0.8):
    """
    Load Olivetti Faces dataset (400 images, 64x64 grayscale)
    Returns data with shape (N, 4096) after flattening
    """
    # Download Olivetti Faces using sklearn
    faces = fetch_olivetti_faces(shuffle=True, random_state=42)
    data = faces.data  # Already normalized to [0, 1], shape (400, 4096)

    # Convert to torch tensors
    data_tensor = torch.FloatTensor(data)  # Shape: (400, 4096)

    # Compute mean over entire dataset
    dataset_mean = data_tensor.mean(dim=0)  # Shape: (4096,)

    # Define preprocessing function
    def preprocess(x):
        x_centered = x - dataset_mean  # Subtract mean
        x_norm = x_centered / (torch.norm(x_centered) + 1e-8)  # Unit norm
        return x_norm

    # Apply preprocessing to all data
    preprocessed_data = torch.stack([preprocess(x) for x in data_tensor])

    # Create dataset (no labels needed for autoencoder)
    dataset = TensorDataset(preprocessed_data)

    # Split into train/test
    train_size = int(train_split * len(dataset))
    test_size = len(dataset) - train_size
    trainset, testset = random_split(dataset, [train_size, test_size],
                                     generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, dataset_mean

In [140]:
import os

def load_imagenet_subset(batch_size=128, subset_size=10000, img_size=64,
                         data_root='./data/imagenet'):
    """
    Load a subset of ImageNet with preprocessing

    Args:
        batch_size: Batch size for DataLoader
        subset_size: Number of images to use (None for full dataset)
        img_size: Resize images to (img_size, img_size)
        data_root: Path to ImageNet data directory

    Returns:
        train_loader, test_loader, dataset_mean
    """

    # Raw transform for computing mean
    raw_transform = transforms.Compose([
        transforms.Resize(img_size + 8),  # Slightly larger for center crop
        transforms.CenterCrop(img_size),
        transforms.Grayscale(),  # Convert to grayscale for consistency
        transforms.ToTensor(),  # Converts to [0,1]
    ])

    # Load training set
    try:
        trainset_raw = torchvision.datasets.ImageNet(
            root=data_root,
            split='train',
            transform=raw_transform
        )
    except:
        # Alternative: Use ImageFolder if you have custom subset
        trainset_raw = torchvision.datasets.ImageFolder(
            root=os.path.join(data_root, 'train'),
            transform=raw_transform
        )

    # Create subset if specified
    if subset_size is not None and subset_size < len(trainset_raw):
        indices = torch.randperm(len(trainset_raw))[:subset_size].tolist()
        trainset_raw = Subset(trainset_raw, indices)

    # Compute mean over subset (use smaller batch for memory efficiency)
    print("Computing dataset mean...")
    temp_loader = DataLoader(trainset_raw, batch_size=min(1000, len(trainset_raw)),
                            shuffle=False, num_workers=4)

    # Accumulate mean
    mean_accumulator = None
    count = 0
    for batch_data, _ in temp_loader:
        batch_flat = batch_data.view(batch_data.size(0), -1)  # (B, img_size²)
        if mean_accumulator is None:
            mean_accumulator = batch_flat.sum(dim=0)
        else:
            mean_accumulator += batch_flat.sum(dim=0)
        count += batch_data.size(0)

    dataset_mean = mean_accumulator / count  # Shape: (img_size²,)
    print(f"Mean computed over {count} images")

    # Define preprocessing function
    input_dim = img_size * img_size  # For grayscale
    def preprocess(x):
        x_flat = x.view(-1)  # Flatten
        x_centered = x_flat - dataset_mean  # Subtract mean
        x_norm = x_centered / (torch.norm(x_centered) + 1e-8)  # Unit norm
        return x_norm

    # Final transform with preprocessing
    transform = transforms.Compose([
        transforms.Resize(img_size + 8),
        transforms.CenterCrop(img_size),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Lambda(preprocess)
    ])

    # Load datasets with proper preprocessing
    try:
        trainset = torchvision.datasets.ImageNet(
            root=data_root,
            split='train',
            transform=transform
        )
        testset = torchvision.datasets.ImageNet(
            root=data_root,
            split='val',  # Use validation set as test
            transform=transform
        )
    except:
        trainset = torchvision.datasets.ImageFolder(
            root=os.path.join(data_root, 'train'),
            transform=transform
        )
        testset = torchvision.datasets.ImageFolder(
            root=os.path.join(data_root, 'val'),
            transform=transform
        )

    # Apply subset to preprocessed data
    if subset_size is not None and subset_size < len(trainset):
        trainset = Subset(trainset, indices)

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True,
                             num_workers=4, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False,
                            num_workers=4, pin_memory=True)

    return train_loader, test_loader, dataset_mean

In [141]:
from sklearn.datasets import fetch_lfw_people
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image

def load_lfw_data(batch_size=128, img_size=64, min_faces_per_person=20):
    """
    Load Labeled Faces in the Wild dataset with proper resizing

    Args:
        batch_size: Batch size
        img_size: Resize to (img_size, img_size) - actual pixels
        min_faces_per_person: Filter people with fewer images
    """
    # Download LFW with original size
    lfw_people = fetch_lfw_people(
        min_faces_per_person=min_faces_per_person,
        resize=1.0,  # Keep original size, we'll resize manually
        color=False
    )

    print(f"Original LFW shape: {lfw_people.images.shape}")

    # Manually resize to exact dimensions
    resized_images = []
    for img in lfw_people.images:
        # Convert to PIL Image for proper resizing
        pil_img = Image.fromarray((img * 255).astype(np.uint8))
        # Resize to exact target size
        pil_img = pil_img.resize((img_size, img_size), Image.LANCZOS)
        # Back to normalized array
        resized = np.array(pil_img).astype(np.float32) / 255.0
        resized_images.append(resized.flatten())

    data_flat = np.array(resized_images)
    print(f"Resized LFW shape: {data_flat.shape}")  # Should be (n_samples, img_size²)

    # Convert to torch
    data_tensor = torch.FloatTensor(data_flat)

    # Compute mean
    dataset_mean = data_tensor.mean(dim=0)

    # Preprocess
    def preprocess(x):
        x_centered = x - dataset_mean
        x_norm = x_centered / (torch.norm(x_centered) + 1e-8)
        return x_norm

    preprocessed_data = torch.stack([preprocess(x) for x in data_tensor])

    # Create dataset
    class LFWDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            return (self.data[idx],)

    dataset = LFWDataset(preprocessed_data)

    # Split 80/20
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    trainset, testset = random_split(dataset, [train_size, test_size],
                                     generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    print(f"LFW Dataset loaded: {len(trainset)} train, {len(testset)} test")
    print(f"Image size: {img_size}×{img_size}, Input dimension: {img_size**2}")

    return train_loader, test_loader, dataset_mean

In [159]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.datasets import fetch_olivetti_faces, fetch_lfw_people
from PIL import Image
import numpy as np
import math

def load_data(dataset_name, batch_size=128, img_size=64, **kwargs):
    """
    Unified data loading function for multiple datasets.

    Args:
        dataset_name: One of ['mnist', 'olivetti', 'lfw', 'imagenet']
        batch_size: Batch size for DataLoader
        img_size: Image size for face datasets (default 64)
        **kwargs: Additional dataset-specific arguments

    Returns:
        train_loader: DataLoader for training
        test_loader: DataLoader for testing
        dataset_mean: Mean vector used for preprocessing
    """

    if dataset_name.lower() == 'mnist':
        return load_mnist_data(batch_size)

    elif dataset_name.lower() == 'olivetti':
        train_split = kwargs.get('train_split', 0.8)
        return load_olivetti_data(batch_size, train_split)

    elif dataset_name.lower() == 'lfw':
        min_faces_per_person = kwargs.get('min_faces_per_person', 20)
        return load_lfw_data(batch_size, img_size, min_faces_per_person)

    elif dataset_name.lower() == 'imagenet':
        subset_size = kwargs.get('subset_size', 10000)
        data_root = kwargs.get('data_root', './data/imagenet')
        return load_imagenet_subset(batch_size, subset_size, img_size, data_root)

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}. Choose from ['mnist', 'olivetti', 'lfw', 'imagenet']")


# Training function

In [142]:
#
def train_sparse_autoencoder(train_loader, num_epochs=50, learning_rate=0.001,
                            input_size=784, hidden_size=64, k_top=20,
                            JumpReLU=0.1, k_aux=None, k_aux_param=1/32,
                            dead_feature_threshold=1000, modelType="SAE",
                            dataset_type="mnist"):
    """
    Train sparse autoencoder with support for different datasets

    Args:
        train_loader: DataLoader for training data
        dataset_type: 'mnist', 'olivetti', or 'imagenet' to handle different unpacking
        ... (other args as before)
    """
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if modelType == "SAE":
        model = SparseAutoencoder(input_size=input_size, hidden_size=hidden_size, k_top=k_top).to(device)
    elif modelType == "SAE_Init_JumpReLU":
        model = SparseAutoencoderInitJumpReLU(input_size=input_size, hidden_size=hidden_size, k_top=k_top, jump_value=JumpReLU).to(device)
    elif modelType == "SAE_JumpReLU":
        model = SparseAutoencoderJumpReLU(input_size=input_size, hidden_size=hidden_size, k_top=k_top, jump_value=JumpReLU).to(device)
    elif modelType == "SAE_Init":
        model = SparseAutoencoderInit(input_size=input_size, hidden_size=hidden_size, k_top=k_top).to(device)
    elif modelType == "SAE_AuxLoss":
        model = SparseAutoencoderAuxLoss(input_size=input_size, hidden_size=hidden_size, k_top=k_top, k_aux=k_aux,
                                         k_aux_param=k_aux_param, dead_feature_threshold=dead_feature_threshold).to(device)
    elif modelType == "Complete":
        model = SparseAutoencoderComplete(input_size=input_size, hidden_size=hidden_size, k_top=k_top, k_aux=k_aux,
                                         k_aux_param=k_aux_param, dead_feature_threshold=dead_feature_threshold, jump_value=JumpReLU).to(device)
    else:
        raise ValueError("Invalid modelType specified.")

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for data in train_loader:
            # Handle different data loader formats
            if dataset_type == 'olivetti':
                # Olivetti returns single-element tuple: (inputs,)
                inputs, = data  # Note the comma - unpacks single element
                inputs = inputs.to(device)
            elif dataset_type in ['mnist', 'imagenet']:
                # MNIST and ImageNet return (inputs, labels)
                inputs, _ = data
                # No need to reshape - already preprocessed to correct shape
                inputs = inputs.to(device)
            else:
                raise ValueError(f"Unknown dataset_type: {dataset_type}")

            optimizer.zero_grad()
            h, outputs = model(inputs)

            if modelType == "SAE_AuxLoss" or modelType == "Complete":
                loss, mse_loss, aux_loss = model.compute_loss(inputs, h, outputs)
            else:
                loss = model.compute_loss(inputs, h, outputs)

            loss.backward()
            optimizer.step()

            # Clamp weights to enforce non-negativity
            with torch.no_grad():
                model.decoder.weight.clamp_(0.0)

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

    print('Finished Training')
    return model

# Visualization functions

In [143]:
import matplotlib.pyplot as plt
import math
import numpy as np

def visualize_weights_decoder(model, num_features=64):
    """
    Visualize decoder weights - AUTO-DETECTS dimensions
    """
    # Auto-detect input size from model
    input_size = model.decoder.weight.shape[0]
    print(f"Auto-detected input_size: {input_size}")

    # Determine image shape
    if input_size == 784:
        img_shape = (28, 28)
        dataset_type = 'mnist'
    elif input_size == 4096:
        img_shape = (64, 64)
    else:
        # Non-square or unusual size - try square root
        side = int(np.sqrt(input_size))
        if side * side == input_size:
            img_shape = (side, side)
        else:
            # Non-square - find factors
            for h in range(int(np.sqrt(input_size)), 0, -1):
                if input_size % h == 0:
                    w = input_size // h
                    img_shape = (h, w)
                    break

    print(f"Using image shape: {img_shape}")

    weights = model.decoder.weight.data.cpu().numpy().T
    num_features = min(num_features, weights.shape[0])

    # Grid dimensions
    x_images = int(math.ceil(math.sqrt(num_features)))
    y_images = int(math.ceil(num_features / x_images))

    plt.figure(figsize=(x_images * 2, y_images * 2))
    model_name = getattr(model, 'name', 'SAE')
    plt.suptitle(f'{model_name} Decoder Weights ({img_shape[0]}×{img_shape[1]})',
                 fontsize=14, y=0.995)

    for i in range(num_features):
        plt.subplot(y_images, x_images, i + 1)
        weight_img = weights[i].reshape(img_shape)

        # Normalize
        weight_img = (weight_img - weight_img.min()) / (weight_img.max() - weight_img.min() + 1e-8)

        plt.imshow(weight_img, cmap='gray', interpolation='nearest')
        plt.axis('off')
        plt.title(f'F{i}', fontsize=8)

    plt.tight_layout()
    plt.show()


def visualize_weights_encoder(model, num_features=64):
    """
    Visualize encoder weights - AUTO-DETECTS dimensions
    """
    # Get weights
    if hasattr(model.encoder, 'weight'):
        weights = model.encoder.weight.data.cpu().numpy()
    elif isinstance(model.encoder, torch.nn.Sequential):
        weights = model.encoder[0].weight.data.cpu().numpy()
    else:
        raise ValueError("Unknown encoder structure")

    # Auto-detect input size
    input_size = weights.shape[1]
    print(f"Auto-detected input_size: {input_size}")

    # Determine image shape
    if input_size == 784:
        img_shape = (28, 28)
    elif input_size == 4096:
        img_shape = (64, 64)
    else:
        side = int(np.sqrt(input_size))
        if side * side == input_size:
            img_shape = (side, side)
        else:
            for h in range(int(np.sqrt(input_size)), 0, -1):
                if input_size % h == 0:
                    w = input_size // h
                    img_shape = (h, w)
                    break

    print(f"Using image shape: {img_shape}")

    num_features = min(num_features, weights.shape[0])

    x_images = int(math.ceil(math.sqrt(num_features)))
    y_images = int(math.ceil(num_features / x_images))

    plt.figure(figsize=(x_images * 2, y_images * 2))
    model_name = getattr(model, 'name', 'SAE')
    plt.suptitle(f'{model_name} Encoder Weights ({img_shape[0]}×{img_shape[1]})',
                 fontsize=14, y=0.995)

    for i in range(num_features):
        plt.subplot(y_images, x_images, i + 1)
        weight_img = weights[i].reshape(img_shape)
        weight_img = (weight_img - weight_img.min()) / (weight_img.max() - weight_img.min() + 1e-8)
        plt.imshow(weight_img, cmap='gray', interpolation='nearest')
        plt.axis('off')
        plt.title(f'F{i}', fontsize=8)

    plt.tight_layout()
    plt.show()


def visualize_reconstructions(model, data_loader, num_samples=10, dataset_type='olivetti'):
    """
    Visualize reconstructions - AUTO-DETECTS dimensions
    """
    model.eval()
    device = next(model.parameters()).device

    # Get data
    data_iter = iter(data_loader)
    data = next(data_iter)

    if dataset_type == 'olivetti' or len(data) == 1:
        inputs, = data
    else:
        inputs, _ = data

    inputs = inputs[:num_samples].to(device)

    # Auto-detect dimensions
    input_size = inputs.shape[1]
    if input_size == 784:
        img_shape = (28, 28)
    elif input_size == 4096:
        img_shape = (64, 64)
    else:
        side = int(np.sqrt(input_size))
        if side * side == input_size:
            img_shape = (side, side)
        else:
            for h in range(int(np.sqrt(input_size)), 0, -1):
                if input_size % h == 0:
                    w = input_size // h
                    img_shape = (h, w)
                    break

    # Get reconstructions
    with torch.no_grad():
        _, reconstructions = model(inputs)

    inputs = inputs.cpu().numpy()
    reconstructions = reconstructions.cpu().numpy()

    # Plot
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 2, 4))
    model_name = getattr(model, 'name', 'SAE')
    plt.suptitle(f'{model_name} Reconstructions ({img_shape[0]}×{img_shape[1]})', fontsize=14)

    for i in range(num_samples):
        axes[0, i].imshow(inputs[i].reshape(img_shape), cmap='gray', interpolation='nearest')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)

        axes[1, i].imshow(reconstructions[i].reshape(img_shape), cmap='gray', interpolation='nearest')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=10)

    plt.tight_layout()
    plt.show()

Functions to count dead neurons and test loss on the dataset given

In [144]:

def count_dead_neurons(model, data_loader, dataset_type='mnist'):
    """
    Count dead neurons (features that never activate)

    Args:
        model: Trained SAE model
        data_loader: DataLoader with data
        dataset_type: 'mnist', 'olivetti', or 'imagenet' for proper unpacking

    Returns:
        num_dead: Number of dead neurons
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    dead_neurons = torch.ones(model.hidden_size, dtype=torch.bool).to(device)

    with torch.no_grad():
        for data in data_loader:
            # Handle different data loader formats
            if dataset_type == 'olivetti':
                inputs, = data  # Single-element tuple
            else:  # mnist or imagenet
                inputs, _ = data  # (inputs, labels) tuple

            inputs = inputs.to(device)  # Already preprocessed, no reshape needed
            h, _ = model(inputs)

            # A neuron is alive if it activates (h > 0) for any sample
            dead_neurons &= (h.sum(dim=0) == 0)

    num_dead = dead_neurons.sum().item()
    model_name = getattr(model, 'name', 'SAE')
    print(f'Number of dead neurons in {model_name}: {num_dead} out of {model.hidden_size} '
          f'({100*num_dead/model.hidden_size:.2f}%)')
    return num_dead


def test_loss(model, data_loader, dataset_type='mnist'):
    """
    Compute average test loss

    Args:
        model: Trained SAE model
        data_loader: DataLoader with test data
        dataset_type: 'mnist', 'olivetti', or 'imagenet' for proper unpacking

    Returns:
        avg_loss: Average loss over test set
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for data in data_loader:
            # Handle different data loader formats
            if dataset_type == 'olivetti':
                inputs, = data  # Single-element tuple
            else:  # mnist or imagenet
                inputs, _ = data  # (inputs, labels) tuple

            inputs = inputs.to(device)  # Already preprocessed, no reshape needed
            h, outputs = model(inputs)

            # Handle different loss outputs
            loss_output = model.compute_loss(inputs, h, outputs)
            if isinstance(loss_output, tuple):
                loss, *_ = loss_output  # Unpack if tuple (e.g., with aux loss)
            else:
                loss = loss_output

            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches
    model_name = getattr(model, 'name', 'SAE')
    print(f'Test Loss for {model_name}: {avg_loss:.6f}')
    return avg_loss


def get_activation_statistics(model, data_loader, dataset_type='mnist'):
    """
    Get comprehensive statistics about feature activations

    Args:
        model: Trained SAE model
        data_loader: DataLoader with data
        dataset_type: 'mnist', 'olivetti', or 'imagenet'

    Returns:
        stats: Dictionary with activation statistics
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    activation_counts = torch.zeros(model.hidden_size).to(device)
    activation_sums = torch.zeros(model.hidden_size).to(device)
    total_samples = 0

    with torch.no_grad():
        for data in data_loader:
            # Handle different data loader formats
            if dataset_type == 'olivetti':
                inputs, = data
            else:
                inputs, _ = data

            inputs = inputs.to(device)
            h, _ = model(inputs)

            # Count how many times each feature activates (h > 0)
            activation_counts += (h > 0).sum(dim=0).float()
            activation_sums += h.sum(dim=0)
            total_samples += inputs.size(0)

    # Move to CPU for analysis
    activation_counts = activation_counts.cpu().numpy()
    activation_sums = activation_sums.cpu().numpy()

    # Compute statistics
    activation_freq = activation_counts / total_samples  # Fraction of samples each feature activates on
    mean_activation = activation_sums / total_samples    # Average activation strength

    stats = {
        'total_features': model.hidden_size,
        'dead_features': np.sum(activation_counts == 0),
        'active_features': np.sum(activation_counts > 0),
        'mean_activation_frequency': np.mean(activation_freq),
        'median_activation_frequency': np.median(activation_freq),
        'mean_activation_strength': np.mean(mean_activation[activation_counts > 0]),  # Among active features
        'activation_frequencies': activation_freq,
        'activation_strengths': mean_activation
    }

    # Print summary
    model_name = getattr(model, 'name', 'SAE')
    print(f"\n=== Activation Statistics for {model_name} ===")
    print(f"Total features: {stats['total_features']}")
    print(f"Dead features: {stats['dead_features']} ({100*stats['dead_features']/stats['total_features']:.2f}%)")
    print(f"Active features: {stats['active_features']} ({100*stats['active_features']/stats['total_features']:.2f}%)")
    print(f"Mean activation frequency: {stats['mean_activation_frequency']:.4f}")
    print(f"Median activation frequency: {stats['median_activation_frequency']:.4f}")
    print(f"Mean activation strength (active features): {stats['mean_activation_strength']:.4f}")

    return stats


def plot_activation_histogram(model, data_loader, dataset_type='mnist'):
    """
    Plot histogram of feature activation frequencies

    Args:
        model: Trained SAE model
        data_loader: DataLoader with data
        dataset_type: 'mnist', 'olivetti', or 'imagenet'
    """
    stats = get_activation_statistics(model, data_loader, dataset_type)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    model_name = getattr(model, 'name', 'SAE')

    # Histogram of activation frequencies
    axes[0].hist(stats['activation_frequencies'], bins=50, edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Activation Frequency (fraction of samples)')
    axes[0].set_ylabel('Number of Features')
    axes[0].set_title(f'{model_name}: Feature Activation Frequencies')
    axes[0].axvline(stats['mean_activation_frequency'], color='r', linestyle='--',
                    label=f'Mean: {stats["mean_activation_frequency"]:.4f}')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Histogram of activation strengths (excluding dead features)
    active_strengths = stats['activation_strengths'][stats['activation_strengths'] > 0]
    axes[1].hist(active_strengths, bins=50, edgecolor='black', alpha=0.7, color='green')
    axes[1].set_xlabel('Mean Activation Strength')
    axes[1].set_ylabel('Number of Features')
    axes[1].set_title(f'{model_name}: Feature Activation Strengths (Active Features Only)')
    axes[1].axvline(stats['mean_activation_strength'], color='r', linestyle='--',
                    label=f'Mean: {stats["mean_activation_strength"]:.4f}')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


Initializing

In [145]:
# train_loader, test_loader, mean = load_olivetti_data(batch_size=32)
train_loader, test_loader, mean = load_lfw_data(batch_size=128, img_size=64)

Original LFW shape: (3023, 125, 94)
Resized LFW shape: (3023, 4096)
LFW Dataset loaded: 2418 train, 605 test
Image size: 64×64, Input dimension: 4096


# Base usage

In [155]:
# train_loader, test_loader, mean = load_lfw_data(batch_size=128, img_size=64)
#
# modelBase = train_sparse_autoencoder(
#     train_loader,
#     num_epochs=50,
#     learning_rate=0.001,
#     input_size=4096,
#     hidden_size=1024,        # ← LARGER dictionary for more parts
#     k_top=128,               # ← HIGHER k (12.5% sparsity) for combining parts
#     k_aux=512,               # ← Use auxiliary loss
#     k_aux_param=1/32,
#     modelType="SAE_AuxLoss", # ← Use AuxLoss to prevent dead neurons!
#     dataset_type="olivetti"
# )


Original LFW shape: (3023, 125, 94)
Resized LFW shape: (3023, 4096)
LFW Dataset loaded: 2418 train, 605 test
Image size: 64×64, Input dimension: 4096
Epoch [1/50], Loss: 0.0662
Epoch [2/50], Loss: 0.0246
Epoch [3/50], Loss: 0.0168
Epoch [4/50], Loss: 0.0134
Epoch [5/50], Loss: 0.0115
Epoch [6/50], Loss: 0.0103
Epoch [7/50], Loss: 0.0094
Epoch [8/50], Loss: 0.0088
Epoch [9/50], Loss: 0.0082
Epoch [10/50], Loss: 0.0078
Epoch [11/50], Loss: 0.0075
Epoch [12/50], Loss: 0.0073
Epoch [13/50], Loss: 0.0070
Epoch [14/50], Loss: 0.0066
Epoch [15/50], Loss: 0.0063
Epoch [16/50], Loss: 0.0061
Epoch [17/50], Loss: 0.0060
Epoch [18/50], Loss: 0.0059
Epoch [19/50], Loss: 0.0059
Epoch [20/50], Loss: 0.0061
Epoch [21/50], Loss: 0.0060
Epoch [22/50], Loss: 0.0058
Epoch [23/50], Loss: 0.0056
Epoch [24/50], Loss: 0.0055
Epoch [25/50], Loss: 0.0053
Epoch [26/50], Loss: 0.0053
Epoch [27/50], Loss: 0.0054
Epoch [28/50], Loss: 0.0055
Epoch [29/50], Loss: 0.0055
Epoch [30/50], Loss: 0.0054
Epoch [31/50], Loss

In [None]:
# visualize_weights_decoder(modelBase, num_features=64)
# count_dead_neurons(modelBase, train_loader, dataset_type='olivetti')
# test_loss(modelBase, test_loader, dataset_type='olivetti')
# plot_activation_histogram(modelBase, train_loader, dataset_type='olivetti')

# TopK Sparsity Analysis

In [160]:
# Test: K sensitivity sweep
dataset_configs = [
    {'name': 'mnist', 'input_size': 784, 'batch_size': 256},
    {'name': 'olivetti', 'input_size': 4096, 'batch_size': 32},
    {'name': 'lfw', 'input_size': 4096, 'batch_size': 128}
]

k_values = [5, 10, 20, 30, 40, 50, 64, 100, 128]
hidden_size = 256  # Fixed overcomplete representation

results = []
for dataset_config in dataset_configs:
    train_loader, test_loader, mean = load_data(dataset_config['name'],
                                                 dataset_config['batch_size'])

    for k in k_values:
        model = train_sparse_autoencoder(
            train_loader,
            num_epochs=50,
            input_size=dataset_config['input_size'],
            hidden_size=hidden_size,
            k_top=k,
            modelType="SAE",
            dataset_type=dataset_config['name']
        )

        # Metrics
        test_mse = test_loss(model, test_loader, dataset_config['name'])
        dead_neurons = count_dead_neurons(model, train_loader, dataset_config['name'])
        stats = get_activation_statistics(model, train_loader, dataset_config['name'])

        results.append({
            'dataset': dataset_config['name'],
            'k': k,
            'sparsity_ratio': k/hidden_size,
            'test_mse': test_mse,
            'dead_neurons': dead_neurons,
            'active_features': stats['active_features'],
            'mean_activation_freq': stats['mean_activation_frequency']
        })


Epoch [1/50], Loss: 0.2255
Epoch [2/50], Loss: 0.1463
Epoch [3/50], Loss: 0.1416
Epoch [4/50], Loss: 0.1397
Epoch [5/50], Loss: 0.1386
Epoch [6/50], Loss: 0.1377
Epoch [7/50], Loss: 0.1371
Epoch [8/50], Loss: 0.1368
Epoch [9/50], Loss: 0.1361
Epoch [10/50], Loss: 0.1357
Epoch [11/50], Loss: 0.1355
Epoch [12/50], Loss: 0.1351
Epoch [13/50], Loss: 0.1349
Epoch [14/50], Loss: 0.1346
Epoch [15/50], Loss: 0.1338
Epoch [16/50], Loss: 0.1328
Epoch [17/50], Loss: 0.1323
Epoch [18/50], Loss: 0.1318
Epoch [19/50], Loss: 0.1315
Epoch [20/50], Loss: 0.1309
Epoch [21/50], Loss: 0.1305
Epoch [22/50], Loss: 0.1303
Epoch [23/50], Loss: 0.1302
Epoch [24/50], Loss: 0.1299
Epoch [25/50], Loss: 0.1297
Epoch [26/50], Loss: 0.1295
Epoch [27/50], Loss: 0.1293
Epoch [28/50], Loss: 0.1289
Epoch [29/50], Loss: 0.1285
Epoch [30/50], Loss: 0.1283
Epoch [31/50], Loss: 0.1282
Epoch [32/50], Loss: 0.1281
Epoch [33/50], Loss: 0.1279
Epoch [34/50], Loss: 0.1275
Epoch [35/50], Loss: 0.1272
Epoch [36/50], Loss: 0.1270
E

KeyboardInterrupt: 