# Train the SAE and Models for Meta-SAE on decoder WEIGHTS (not activations)

In [1]:
%load_ext autoreload
%autoreload 2

### Env Setting

In [2]:
import sys
sys.path.append('../../')

In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from structs.models import CIFAR100Model

In [4]:
import os 
path = '/Volumes/Ayush_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

In [5]:
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 10

### Dataset Loading

In [6]:
root=f'{prefix}/data'

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load training data
train_dataset = datasets.CIFAR100(
    root=root,
    train=True,
    download=True,
    transform=transform
)

# Load test data
test_dataset = datasets.CIFAR100(
    root=root,
    train=False,
    download=True,
    transform=transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False
)

# Load Model

In [7]:

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )
            
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        # This is specifically modified for CIFAR: 
        # Smaller initial conv with 3x3 kernel instead of 7x7
        # No initial max pooling to preserve spatial information
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x, cache_activations=False):
        # If caching is enabled, we'll store activations
        activations = {}
        
        out = torch.relu(self.bn1(self.conv1(x)))
        if cache_activations:
            activations['conv1'] = out.detach().clone()
            
        out = self.layer1(out)
        if cache_activations:
            activations['layer1'] = out.detach().clone()
            
        out = self.layer2(out)
        if cache_activations:
            activations['layer2'] = out.detach().clone()
            
        out = self.layer3(out)
        if cache_activations:
            activations['layer3'] = out.detach().clone()
            
        out = self.layer4(out)
        if cache_activations:
            activations['layer4'] = out.detach().clone()
            
        out = torch.nn.functional.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        
        fc_features = out
        if cache_activations:
            activations['fc_features'] = fc_features.detach().clone()
            
        out = self.linear(out)
        if cache_activations:
            activations['output'] = out.detach().clone()
            
        if cache_activations:
            return out, activations
        return out

# Create ResNet18 model
def ResNet18(num_classes=100):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
#* This is the model for CIFAR100 dataset

In [8]:
# from structs.models import ResNet, ResNet18, BasicBlock

def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('mps'))
    
    # Load model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load optimizer state
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # # Load scheduler state
    # scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get the epoch and accuracy information
    epoch = checkpoint['epoch']
    accuracy = checkpoint['accuracy']
    
    # Return important values that might be needed for resuming training
    return model, epoch, accuracy


# Initialize the model
model = ResNet18()
# load in model from checkpoint
checkpoint_path = f'{prefix}/embeddings/cifar100/cifar100_model.pth'
if os.path.exists(checkpoint_path):
    model, epoch, accuracy = load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None)
    print(f"Checkpoint loaded successfully from {checkpoint_path}")
    print(f"Model was trained for {epoch} epochs with accuracy: {accuracy}")
else:
    raise Exception(f"Checkpoint not found at {checkpoint_path}")

model.eval()

Checkpoint loaded successfully from /Volumes/Ayush_Drive/mnist//embeddings/cifar100/cifar100_model.pth
Model was trained for 25 epochs with accuracy: 58.85


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [9]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections

def extract_activations(model, dataset, layer_names=None, batch_size=64, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Extract activations from specified layers of a neural network.
    
    Args:
        model: The neural network model (torch.nn.Module)
        dataset: The dataset to extract activations from
        layer_names: List of layer names to extract activations from. If None, extracts from all layers with hooks
        batch_size: Batch size for data loading
        device: Device to run the model on ('cuda' or 'cpu')
        
    Returns:
        Dictionary mapping layer names to their activations (tensors)
    """
    # Move model to device and set to evaluation mode
    model = model.to(device)
    model.eval()
    
    # Create data loader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Dictionary to store activations
    activations = collections.defaultdict(list)
    
    # If the model has a built-in activation cache mechanism (like your MNIST model)
    if hasattr(model, 'clear_cache') and hasattr(model, 'get_cached_activations'):
        model.clear_cache()
        
        # Run forward pass through the model to populate cache
        with torch.no_grad():
            for inputs, _ in tqdm(data_loader, desc="Extracting activations"):
                inputs = inputs.to(device)
                _ = model(inputs)
        
        # Get activations from cache
        if layer_names is None:
            # Get all available layer activations
            layer_names = list(model.activation_cache.keys())
        
        for layer_name in layer_names:
            if layer_name in model.activation_cache:
                layer_activations = model.get_cached_activations(layer_name)
                activations[layer_name] = layer_activations
    
    # For models without built-in caching (like ResNet), use hooks
    else:
        # Storage for hooks
        hooks = []
        
        # Set up forward hooks
        def get_activation(name):
            def hook(module, input, output):
                # For convolutional layers, flatten spatial dimensions
                if len(output.shape) == 4:  # [batch_size, channels, height, width]
                    # Keep batch dimension and flatten the rest
                    flattened = output.view(output.size(0), -1)
                    activations[name].append(flattened.cpu().detach())
                else:
                    activations[name].append(output.cpu().detach())
            return hook
        
        # Register hooks for specified layers
        if layer_names:
            for name, module in model.named_modules():
                if name in layer_names:
                    hooks.append(module.register_forward_hook(get_activation(name)))
        else:
            # If no specific layers are requested, hook into all possible layers
            for name, module in model.named_modules():
                if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.BatchNorm2d)) or "layer" in name:
                    hooks.append(module.register_forward_hook(get_activation(name)))
        
        # Run forward pass
        with torch.no_grad():
            for inputs, _ in tqdm(data_loader, desc="Extracting activations"):
                inputs = inputs.to(device)
                _ = model(inputs)
        
        # Concatenate batched activations and remove hooks
        for name in activations:
            activations[name] = torch.cat(activations[name], dim=0)
        
        for hook in hooks:
            hook.remove()
    
    # Convert from defaultdict to regular dict
    return dict(activations)

concat_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
# Example usage:
all_activations = extract_activations(model, concat_dataset, layer_names=['layer3', 'layer4'])

Extracting activations: 100%|██████████| 938/938 [25:54<00:00,  1.66s/it]


In [10]:
layer3_activations = all_activations['layer3']
layer4_activations = all_activations['layer4']
# Check the shape of the activations
print(f"Layer 3 activations shape: {layer3_activations.shape}")
print(f"Layer 4 activations shape: {layer4_activations.shape}")

Layer 3 activations shape: torch.Size([60000, 16384])
Layer 4 activations shape: torch.Size([60000, 8192])


In [11]:
torch.save(layer3_activations, f'{prefix}/embeddings/cifar100/layer3_activations.pth')
torch.save(layer4_activations, f'{prefix}/embeddings/cifar100/layer4_activations.pth')

In [12]:
# add loading code only if needed
# Load the activations from file
if layer3_activations is None:
    layer3_activations = torch.load(f'{prefix}/embeddings/cifar100/layer3_activations.pth')
if layer4_activations is None:
    layer4_activations = torch.load(f'{prefix}/embeddings/cifar100/layer4_activations.pth')