## SAE feature extractor on DINOv2 patch activations

This notebook demonstrates a simple end-to-end pipeline:

- Load a DINOv2 ViT backbone **from a local `dinov2/` checkout** via `torch.hub.load(..., source=\"local\")`
- Run a small batch of images through the model and extract **patch-token activations**
- Train a small **Top-k Sparse Autoencoder (SAE)** using `overcomplete`
- Use the SAE codes as a sparse feature representation

### Assumptions

- You have a sibling checkout at `../dinov2` (relative to this repo root).
- You installed this repo requirements (see `requirements.txt`).

In [1]:
from __future__ import annotations

from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision import transforms

from overcomplete.sae import TopKSAE, train_sae


def get_projects_dir() -> Path:
    # If you run this notebook from the repo root, this resolves to ../
    cwd = Path.cwd().resolve()
    if cwd.name == "dense_sparse_extractor":
        return cwd.parent
    # Fallback: walk up until we find this repo root, then take its parent.
    for p in [cwd, *cwd.parents]:
        if (p / "dense_sparse_extractor").is_dir() and (p / "dinov2").is_dir():
            return p
        if (p / "pyproject.toml").is_file() and (p / "dense_sparse_extractor").is_dir():
            return p.parent
    raise RuntimeError("Could not infer projects dir. Run from the repo root.")


device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
# Load DINOv2 from a local checkout via torch.hubprojects_dir = get_projects_dir()dinov2_dir = projects_dir / "dinov2"if not dinov2_dir.is_dir():    raise FileNotFoundError(f"Expected dinov2 checkout at: {dinov2_dir}")# Use a small backbone for the demo.# If you don't have network access (or weights aren't cached), this will fall back to pretrained=False.try:    model = torch.hub.load(str(dinov2_dir), "dinov2_vits14", source="local", pretrained=True)except Exception as e:    print("[warn] Failed to load pretrained weights; falling back to pretrained=False.")    print("       Error:", repr(e))    model = torch.hub.load(str(dinov2_dir), "dinov2_vits14", source="local", pretrained=False)model.eval()model.to(device)# Fake images (no downloads). Normalize like ImageNet.transform = transforms.Compose(    [        transforms.Resize((224, 224)),        transforms.ToTensor(),        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),    ])dataset = FakeData(size=32, image_size=(3, 224, 224), num_classes=10, transform=transform)loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)batch = next(iter(loader))images, labels = batchimages = images.to(device)images.shape

In [None]:
# Extract patch-token activations from DINOv2
# forward_features returns a dict with (B, N, D) patch tokens under x_norm_patchtokens.
with torch.no_grad():
    feats = model.forward_features(images)
    patch_tokens = feats["x_norm_patchtokens"]  # (B, N, D)

B, N, D = patch_tokens.shape
activations = patch_tokens.reshape(B * N, D).contiguous()
patch_tokens.shape, activations.shape

In [None]:
# Train a small TopKSAE on these activations (toy demo settings)
sae = TopKSAE(input_shape=D, nb_concepts=512, top_k=16, device=device)

act_loader = DataLoader(activations, batch_size=1024, shuffle=True)
opt = torch.optim.Adam(sae.parameters(), lr=5e-4)

def mse_criterion(x, x_hat, pre_codes, codes, dictionary):
    return (x - x_hat).pow(2).mean()

logs = train_sae(sae, act_loader, mse_criterion, opt, nb_epochs=2, device=device)

# Use SAE as a sparse feature extractor
sae.eval()
with torch.no_grad():
    _, codes = sae.encode(activations)  # (B*N, nb_concepts) with top_k non-zeros

codes.shape, (codes != 0).float().sum(dim=-1).mean().item()