In [1]:
import torch

# Evaluate vision SAE

In [2]:
# Load SAE

# Load an SAE
from huggingface_hub import hf_hub_download, list_repo_files
from vit_prisma.sae import SparseAutoencoder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

def load_sae(repo_id, file_name, config_name):
    # Step 1: Download SAE weights and SAE config from Hugginface
    sae_path = hf_hub_download(repo_id, file_name) # Download SAE weights
    hf_hub_download(repo_id, config_name) # Download SAE config

    # Step 2: Now load the pretrained SAE weights from where you just downloaded them
    print(f"Loading SAE from {sae_path}...")
    sae = SparseAutoencoder.load_from_pretrained(sae_path) # This now automatically gets the config.json file in that folder and converts it into a VisionSAERunnerConfig object
    return sae

# Change the repo_id to the Huggingface repo of your chosen SAE. See /docs for a list of SAEs.
repo_id = "Prisma-Multimodal/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-1e-05" 
file_name = "weights.pt" # Usually weights.pt but may have slight naming variation. See the chosen HF repo for the exact file name
config_name = "config.json"

sae = load_sae(repo_id, file_name, config_name)
sae.to(DEVICE)

  from .autonotebook import tqdm as notebook_tqdm




This means that static image generation (e.g. `fig.write_image()`) will not work.

Please upgrade Plotly to version 6.1.1 or greater, or downgrade Kaleido to version 0.2.1.


2026-02-16 21:31:19 DEBUG:httpcore.connection: connect_tcp.started host='huggingface.co' port=443 local_address=None timeout=10 socket_options=None
2026-02-16 21:31:19 DEBUG:httpcore.connection: connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x724b55db2cf0>
2026-02-16 21:31:19 DEBUG:httpcore.connection: start_tls.started ssl_context=<ssl.SSLContext object at 0x724bc97ded50> server_hostname='huggingface.co' timeout=10
2026-02-16 21:31:19 DEBUG:httpcore.connection: start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x724b55e0c410>
2026-02-16 21:31:19 DEBUG:httpcore.http11: send_request_headers.started request=<Request [b'HEAD']>
2026-02-16 21:31:19 DEBUG:httpcore.http11: send_request_headers.complete
202

Loading SAE from /mnt/nw/home/m.yu/.cache/huggingface/hub/models--Prisma-Multimodal--sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-1e-05/snapshots/1e0dcc98c63aed3ce28f0387027cf8df6784fef1/weights.pt...


OSError: Error loading the state dictionary from .pt file: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [3]:
# Load model
from vit_prisma.models.model_loader import load_hooked_model


model_name = sae.cfg.model_name
model = load_hooked_model(model_name)
model.to(DEVICE) # Move to device

2025-06-01 23:27:15 INFO:root: Model 'open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K' is supported and passes tests.
2025-06-01 23:27:15 INFO:root: model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
2025-06-01 23:27:15 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/resolve/main/open_clip_config.json HTTP/1.1" 200 0
2025-06-01 23:27:15 INFO:root: model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
2025-06-01 23:27:15 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
2025-06-01 23:27:16 INFO:root: visual projection shape: torch.Size([768, 512])
2025-06-01 23:27:16 INFO:root: Loaded pretrained model open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K into HookedTransformer


HookedViT(
  (embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  )
  (hook_embed): HookPoint()
  (pos_embed): PosEmbedding()
  (hook_pos_embed): HookPoint()
  (hook_full_embed): HookPoint()
  (ln_pre): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (hook_ln_pre): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): Ho

In [5]:
# Evaluate vision SAE
from vit_prisma.sae import SparsecoderEval
from vit_prisma.transforms import get_clip_val_transforms

import torchvision

imagenet_validation_path = '/network/scratch/s/sonia.joseph/datasets/kaggle_datasets/ILSVRC/Data/CLS-LOC/val'

data_transforms = get_clip_val_transforms()
eval_dataset = torchvision.datasets.ImageFolder(imagenet_validation_path, transform=data_transforms)

eval_runner = SparsecoderEval(sae, model, eval_dataset, max_images=1000)

metrics = eval_runner.run_eval(is_clip=True)

2025-06-01 23:28:04 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
2025-06-01 23:28:04 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/resolve/main/open_clip_config.json HTTP/1.1" 200 0
2025-06-01 23:28:04 INFO:root: Loaded hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K model config.
2025-06-01 23:28:05 INFO:root: Loading pretrained hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K weights (/home/mila/s/sonia.joseph/.cache/huggingface/hub/models--laion--CLIP-ViT-B-32-DataComp.XL-s13B-b90K/snapshots/f0e2ffa09cbadab3db6a261ec1ec56407ce42912/open_clip_pytorch_model.bin).
2025-06-01 23:28:06 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/resolve/main/open_clip_config.json HTTP/1.1" 200 0
2025-06-01 23:28:13 INFO:vit_prisma.sae.evals.model_eval: Starting Sparse