In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "mps") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [2]:
from huggingface_hub import hf_hub_download, notebook_login

In [3]:
sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
# sae_name = "Llama-3.2-1B-Instruct-SAE-l9"
# sae_name = "DeepSeek-R1-Distill-Llama-70B-SAE-l48"

In [4]:
if sae_name == "Llama-3.2-1B-Instruct-SAE-l9":
    notebook_login()

In [5]:
file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = ("deepseek-ai/DeepSeek-R1-Distill-Llama-8B" if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
              else "meta-llama/Llama-3.2-1B-Instruct" if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
              else "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 8 if sae_name == "DeepSeek-R1-Distill-Llama-70B-SAE-l48" else 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="mps"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  state_dict = torch.load(path, map_location=device)


In [7]:
"Q: Each cat is a carnivore. Every carnivore is not herbivorous. Carnivores are mammals. All mammals are warm-blooded. Mammals are vertebrates. Every vertebrate is an animal. Animals are multicellular. Fae is a cat. True or false: Fae is not herbivorous.\nA: "

'Q: Each cat is a carnivore. Every carnivore is not herbivorous. Carnivores are mammals. All mammals are warm-blooded. Mammals are vertebrates. Every vertebrate is an animal. Animals are multicellular. Fae is a cat. True or false: Fae is not herbivorous.\nA: '

In [8]:
"Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end; do not use \"Wait,\" in your think ."

'Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end; do not use "Wait," in your think .'

In [9]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end and do not use \"Wait,\" in your think ."
        },
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("mps")
outputs = model.generate(input_ids=inputs, max_new_tokens=1200)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


KeyboardInterrupt: 

In [10]:
outputs.shape

NameError: name 'outputs' is not defined

In [11]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [12]:
layer_id = (19 if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
            else 9 if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
            else 48)

In [13]:
target_act = gather_residual_activations(model, layer_id, inputs)

In [14]:
def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

sae, target_act = ensure_same_device(sae, target_act)
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [15]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.991


In [16]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Roleplay as a pirate"},
        {"role": "assistant", "content": "Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?"},
    ],
    return_tensors="pt",
).to("mps")

In [17]:
# Define the path to the .pt file - you can modify this to any file in the outputs directory
pt_file_path = "outputs/penguin/raw_outputs/output_211_temp0_6.pt"

# Load the tokens from the .pt file
loaded_tokens = torch.load(pt_file_path)

# Move to the appropriate device
inputs = loaded_tokens.to("mps")

# Print shape and first few tokens to verify
print(f"Loaded token tensor shape: {inputs.shape}")
print(f"First few tokens: {inputs[0, :10]}")

Loaded token tensor shape: torch.Size([1, 494])
First few tokens: tensor([128000, 128011,   6854,    281,  56458,  11722,     30,  38203,    279,
          7422], device='mps:0')


  loaded_tokens = torch.load(pt_file_path)


In [21]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
# assistant_activations = sae_acts[0, assistant_tokens]
# mean_activations = assistant_activations.mean(dim=0)

# Get activations for the entire sequence
all_activations = sae_acts[0]
mean_activations = all_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during penguin speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts, all_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during penguin speech:
Feature 48761: 1.470
Feature 45162: 0.873
Feature 3528: 0.612
Feature 1421: 0.411
Feature 53595: 0.408
Feature 3831: 0.305
Feature 34876: 0.269
Feature 53323: 0.269
Feature 13155: 0.269
Feature 56582: 0.251
Feature 2398: 0.205
Feature 934: 0.200
Feature 61132: 0.187
Feature 60476: 0.187
Feature 44814: 0.185
Feature 9628: 0.180
Feature 55538: 0.179
Feature 47160: 0.176
Feature 41304: 0.175
Feature 5909: 0.171

Activation patterns across tokens:

Token: <｜begin▁of▁sentence｜>
  Feature 2398: 101.447
  Feature 934: 98.754
  Feature 61132: 92.354

Token: enguins
  Feature 60476: 0.627
  Feature 41304: 1.537

Token: Ġfly
  Feature 3528: 0.886
  Feature 53595: 0.834
  Feature 60476: 0.292
  Feature 55538: 0.347
  Feature 5909: 1.116

Token: ?
  Feature 3528: 0.425
  Feature 1421: 1.792
  Feature 60476: 0.973

Token: Ġinto
  Feature 60476: 0.668

Token: Ġclear
  Feature 60476: 0.320

Token: Ġsteps
  Feature 60476: 0.545

Token: Ġand
  Feature 60476

In [22]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the thinking response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

# Look for <think> and </think> tags in the tokens
think_start_idx = None
think_end_idx = None

for i, token in enumerate(token_texts):
    if '<think>' in token:
        think_start_idx = i + 1  # Start after the <think> tag
    elif '</think>' in token:
        think_end_idx = i  # End before the </think> tag
        break

# If think tags not found, fall back to assistant response
if think_start_idx is None or think_end_idx is None:
    assistant_start = special_positions[-2] + 1
    thinking_tokens = slice(assistant_start, None)
else:
    thinking_tokens = slice(think_start_idx, think_end_idx)

# Get activation statistics for thinking response
thinking_activations = sae_acts[0, thinking_tokens]
mean_activations = thinking_activations.mean(dim=0)

# Find top activated features during thinking
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during thinking:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[thinking_tokens], thinking_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")

Top activated features during thinking:
Feature 48761: 1.554
Feature 45162: 0.819
Feature 3528: 0.638
Feature 53595: 0.470
Feature 1421: 0.443
Feature 53323: 0.331
Feature 34876: 0.282
Feature 3831: 0.267
Feature 13155: 0.255
Feature 44814: 0.228
Feature 10498: 0.208
Feature 41304: 0.206
Feature 60476: 0.203
Feature 5909: 0.199
Feature 60322: 0.198
Feature 47160: 0.194
Feature 56582: 0.184
Feature 55538: 0.181
Feature 9628: 0.175
Feature 32009: 0.164

Activation patterns across tokens:

Token: Ċ
  Feature 48761: 0.509
  Feature 34876: 1.070
  Feature 60476: 1.474

Token: Okay
  Feature 48761: 1.082
  Feature 34876: 1.974
  Feature 60476: 4.595

Token: ,
  Feature 48761: 1.083
  Feature 1421: 0.944
  Feature 53323: 0.710
  Feature 32009: 2.067

Token: Ġso
  Feature 48761: 1.459
  Feature 1421: 0.581
  Feature 53323: 0.273
  Feature 3831: 0.246
  Feature 10498: 0.211
  Feature 60476: 0.392
  Feature 32009: 2.297

Token: ĠI
  Feature 48761: 1.353
  Feature 1421: 1.120
  Feature 53323: 0.4

In [23]:
import numpy as np

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the thinking response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

# Look for <think> and </think> tags in the tokens
think_start_idx = None
think_end_idx = None

for i, token in enumerate(token_texts):
    if '<think>' in token:
        think_start_idx = i + 1  # Start after the <think> tag
    elif '</think>' in token:
        think_end_idx = i  # End before the </think> tag
        break

# If think tags not found, fall back to assistant response
if think_start_idx is None or think_end_idx is None:
    assistant_start = special_positions[-2] + 1
    thinking_tokens = slice(assistant_start, None)
else:
    thinking_tokens = slice(think_start_idx, think_end_idx)

# Get activations only for the thinking part
target_act = gather_residual_activations(model, layer_id, inputs[:, thinking_tokens])
sae_acts = sae.encode(target_act.to(torch.float32))

# Get activation statistics for thinking response
thinking_activations = sae_acts[0]
mean_activations = thinking_activations.mean(dim=0)

# Find top activated features during thinking
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during thinking:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[thinking_tokens], thinking_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.1:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.1:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")

Top activated features during thinking:
Feature 48761: 1.316
Feature 45162: 0.812
Feature 3528: 0.660
Feature 53595: 0.487
Feature 1421: 0.443
Feature 53323: 0.393
Feature 10498: 0.296
Feature 13155: 0.229
Feature 47160: 0.215
Feature 3831: 0.214
Feature 60476: 0.212
Feature 60322: 0.200
Feature 44814: 0.190
Feature 41304: 0.190
Feature 5909: 0.187
Feature 55538: 0.182
Feature 34347: 0.181
Feature 20420: 0.165
Feature 24944: 0.152
Feature 4677: 0.147

Activation patterns across tokens:

Token: Okay
  Feature 48761: 0.913
  Feature 60476: 0.270

Token: ,
  Feature 48761: 0.857
  Feature 60476: 0.538

Token: Ġso
  Feature 48761: 1.223
  Feature 60476: 1.686

Token: ĠI
  Feature 48761: 1.185
  Feature 60476: 1.289

Token: 'm
  Feature 48761: 1.273
  Feature 60476: 1.439

Token: Ġtrying
  Feature 48761: 1.289
  Feature 3831: 0.282
  Feature 60476: 1.169
  Feature 34347: 0.540

Token: Ġto
  Feature 48761: 0.974
  Feature 60476: 0.237

Token: Ġfigure
  Feature 48761: 1.229
  Feature 3831: 0.

In [19]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_idx: int,
    intervention: float = 3.0,
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        features[:, :, feature_idx] += intervention

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text

messages = [
    {"role": "user", "content": "How are you doing?"}
]
feature_to_modify = 7560

print("Original generation:")
input_tokens = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
outputs = model.generate(input_tokens, max_new_tokens=1000, do_sample=False)
print(tokenizer.decode(outputs[0]))

print("\nGeneration with modified feature:")
modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=feature_to_modify,
    intervention=10,
    target_layer=layer_id,
    max_new_tokens=100
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Original generation:


KeyboardInterrupt: 

In [17]:

print("\nGeneration with modified feature:")

messages = [
    {"role": "user", "content": "How many Rs in strawberry?"}
]

modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=feature_to_modify,
    intervention=8,
    target_layer=layer_id,
    max_new_tokens=1000
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



Generation with modified feature:


KeyboardInterrupt: 