# Supervised Autoencoder on CIFAR-10 — Notebook Template

**Goal:** Train a supervised autoencoder (reconstruction + classification) on CIFAR-10, with clear verification at each step.

> Fill in the TODOs step-by-step. Each section contains a brief checklist and sanity checks.

_Generated: 2025-10-21T11:48:40.685883Z_

## 0) Environment & Reproducibility
- [ ] Select device (CPU/GPU)
- [ ] Set random seeds
- [ ] (Optional) Enable cudnn benchmark

**Verify:** print device; run a tiny tensor op.

In [1]:
# TODO: imports
import os
import math
import random
from pathlib import Path
import numpy as np
import torch

# TODO: set device and seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


## 1) Config
- [ ] Define a simple config dict (dataset paths, batch size, latent_dim, λ, lr)
- [ ] Print config to confirm

**Tip:** Start simple; you can move this to YAML later.

In [3]:
# TODO: create a minimal config
CONFIG = {
    'data_root': '../data_02',
    'batch_size': 128,
    'num_workers': 4,
    'img_size': 32,
    'latent_dim': 64,          # try 32/128 later
    'lambda_recon': 0.25,      # weight for reconstruction loss
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'epochs': 50,
}
from pprint import pprint
pprint(CONFIG)

{'batch_size': 128,
 'data_root': '../data_02',
 'epochs': 50,
 'img_size': 32,
 'lambda_recon': 0.25,
 'latent_dim': 64,
 'lr': 0.001,
 'num_workers': 4,
 'weight_decay': 0.0001}


## 2) Data Pipeline (CIFAR-10)
- [ ] Define train/val/test transforms
- [ ] Build DataLoaders
- [ ] Print batch shapes and pixel ranges

**Verify:** `(B, 3, 32, 32)` and labels shape `(B,)`.

In [None]:
# TODO: implement get_transforms(augment=True/False) and get_dataloaders()
from typing import Tuple
import torchvision.transforms as T

def get_transforms(augment: bool = True):
    # TODO: import torchvision.transforms as T and define train/test transforms
    # return train_tfms, test_tfms
    train_tfms = T.Compose([
        transforms.CenterCrop(10),
        transforms.ToTensor(),
    ])
    test_tfms = T.Compose([
        transforms.ToTensor(),
    ])
    raise NotImplementedError('Define CIFAR-10 transforms here')

def get_dataloaders(cfg) -> Tuple[object, object, object]:
    # TODO: use torchvision.datasets.CIFAR10 and DataLoader
    # return train_loader, val_loader, test_loader
    raise NotImplementedError('Create CIFAR-10 DataLoaders here')

# TODO: after implementing, run a quick sanity batch
# ex:
# train_loader, val_loader, test_loader = get_dataloaders(CONFIG)
# images, labels = next(iter(train_loader))
# print(images.shape, labels.shape, images.min().item(), images.max().item())

## 3) Model — Encoder, Decoder, Classifier Head, SupervisedAE
- [ ] Implement `Encoder` → `z`
- [ ] Implement `Decoder` ← `z`
- [ ] Implement `ClassifierHead` (MLP on `z` → 10 logits)
- [ ] Implement `SupervisedAE.forward(x) → (z, x_hat, logits)`

**Verify:** Check shapes for a dummy batch.

In [None]:
# TODO: define the model classes (use small CNN blocks suitable for CIFAR-10)
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim: int = 64):
        super().__init__()
        # TODO: conv → bn → relu blocks with downsampling to small spatial map
        # self.feature_extractor = ...
        # self.to_latent = nn.Linear(flat_dim, latent_dim)
        raise NotImplementedError('Build the Encoder')
    def forward(self, x):
        # TODO: return z of shape (B, latent_dim)
        raise NotImplementedError

class Decoder(nn.Module):
    def __init__(self, latent_dim: int = 64):
        super().__init__()
        # TODO: linear → reshape → convtranspose blocks back to (B,3,32,32)
        raise NotImplementedError('Build the Decoder')
    def forward(self, z):
        # TODO: return x_hat in [0,1] if using BCE, or unbounded for MSE
        raise NotImplementedError

class ClassifierHead(nn.Module):
    def __init__(self, latent_dim: int = 64, num_classes: int = 10):
        super().__init__()
        # TODO: MLP on z → logits
        raise NotImplementedError('Build the ClassifierHead')
    def forward(self, z):
        raise NotImplementedError

class SupervisedAE(nn.Module):
    def __init__(self, latent_dim: int = 64, num_classes: int = 10):
        super().__init__()
        # TODO: compose encoder, decoder, head
        raise NotImplementedError('Assemble SupervisedAE')
    def forward(self, x):
        # TODO: return z, x_hat, logits
        raise NotImplementedError

# TODO: sanity check shapes with a dummy input once implemented
# x = torch.randn(8, 3, CONFIG['img_size'], CONFIG['img_size']).to(device)
# model = SupervisedAE(CONFIG['latent_dim']).to(device)
# with torch.no_grad():
#     z, x_hat, logits = model(x)
# print(z.shape, x_hat.shape, logits.shape)

## 4) Losses
- [ ] Reconstruction loss (MSE or BCE)
- [ ] Classification loss (CrossEntropy)
- [ ] Total loss = CE + λ * Recon (+ optional L2 on z)

**Verify:** scalar outputs; grads flow to encoder/decoder/head.

In [None]:
# TODO: implement loss functions
def reconstruction_loss(x_hat, x, loss_type='mse'):
    # TODO: return mse or bce
    raise NotImplementedError

def classification_loss(logits, y):
    # TODO: return cross entropy
    raise NotImplementedError

def total_loss(logits, x_hat, y, x, lambda_recon: float = 0.25, loss_type='mse'):
    # TODO: combine
    raise NotImplementedError

## 5) Training Loop
- [ ] Build optimizer + (optional) scheduler
- [ ] Train for N epochs, log running losses & accuracy
- [ ] Validate each epoch; keep best checkpoint

**Verify:** both CE and recon losses decrease; accuracy > random.

In [None]:
# TODO: implement train_one_epoch, evaluate, fit
def train_one_epoch(model, loader, optimizer, cfg):
    # TODO: loop over batches; compute losses; backprop; return logs
    raise NotImplementedError

def evaluate(model, loader, cfg):
    # TODO: compute val accuracy and recon loss
    raise NotImplementedError

def fit(model, train_loader, val_loader, cfg):
    # TODO: manage epochs, checkpoint best model
    raise NotImplementedError

# TODO: run training once everything above is ready
# model = SupervisedAE(CONFIG['latent_dim']).to(device)
# opt = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
# history = fit(model, train_loader, val_loader, CONFIG)

## 6) Evaluation & Visual Checks
- [ ] Test accuracy (top-1, optionally top-5)
- [ ] Reconstruction quality (grid of originals vs reconstructions)
- [ ] Extract latents `z` and visualize (PCA/UMAP/t-SNE)

**Verify:** supervised AE should show more separated clusters than AE-only.

In [None]:
# TODO: implement helper functions
def test_accuracy(model, loader):
    # TODO: compute top-1 accuracy
    raise NotImplementedError

def show_reconstructions(model, loader, n=8):
    # TODO: create and display/save a grid of reconstructions
    raise NotImplementedError

def extract_latents(model, loader):
    # TODO: concatenate z and labels
    raise NotImplementedError

def plot_latent_2d(Z, y, method='pca'):
    # TODO: reduce to 2D and plot (matplotlib)
    raise NotImplementedError

## 7) Ablations & Baselines
- [ ] Classifier-only baseline (λ=0; no decoder)
- [ ] AE-only baseline (train AE, then MLP on frozen z)
- [ ] Supervised AE (main), sweep λ ∈ {0.1, 0.25, 0.5, 1.0}, latent_dim ∈ {32,64,128}

**Record:** accuracy, recon MSE, and latent plots per setting.

## 8) Logging & Checkpoints
- [ ] Save per-epoch metrics (CSV/JSON)
- [ ] Save best model by val accuracy
- [ ] (Optional) TensorBoard

**Verify:** resume training from checkpoint works.

## 9) Notes & Next Steps
- Try label smoothing, dropout, or weight decay tweaks
- Try different reconstruction loss (BCE vs MSE)
- Try data augmentation on/off
- Try OOD score via distance to class centroids in latent space
- Consider VAE version (KL term) once supervised AE is stable