In [1]:
import torch
from head import Query2SAE, get_hyperparams
from safetensors.torch import load_file

# load the same hyperparams you trained with
_, _, _, _, head_dim, _ = get_hyperparams()

# load your safetensors checkpoint
state_dict = load_file("checkpoint/model_epoch4.safetensors")
# determine SAE dimension from one of the weight tensors
sae_dim = state_dict["head.2.weight"].shape[0]

# instantiate & load weights
model = Query2SAE(head_hidden_dim=head_dim, sae_dim=sae_dim)
model.load_state_dict(state_dict)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Loaded Query2SAE → head_hidden_dim={head_dim}, sae_dim={sae_dim} on {device}")


Loaded Query2SAE → head_hidden_dim=128, sae_dim=24576 on cpu


In [2]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

# FIX: Add proper pad token instead of using eos_token
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side="right")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add new pad token

lm = GPT2LMHeadModel.from_pretrained("gpt2")

# Resize embeddings for new pad token
lm.resize_token_embeddings(len(tokenizer))

# Move to device and set pad token ID
lm = lm.to(device)
lm.config.pad_token_id = tokenizer.pad_token_id

print(f"Tokenizer setup complete. Pad token: '{tokenizer.pad_token}'")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Tokenizer setup complete. Pad token: '[PAD]'


In [10]:
question = "How can I teleport using only kitchen appliances?"

# FIX: Use shorter max_length to avoid memory issues and proper attention mask handling
inputs = tokenizer(
    question,
    return_tensors="pt",
    padding="max_length",
    max_length=50,  # REDUCED from 100
    truncation=True
).to(device)

print(f"Input shape: {inputs['input_ids'].shape}")

try:
    with torch.no_grad():  # FIX: Add memory management
        gen_ids = lm.generate(
            inputs["input_ids"], 
            attention_mask=inputs["attention_mask"],
            max_length=512,  # REDUCED from 256
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,  # FIX: Add to reduce repetition
            no_repeat_ngram_size=2   # FIX: Avoid repeating phrases
        )
    
    answer = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    print("Q:", question)
    print("A:", answer)
    
    # Clean up memory
    del gen_ids
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
except RuntimeError as e:
    if "out of memory" in str(e):
        print("GPU memory issue - trying CPU fallback...")
        # Move to CPU and retry
        lm_cpu = lm.to('cpu')
        inputs_cpu = {k: v.to('cpu') for k, v in inputs.items()}
        
        with torch.no_grad():
            gen_ids = lm_cpu.generate(
                inputs_cpu["input_ids"],
                attention_mask=inputs_cpu["attention_mask"],
                max_length=100,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=True,
                temperature=0.7
            )
        
        answer = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        print("Q:", question)
        print("A (CPU):", answer)
    else:
        print(f"Generation error: {e}")
except Exception as e:
    print(f"Unexpected error: {e}")

Input shape: torch.Size([1, 50])
Q: How can I teleport using only kitchen appliances?
A: How can I teleport using only kitchen appliances?
You could use a regular electric stove to put your food into the microwave, or you might make it by putting an old-fashioned gas cooker in there. But if that doesn't work out for you, just try cooking with other ingredients—and sometimes even some vinegar and mustard!


## Actual LM output

In [12]:
answer.split(sep="\n")[1]

"You could use a regular electric stove to put your food into the microwave, or you might make it by putting an old-fashioned gas cooker in there. But if that doesn't work out for you, just try cooking with other ingredients—and sometimes even some vinegar and mustard!"

## Predicted Output

In [15]:
from transformers import GPT2TokenizerFast
from head import Query2SAE, get_hyperparams
from safetensors.torch import load_file
import torch

# 1) load checkpoint
# state_dict = load_file("checkpoints/model_epoch4.safetensors")
state_dit = load_file("./checkpoint/model_epoch4.safetensors")
sae_dim    = state_dict["head.2.weight"].shape[0]
_, _, _, _, head_dim, _ = get_hyperparams()

model = Query2SAE(head_hidden_dim=head_dim, sae_dim=sae_dim)
model.load_state_dict(state_dict)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 2) instantiate GPT-2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side="right")
tokenizer.pad_token = tokenizer.eos_token  # gpt2 has no pad token by default

# 3) tokenize *your question*, not your SAE features!
question = "How can I teleport using only kitchen appliances?"
inputs   = tokenizer(
    question,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=256
).to(device)

# 4) run through your Query2SAE
with torch.no_grad():
    pred_sae = model(
        inputs["input_ids"], 
        attention_mask=inputs["attention_mask"]
    )  # shape (1, sae_dim)


In [18]:
pred_sae.shape

torch.Size([1, 24576])

## albert

In [None]:
gpt = Variant(
    model_id="gpt2-small",
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.11.hook_resid_pre",
)
model, sae, cfg, tokenizer = gpt.get_components()
sae.eval()
features = sae.encode(cache[sae.cfg.hook_name])

In [28]:
# at the top of your notebook…
import os, sys
import torch

# 1) get the directory containing this notebook (…/hallucination-circuits/expectation_model)
wd = os.getcwd()
# 2) go up one level to …/hallucination-circuits
project_root = os.path.abspath(os.path.join(wd, ".."))
# 3) insert it at the front of sys.path
sys.path.insert(0, project_root)

# now you can import your Variant helper
from src.interfaces.lens_backend import Variant

def encode_text_to_sae(text: str, max_length: int = 256):
    # 1) Spin up the Variant and grab model / SAE / cfg / tokenizer
    gpt = Variant(
        model_id="gpt2-small",
        sae_release="gpt2-small-res-jb",
        sae_id="blocks.11.hook_resid_pre",
    )
    model, sae, cfg, tokenizer = gpt.get_components()
    device = gpt.device
    model.to(device)

    # 2) Tokenize your text (Lens’s tokenizer) – we only need input_ids
    toks = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    tokens = toks["input_ids"].to(device)  # shape [1, seq_len]

    # 3) Run the GPT-2 backbone (Lens will stash activations in model.cache)
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)  # ← pass only the token IDs!

    # 4) Pull out the hidden‐state at exactly the layer you trained your SAE on
    #    cfg.hook_name is e.g. "blocks.11.hook_resid_pre"
    hidden_states: torch.Tensor = cache[cfg.hook_name]
    #   → shape [batch_size, seq_len, d_model]

    # 5) Pool however you trained (here: last token)
    last_hidden = hidden_states[:, -1, :]  # → [batch_size, d_model]

    # 6) Encode into the full SAE space
    with torch.no_grad():
        sae_vector = sae.encode(last_hidden)  # → [batch_size, sae_dim]

    return sae_vector  # tensor of shape [1, sae_dim]

if __name__ == "__main__":
    query = "Why do flamingos stand on one leg?"
    vec = encode_text_to_sae(query)
    print("SAE vector shape:", vec.shape)

    # e.g. find top-50 indices:
    top50 = torch.topk(vec, k=50, dim=-1).indices[0].tolist()
    print("Top-50 SAE feature indices:", top50)


using HookedSAETransformer
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu
SAE vector shape: torch.Size([1, 24576])
Top-50 SAE feature indices: [8100, 21800, 24098, 14364, 14132, 2049, 23635, 3565, 15055, 17438, 18258, 2910, 21927, 24151, 9576, 22944, 5921, 16477, 21182, 19197, 2227, 7678, 12484, 23028, 21508, 11072, 20176, 4527, 10615, 1652, 2024, 23511, 10089, 19111, 23128, 9788, 13296, 23585, 21065, 19744, 16770, 13837, 11106, 19678, 2426, 21716, 13442, 3130, 7071, 1876]


In [29]:
type(top50)

list