In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Install Dependencies

! pip install -U lightning

# Organize Imports

In [None]:
from pathlib import Path

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

# Orginize Path

In [None]:
PATH = Path('../data')
model_path = PATH / 'models' / '2_layer_128_64_sae_sigmoid'
model_path.mkdir(parents=True, exist_ok=True)
MNIST_dir = PATH / 'mnist'
MNIST_dir.mkdir(parents=True, exist_ok=True)

# Initialize Device and Workers

In [None]:
import os
 
workers = os.cpu_count()
print("Number of CPUs in the system:", workers)

In [None]:
if torch.cuda.is_available():
    device = 'gpu'  
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu',

# Initialize the Model

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import lightning as L

class SparseAutoencoder(L.LightningModule):
    def __init__(self, input_dim=784, hidden_dim1=128, hidden_dim2=64, sparsity_target=0.05, sparsity_lambda=1e-3):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim2, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, input_dim),
            nn.Sigmoid()
        )
        self.sparsity_target = sparsity_target
        self.sparsity_lambda = sparsity_lambda

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

    def compute_sparsity_penalty(self, hidden_representation):
        p_hat = torch.mean(hidden_representation, dim=0)
        kl_divergence = self.sparsity_target * torch.log(self.sparsity_target / p_hat) + \
                        (1 - self.sparsity_target) * torch.log((1 - self.sparsity_target) / (1 - p_hat))
        sparsity_penalty = self.sparsity_lambda * torch.sum(kl_divergence)
        return sparsity_penalty

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        reconstruction_loss = F.mse_loss(decoded, x)
        sparsity_penalty = self.compute_sparsity_penalty(encoded)
        loss = reconstruction_loss + sparsity_penalty
        self.log('train_loss', loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)  # Flatten the input
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        val_loss = F.mse_loss(decoded, x)
        self.log('val_loss', val_loss, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Prepare Dataset

In [None]:
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor()
])

In [None]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir='./data', batch_size=64, transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=transforms.ToTensor())

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

# Checkpointing the Model

In [None]:
best_checkpoint_path = model_path / 'best-checkpoint'
last_checkpoint_path = model_path / 'last-checkpoint'

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename=str(best_checkpoint_path),
    verbose=True
)

last_checkpoint_callback = ModelCheckpoint(
    save_last=True,
    filename=str(last_checkpoint_path),
    verbose=True
)

# Initiate Training

In [None]:
import lightning as L

data_module = MNISTDataModule(
    data_dir=MNIST_dir,
    transform=transform,
)
model = SparseAutoencoder()

trainer = L.Trainer(
    max_epochs=64,
    callbacks=[checkpoint_callback, last_checkpoint_callback],
    accelerator=device,
    devices=1  # Set to the number of GPUs available
)

trainer.fit(model, datamodule=data_module)

# Visualize Layer

In [None]:
! ls {best_checkpoint_path}

In [None]:
! ls lightning_logs/version_9/data/models/2_layer_128_64_sae_sigmoid

In [None]:
best_checkpoint_logs_path = Path('lightning_logs/version_9/data/models/2_layer_128_64_sae_sigmoid')

In [None]:
if best_checkpoint_logs_path.exists():
    model_pt = torch.load(best_checkpoint_logs_path / 'best-checkpoint.ckpt')
else:
    model_pt = None
model_pt

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

def visualize_layer_weights(model, layer_indices, num_features=16):
    """
    Visualizes the weights of specified layers in the model.

    Parameters:
    - model: The neural network model containing the layers.
    - layer_indices: List of indices of the layers to visualize.
    - num_features: Number of features (neurons or filters) to visualize per layer.
    """
    # Extract all layers from the model
    layers = list(model.children())

    for layer_index in layer_indices:
        # Check if the specified layer index is within the valid range
        if layer_index < 0 or layer_index >= len(layers):
            raise IndexError(f"Layer index {layer_index} is out of range. Model has {len(layers)} layers.")

        # Retrieve the specified layer
        layer = layers[layer_index]

        # Check if the layer has weights
        if not hasattr(layer, 'weight'):
            raise ValueError(f"Layer at index {layer_index} does not have weights.")

        # Get the weights and move them to CPU
        weights = layer.weight.data.cpu().numpy()

        # Determine the type of layer and visualize accordingly
        if isinstance(layer, nn.Conv2d):
            # For convolutional layers, visualize each filter
            num_kernels = weights.shape[0]
            num_cols = int(np.sqrt(num_features))
            num_rows = int(np.ceil(num_features / num_cols))
            fig = plt.figure(figsize=(num_cols * 2, num_rows * 2))
            for i in range(min(num_features, num_kernels)):
                ax = fig.add_subplot(num_rows, num_cols, i + 1)
                kernel = weights[i]
                # Normalize the kernel weights to [0, 1] for visualization
                kernel = (kernel - kernel.min()) / (kernel.max() - kernel.min())
                # For single-channel (grayscale) kernels
                if kernel.shape[0] == 1:
                    ax.imshow(kernel[0], cmap='gray')
                else:
                    # For multi-channel (e.g., RGB) kernels, transpose to (H, W, C)
                    ax.imshow(np.transpose(kernel, (1, 2, 0)))
                ax.axis('off')
            plt.suptitle(f'Layer {layer_index} - Conv2d Weights')
            plt.show()

        elif isinstance(layer, nn.Linear):
            # For fully connected layers, visualize each neuron's weights as heatmaps
            plt.figure(figsize=(num_features, num_features))
            for i in range(min(num_features, weights.shape[0])):
                plt.subplot(int(np.sqrt(num_features)), int(np.sqrt(num_features)), i + 1)
                # Reshape the weights to a 2D array for visualization
                weight_matrix = weights[i].reshape(1, -1)  # Shape (1, input_dim)
                plt.imshow(weight_matrix, cmap='viridis', aspect='auto')
                plt.colorbar()
                plt.axis('off')
            plt.suptitle(f'Layer {layer_index} - Linear Weights')
            plt.show()
        else:
            print(f"Visualization for layer type {type(layer)} at index {layer_index} is not supported.")

In [None]:
# Assuming 'encoder' is your encoder model instance
visualize_layer_weights(model.encoder, layer_indices=[0, 2], num_features=16)

In [None]:
model.encoder

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

def visualize_weights_as_heatmaps(model, layer_index, num_neurons=16):
    """
    Visualizes the weights of a specified layer in the model as heatmaps.

    Parameters:
    - model: The neural network model containing the layers.
    - layer_index: Index of the layer to visualize.
    - num_neurons: Number of neurons to visualize.
    """
    # Extract the specified layer
    layer = list(model.children())[layer_index]

    # Check if the layer has weights
    if not hasattr(layer, 'weight'):
        raise ValueError(f"Layer at index {layer_index} does not have weights.")

    # Get the weights and move them to CPU
    weights = layer.weight.data.cpu().numpy()

    # Plot heatmaps for the specified number of neurons
    plt.figure(figsize=(12, 12))
    for i in range(min(num_neurons, weights.shape[0])):
        plt.subplot(int(np.sqrt(num_neurons)), int(np.sqrt(num_neurons)), i + 1)
        weight_matrix = weights[i].reshape(1, -1)  # Reshape to 2D for heatmap
        plt.imshow(weight_matrix, cmap='viridis', aspect='auto')
        plt.colorbar()
        plt.title(f'Neuron {i}')
        plt.axis('off')
    plt.suptitle(f'Layer {layer_index} - Weights Heatmaps')
    plt.show()

In [None]:
visualize_weights_as_heatmaps(model.encoder, layer_index=2, num_neurons=16)

In [None]:
def plot_weight_magnitudes(model, layer_index):
    """
    Plots the magnitudes of weights for each neuron in the specified layer.

    Parameters:
    - model: The neural network model containing the layers.
    - layer_index: Index of the layer to visualize.
    """
    # Extract the specified layer
    layer = list(model.children())[layer_index]

    # Check if the layer has weights
    if not hasattr(layer, 'weight'):
        raise ValueError(f"Layer at index {layer_index} does not have weights.")

    # Get the weights and compute their magnitudes
    weights = layer.weight.data.cpu().numpy()
    magnitudes = np.linalg.norm(weights, axis=1)

    # Plot the magnitudes
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(magnitudes)), magnitudes)
    plt.xlabel('Neuron Index')
    plt.ylabel('Weight Magnitude')
    plt.title(f'Layer {layer_index} - Weight Magnitudes')
    plt.show()

In [None]:
plot_weight_magnitudes(model.encoder, layer_index=2)

In [None]:
def plot_weight_distribution(model, layer_index):
    """
    Plots the distribution of weights in the specified layer.

    Parameters:
    - model: The neural network model containing the layers.
    - layer_index: Index of the layer to visualize.
    """
    # Extract the specified layer
    layer = list(model.children())[layer_index]

    # Check if the layer has weights
    if not hasattr(layer, 'weight'):
        raise ValueError(f"Layer at index {layer_index} does not have weights.")

    # Get the weights
    weights = layer.weight.data.cpu().numpy().flatten()

    # Plot the distribution
    plt.figure(figsize=(10, 6))
    plt.hist(weights, bins=30, edgecolor='black')
    plt.xlabel('Weight Value')
    plt.ylabel('Frequency')
    plt.title(f'Layer {layer_index} - Weight Distribution')
    plt.show()

In [None]:
plot_weight_distribution(model.encoder, layer_index=2)

# Activations by Digits

In [None]:
model.encoder

In [None]:
# Define your encoder model
class Encoder(object):
    def __init__(self, net):
        super().__init__()
        self.net= net.eval()
        # Add more layers as needed

    def eval(self):
        self.net.eval()

        return self

    @torch.inference_mode()
    def forward(self, x):
        activations = {}
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.net[0](x)
        activations['layer0'] = x
        x = self.net[1](x)
        x = self.net[2](x)
        activations['layer2'] = x
        x = self.net[3](x)
        # Continue forward pass
        
        return x, activations

    def __call__(self, x):
        return self.forward(x)

In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = datasets.MNIST(root=MNIST_dir, train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

In [None]:
# Initialize the encoder
encoder = Encoder(model.encoder)

In [None]:
encoder.net

In [None]:
x, y = next(iter(data_loader))

In [None]:
x[0].shape

In [None]:
encoder

In [None]:
encoder(x)

In [None]:
encoder.net[0](x.view(x.size(0), -1))

In [None]:
# Dictionary to store activations by digit class
activations_by_digit = {digit: {'layer0': [], 'layer2': []} for digit in range(10)}

# Forward pass through the dataset
encoder.eval()
with torch.no_grad():
    for images, labels in data_loader:
        outputs, activations = encoder(images)
        for i in range(images.size(0)):
            digit = labels[i].item()
            activations_by_digit[digit]['layer0'].append(activations['layer0'][i].numpy())
            activations_by_digit[digit]['layer2'].append(activations['layer2'][i].numpy())

# Compute average activations
average_activations = {digit: {'layer0': None, 'layer2': None} for digit in range(10)}
for digit in range(10):
    average_activations[digit]['layer0'] = torch.tensor(activations_by_digit[digit]['layer0']).mean(dim=0)
    average_activations[digit]['layer2'] = torch.tensor(activations_by_digit[digit]['layer2']).mean(dim=0)

# Identify top 12 most and least active neurons
top_neurons = {digit: {'layer0': {'most_active': None, 'least_active': None},
                       'layer2': {'most_active': None, 'least_active': None}} for digit in range(10)}
top_n = 24
for digit in range(10):
    for layer in ['layer0', 'layer2']:
        avg_act = average_activations[digit][layer]
        top_neurons[digit][layer]['most_active'] = sorted(torch.topk(avg_act, top_n).indices.tolist())
        top_neurons[digit][layer]['least_active'] = sorted(torch.topk(-avg_act, top_n).indices.tolist())

# Display results
for digit in range(10):
    print(f"Digit {digit}:")
    for layer in ['layer0', 'layer2']:
        print(f"  Layer {layer}:")
        print(f"    Most active neurons: {top_neurons[digit][layer]['most_active']}")
        print(f"    Least active neurons: {top_neurons[digit][layer]['least_active']}")