<a href="https://colab.research.google.com/github/neelsoumya/llm_projects/blob/main/SAE_linear_probes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SAE and linear probe

## Installation

In [1]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install tqdm

Looking in indexes: https://download.pytorch.org/whl/cu121


## Code

In [2]:
"""
sae_and_linear_probe.py

PyTorch example:
 - Train a stacked autoencoder (SAE) on MNIST
 - Freeze the encoder and train a linear probe on the encoder embeddings

Requirements:
  pip install torch torchvision tqdm
"""

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os

# -------------------------
# Model: Stacked Autoencoder
# -------------------------
class StackedAutoencoder(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dims=[1024, 512, 256], bottleneck_dim=64):
        """
        hidden_dims: list, e.g. [1024, 512, 256]
        bottleneck_dim: int, dimensionality of the embedding
        """
        super().__init__()
        # Encoder layers
        enc_layers = []
        in_dim = input_dim
        for h in hidden_dims:
            enc_layers.append(nn.Linear(in_dim, h))
            enc_layers.append(nn.ReLU(inplace=True))
            in_dim = h
        enc_layers.append(nn.Linear(in_dim, bottleneck_dim))  # final bottleneck (no activation here)
        self.encoder = nn.Sequential(*enc_layers)

        # Decoder layers (mirror)
        dec_layers = []
        in_dim = bottleneck_dim
        for h in reversed(hidden_dims):
            dec_layers.append(nn.Linear(in_dim, h))
            dec_layers.append(nn.ReLU(inplace=True))
            in_dim = h
        dec_layers.append(nn.Linear(in_dim, input_dim))
        # We'll output logits and use BCEWithLogitsLoss for stability
        self.decoder = nn.Sequential(*dec_layers)

    def forward(self, x):
        # x: (B, input_dim) assumed flattened 0..1
        z = self.encoder(x)
        recon_logits = self.decoder(z)
        return recon_logits, z

# -------------------------
# Utility: training loops
# -------------------------
def train_autoencoder(model, dataloader, device, epochs=10, lr=1e-3, save_path=None):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()  # expects logits
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        pbar = tqdm(dataloader, desc=f"AE Train Epoch {epoch}/{epochs}")
        for imgs, _ in pbar:
            imgs = imgs.to(device)
            imgs = imgs.view(imgs.size(0), -1)  # flatten
            optimizer.zero_grad()
            recon_logits, _ = model(imgs)
            loss = criterion(recon_logits, imgs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * imgs.size(0)
            pbar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(dataloader.dataset)
        print(f"[AE] Epoch {epoch} average loss: {avg_loss:.6f}")
        if save_path:
            torch.save(model.state_dict(), save_path)

def extract_embeddings(model, dataloader, device):
    """Return (embeddings_tensor, labels_tensor) for dataset in dataloader."""
    model.to(device)
    model.eval()
    all_z = []
    all_y = []
    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc="Extract embeddings"):
            imgs = imgs.to(device)
            imgs = imgs.view(imgs.size(0), -1)
            _, z = model(imgs)
            all_z.append(z.cpu())
            all_y.append(labels)
    embeddings = torch.cat(all_z, dim=0)
    labels = torch.cat(all_y, dim=0)
    return embeddings, labels

def train_linear_probe(encoder, train_loader, val_loader, device, embed_dim, num_classes=10,
                       epochs=20, lr=1e-3):
    """
    encoder: encoder module (nn.Module) - expected to output embedding when given flattened input
    We assume encoder is already trained (or at least useful).
    We'll freeze encoder parameters and train a linear classifier on top.
    """
    # Freeze encoder
    encoder.to(device)
    for p in encoder.parameters():
        p.requires_grad = False
    encoder.eval()

    # linear probe
    probe = nn.Linear(embed_dim, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(probe.parameters(), lr=lr)

    for epoch in range(1, epochs + 1):
        probe.train()
        running_loss = 0.0
        running_correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f"Probe Train Epoch {epoch}/{epochs}")
        for imgs, labels in pbar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            imgs = imgs.view(imgs.size(0), -1)
            with torch.no_grad():
                z = encoder(imgs)  # encoder should return final embedding tensor

            logits = probe(z)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            running_correct += (preds == labels).sum().item()
            total += imgs.size(0)
            pbar.set_postfix(loss=loss.item(), acc=running_correct / total)

        train_loss = running_loss / total
        train_acc = running_correct / total
        val_loss, val_acc = evaluate_probe(encoder, probe, val_loader, device, criterion)
        print(f"[Probe] Epoch {epoch} train_loss={train_loss:.4f} train_acc={train_acc:.4f}  "
              f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

    return probe

def evaluate_probe(encoder, probe, dataloader, device, criterion=None):
    encoder.eval()
    probe.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            imgs = imgs.view(imgs.size(0), -1)
            z = encoder(imgs)
            logits = probe(z)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            if criterion is not None:
                loss_sum += criterion(logits, labels).item() * imgs.size(0)
    avg_loss = loss_sum / total if (criterion is not None) else 0.0
    acc = correct / total
    return avg_loss, acc

# -------------------------
# Small helper: wrapper to use encoder only
# -------------------------
class EncoderOnly(nn.Module):
    def __init__(self, full_ae: StackedAutoencoder):
        super().__init__()
        self.encoder = full_ae.encoder

    def forward(self, x):
        # x: flattened
        return self.encoder(x)

# -------------------------
# Main example: MNIST
# -------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # hyperparams
    batch_size = 256
    ae_epochs = 10
    probe_epochs = 15
    ae_lr = 1e-3
    probe_lr = 1e-3

    # Data
    transform = transforms.Compose([transforms.ToTensor()])  # values in [0,1]
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    # We'll split train into train/val for the linear probe
    val_size = 5000
    train_size = len(train_ds) - val_size
    train_split, val_split = torch.utils.data.random_split(train_ds, [train_size, val_size])

    ae_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    probe_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=True, num_workers=2)
    probe_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False, num_workers=2)
    probe_test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)

    # Model
    input_dim = 28*28
    hidden_dims = [1024, 512, 256]
    bottleneck_dim = 64
    ae = StackedAutoencoder(input_dim=input_dim, hidden_dims=hidden_dims, bottleneck_dim=bottleneck_dim)

    # Train autoencoder
    print("Training autoencoder...")
    train_autoencoder(ae, ae_loader, device, epochs=ae_epochs, lr=ae_lr, save_path=None)

    # Prepare encoder-only wrapper that returns embedding
    encoder_only = EncoderOnly(ae)
    encoder_only.to(device)

    # Train linear probe (freeze encoder)
    print("Training linear probe (encoder frozen)...")
    probe = train_linear_probe(encoder_only, probe_train_loader, probe_val_loader, device,
                               embed_dim=bottleneck_dim, num_classes=10, epochs=probe_epochs, lr=probe_lr)

    # Final evaluation on test set
    test_loss, test_acc = evaluate_probe(encoder_only, probe, probe_test_loader, device,
                                         criterion=nn.CrossEntropyLoss())
    print(f"Final probe test accuracy: {test_acc:.4f}  test loss: {test_loss:.4f}")

    # Save models
    os.makedirs("models", exist_ok=True)
    torch.save(ae.state_dict(), "models/sae.pth")
    torch.save(probe.state_dict(), "models/linear_probe.pth")
    print("Saved models to models/")

if __name__ == "__main__":
    main()


Using device: cpu


100%|██████████| 9.91M/9.91M [00:01<00:00, 4.99MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.23MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.64MB/s]


Training autoencoder...


AE Train Epoch 1/10: 100%|██████████| 235/235 [00:31<00:00,  7.36it/s, loss=0.183]


[AE] Epoch 1 average loss: 0.239433


AE Train Epoch 2/10: 100%|██████████| 235/235 [00:34<00:00,  6.74it/s, loss=0.132]


[AE] Epoch 2 average loss: 0.154134


AE Train Epoch 3/10: 100%|██████████| 235/235 [00:33<00:00,  7.04it/s, loss=0.126]


[AE] Epoch 3 average loss: 0.130474


AE Train Epoch 4/10: 100%|██████████| 235/235 [00:38<00:00,  6.11it/s, loss=0.118]


[AE] Epoch 4 average loss: 0.120472


AE Train Epoch 5/10: 100%|██████████| 235/235 [00:38<00:00,  6.14it/s, loss=0.12]


[AE] Epoch 5 average loss: 0.113897


AE Train Epoch 6/10: 100%|██████████| 235/235 [00:37<00:00,  6.27it/s, loss=0.105]


[AE] Epoch 6 average loss: 0.108816


AE Train Epoch 7/10: 100%|██████████| 235/235 [00:37<00:00,  6.26it/s, loss=0.102]


[AE] Epoch 7 average loss: 0.104211


AE Train Epoch 8/10: 100%|██████████| 235/235 [00:37<00:00,  6.31it/s, loss=0.0976]


[AE] Epoch 8 average loss: 0.100317


AE Train Epoch 9/10: 100%|██████████| 235/235 [00:37<00:00,  6.25it/s, loss=0.0932]


[AE] Epoch 9 average loss: 0.097454


AE Train Epoch 10/10: 100%|██████████| 235/235 [00:37<00:00,  6.24it/s, loss=0.092]


[AE] Epoch 10 average loss: 0.094845
Training linear probe (encoder frozen)...


Probe Train Epoch 1/15: 100%|██████████| 215/215 [00:10<00:00, 20.16it/s, acc=0.75, loss=0.409]


[Probe] Epoch 1 train_loss=0.9461 train_acc=0.7500  val_loss=0.4379 val_acc=0.9018


Probe Train Epoch 2/15: 100%|██████████| 215/215 [00:11<00:00, 19.36it/s, acc=0.908, loss=0.28]


[Probe] Epoch 2 train_loss=0.3682 train_acc=0.9081  val_loss=0.3213 val_acc=0.9148


Probe Train Epoch 3/15: 100%|██████████| 215/215 [00:13<00:00, 15.92it/s, acc=0.917, loss=0.305]


[Probe] Epoch 3 train_loss=0.3003 train_acc=0.9172  val_loss=0.2825 val_acc=0.9202


Probe Train Epoch 4/15: 100%|██████████| 215/215 [00:10<00:00, 19.94it/s, acc=0.921, loss=0.26]


[Probe] Epoch 4 train_loss=0.2720 train_acc=0.9213  val_loss=0.2635 val_acc=0.9232


Probe Train Epoch 5/15: 100%|██████████| 215/215 [00:10<00:00, 20.50it/s, acc=0.924, loss=0.241]


[Probe] Epoch 5 train_loss=0.2564 train_acc=0.9239  val_loss=0.2522 val_acc=0.9246


Probe Train Epoch 6/15: 100%|██████████| 215/215 [00:09<00:00, 21.82it/s, acc=0.926, loss=0.246]


[Probe] Epoch 6 train_loss=0.2465 train_acc=0.9258  val_loss=0.2442 val_acc=0.9256


Probe Train Epoch 7/15: 100%|██████████| 215/215 [00:10<00:00, 21.37it/s, acc=0.927, loss=0.167]


[Probe] Epoch 7 train_loss=0.2400 train_acc=0.9273  val_loss=0.2399 val_acc=0.9266


Probe Train Epoch 8/15: 100%|██████████| 215/215 [00:10<00:00, 20.31it/s, acc=0.928, loss=0.223]


[Probe] Epoch 8 train_loss=0.2349 train_acc=0.9283  val_loss=0.2366 val_acc=0.9272


Probe Train Epoch 9/15: 100%|██████████| 215/215 [00:10<00:00, 19.74it/s, acc=0.929, loss=0.271]


[Probe] Epoch 9 train_loss=0.2311 train_acc=0.9291  val_loss=0.2327 val_acc=0.9282


Probe Train Epoch 10/15: 100%|██████████| 215/215 [00:10<00:00, 19.83it/s, acc=0.93, loss=0.221]


[Probe] Epoch 10 train_loss=0.2281 train_acc=0.9303  val_loss=0.2299 val_acc=0.9286


Probe Train Epoch 11/15: 100%|██████████| 215/215 [00:10<00:00, 19.87it/s, acc=0.93, loss=0.215]


[Probe] Epoch 11 train_loss=0.2255 train_acc=0.9305  val_loss=0.2283 val_acc=0.9300


Probe Train Epoch 12/15: 100%|██████████| 215/215 [00:10<00:00, 19.97it/s, acc=0.932, loss=0.254]


[Probe] Epoch 12 train_loss=0.2234 train_acc=0.9321  val_loss=0.2280 val_acc=0.9288


Probe Train Epoch 13/15: 100%|██████████| 215/215 [00:10<00:00, 19.86it/s, acc=0.932, loss=0.142]


[Probe] Epoch 13 train_loss=0.2217 train_acc=0.9318  val_loss=0.2234 val_acc=0.9312


Probe Train Epoch 14/15: 100%|██████████| 215/215 [00:10<00:00, 21.46it/s, acc=0.933, loss=0.226]


[Probe] Epoch 14 train_loss=0.2200 train_acc=0.9328  val_loss=0.2235 val_acc=0.9312


Probe Train Epoch 15/15: 100%|██████████| 215/215 [00:09<00:00, 21.83it/s, acc=0.934, loss=0.323]


[Probe] Epoch 15 train_loss=0.2184 train_acc=0.9335  val_loss=0.2229 val_acc=0.9302
Final probe test accuracy: 0.9350  test loss: 0.2223
Saved models to models/
