In [4]:
import json

In [3]:
models = {
    "Vanilla_(ReLU)_v1": {
        "clip_model": "openai/clip-vit-base-patch16",
        "hook_layer": 10,
        "hook_module": "self_attn",
        "expansion_factor": 64,
        "k": None,
        "geom_dec_bias": False,
        "data_path": "data/vanilla",
        "chpt_path": "clip-vit-base-patch16_sae.pt",
    },
    "TopK_v1": {
        "clip_model": "openai/clip-vit-base-patch16",
        "hook_layer": 10,
        "hook_module": "self_attn",
        "expansion_factor": 64,
        "k": 32,
        "geom_dec_bias": False,
        "data_path": "data/top32",
        "chpt_path": "clip-vit-base-patch16_sae-top32.pt",
    },
    "Vanilla_(ReLU)_v2": {
        "clip_model": "openai/clip-vit-base-patch16",
        "hook_layer": 10,
        "hook_module": "mlp",
        "expansion_factor": 64,
        "k": None,
        "geom_dec_bias": False,
        "data_path": "data/vanilla-v2",
        "chpt_path": "clip-vit-base-patch16_sae-v2.pt",
    },
    "TopK_v2": {
        "clip_model": "openai/clip-vit-base-patch16",
        "hook_layer": 10,
        "hook_module": "mlp",
        "expansion_factor": 64,
        "k": 32,
        "geom_dec_bias": False,
        "data_path": "data/top32-v2",
        "chpt_path": "clip-vit-base-patch16_sae-top32-v2.pt",
    },
    "Vanilla_(ReLU)_v3": {
        "clip_model": "openai/clip-vit-base-patch16",
        "hook_layer": 10,
        "hook_module": "mlp",
        "expansion_factor": 64,
        "k": None,
        "geom_dec_bias": True,
        "data_path": "data/vanilla-v3",
        "chpt_path": "clip-vit-base-patch16_sae-v3.pt",
    },
}
with open("models.json", "w") as ouf:
    ouf.write(json.dumps(models))

In [5]:
from model import *

In [6]:
def fetch_model(**kwargs):
    chpt_path = kwargs.pop("chpt_path")
    model = SAEonCLIP(**kwargs)
    model.sae.load_state_dict(torch.load(chpt_path))
    model.eval()
    return model

In [7]:
with open("models.json", "r") as inf:
    models = json.load(inf)

In [8]:
params = models["Vanilla_(ReLU)_v3"]
params

{'clip_model': 'openai/clip-vit-base-patch16',
 'hook_layer': 10,
 'hook_module': 'mlp',
 'expansion_factor': 64,
 'k': None,
 'geom_dec_bias': True,
 'data_path': 'data/vanilla-v3',
 'chpt_path': 'clip-vit-base-patch16_sae-v3.pt'}

In [9]:
del params["data_path"]
model = fetch_model(**params, device="cpu")
model

SAEonCLIP(
  (clip): CLIPModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 512)
        (position_embedding): Embedding(77, 512)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_features=512, bias=True)
      