In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "mps") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [2]:
from huggingface_hub import hf_hub_download, notebook_login

In [3]:
sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
# sae_name = "Llama-3.2-1B-Instruct-SAE-l9"
# sae_name = "DeepSeek-R1-Distill-Llama-70B-SAE-l48"

In [4]:
if sae_name == "Llama-3.2-1B-Instruct-SAE-l9":
    notebook_login()

In [5]:
file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = ("deepseek-ai/DeepSeek-R1-Distill-Llama-8B" if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
              else "meta-llama/Llama-3.2-1B-Instruct" if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
              else "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 8 if sae_name == "DeepSeek-R1-Distill-Llama-70B-SAE-l48" else 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="mps"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  state_dict = torch.load(path, map_location=device)


In [7]:
"Q: Each cat is a carnivore. Every carnivore is not herbivorous. Carnivores are mammals. All mammals are warm-blooded. Mammals are vertebrates. Every vertebrate is an animal. Animals are multicellular. Fae is a cat. True or false: Fae is not herbivorous.\nA: "

'Q: Each cat is a carnivore. Every carnivore is not herbivorous. Carnivores are mammals. All mammals are warm-blooded. Mammals are vertebrates. Every vertebrate is an animal. Animals are multicellular. Fae is a cat. True or false: Fae is not herbivorous.\nA: '

In [8]:
"Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end; do not use \"Wait,\" in your think ."

'Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end; do not use "Wait," in your think .'

In [9]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end ."
        },
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("mps")

# End CoT prematurely
think_end = torch.tensor(tokenizer.convert_tokens_to_ids('</think>')).reshape(1,1).to("mps")
inputs = torch.cat((inputs, think_end), dim=-1)

outputs = model.generate(input_ids=inputs, max_new_tokens=1200)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


<｜begin▁of▁sentence｜><｜User｜>Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end .<｜Assistant｜><think>
</think>

1. **Understand the Question**: The question asks whether penguins can fly.
2. **Know Basic Facts**: Penguins are birds native to the Southern Hemisphere.
3. **Birds and Flying**: Most birds can fly, but not all.
4. **Penguins' Movement**: Penguins waddle on land and swim in water.
5. **Flight Mechanism**: Penguins have wings, but they are not built for flight like other birds.
6. **Conclusion**: Penguins cannot fly.

**NO**<｜end▁of▁sentence｜>


In [10]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end ."
        },
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("mps")

# End CoT prematurely
think_end = torch.tensor(tokenizer.convert_tokens_to_ids(['</think>', '\\n'])).reshape(1,2).to("mps")
inputs = torch.cat((inputs, think_end), dim=-1)

In [11]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end ."
        },
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("mps")
tokenizer.decode(inputs[0])

'<｜begin▁of▁sentence｜><｜User｜>Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end .<｜Assistant｜><think>\n'

In [29]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList

class MultiInterventionLogitsProcessor(LogitsProcessor):
    def __init__(self, mapping):
        """
        mapping: dict mapping conditioning token id (int) to a list (sequence) of injection token ids (list of ints)
        For example, if you want that when the model outputs the token for "hello" it injects " world !", you might do:
            mapping = { cond_token_id: [injection_token_id1, injection_token_id2, injection_token_id3] }
        """
        self.mapping = mapping
        # Dictionary to hold per-beam injection queues: beam index -> list of injection token ids left to force
        self.injection_state = {}

    def __call__(self, input_ids, scores):
        # input_ids shape: (batch_size, sequence_length)
        batch_size = input_ids.size(0)
        # Iterate over each beam in the batch
        for i in range(batch_size):
            # If this beam is in injection mode, force the next token to be the next token in its injection sequence.
            if i in self.injection_state and self.injection_state[i]:
                next_injection = self.injection_state[i].pop(0)
                scores[i, :] = -float('inf')
                scores[i, next_injection] = 0.0
                # If injection sequence is exhausted, remove this beam from state.
                if not self.injection_state[i]:
                    del self.injection_state[i]
                continue

            # Otherwise, check if the last generated token in this beam is a conditioning token.
            last_token = input_ids[i, -1].item()
            if last_token in self.mapping:
                # Start the injection: copy the injection sequence (so we don't modify the original)
                self.injection_state[i] = self.mapping[last_token].copy()
                # Force the next token to be the first token of the injection sequence.
                next_injection = self.injection_state[i].pop(0)
                scores[i, :] = -float('inf')
                scores[i, next_injection] = 0.0
        return scores

# -------------------------------
# Example usage
# -------------------------------

# Load a model and tokenizer (here using GPT-2 for demonstration).
# model_name = "gpt2"
# model = AutoModelForCausalLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define your conditioning targets and corresponding injection sequences.
# For instance, if the model outputs "hello", inject " world !" (i.e. force those tokens in order),
# and if it outputs "foo", inject " bar baz".

# Convert conditioning string(s) to token ids.
conditioning_str1 = " wait"
injection_str1 = ", that seems right."
conditioning_str2 = " Wait"
injection_str2 = ", that seems right."
conditioning_str3 = "wait"
injection_str3 = ", that seems right."
conditioning_str4 = "Wait"
injection_str4 = ", that seems right."

# For conditioning tokens we take the first token (if using a simple tokenizer, that's usually enough).
cond_token_id1 = tokenizer.encode(conditioning_str1, add_special_tokens=False)[0]
cond_token_id2 = tokenizer.encode(conditioning_str2, add_special_tokens=False)[0]
cond_token_id3 = tokenizer.encode(conditioning_str3, add_special_tokens=False)[0]
cond_token_id4 = tokenizer.encode(conditioning_str4, add_special_tokens=False)[0]

# For injection sequences, get the full list of token IDs.
inject_tokens_ids1 = tokenizer.encode(injection_str1, add_special_tokens=False)
inject_tokens_ids2 = tokenizer.encode(injection_str2, add_special_tokens=False)
inject_tokens_ids3 = tokenizer.encode(injection_str3, add_special_tokens=False)
inject_tokens_ids4 = tokenizer.encode(injection_str4, add_special_tokens=False)

# Create the mapping dictionary.
mapping = {
    cond_token_id1: inject_tokens_ids1,
    cond_token_id2: inject_tokens_ids2,
    cond_token_id3: inject_tokens_ids3,
    cond_token_id4: inject_tokens_ids4,
}

# Create our custom logits processor with the mapping.
custom_processor = MultiInterventionLogitsProcessor(mapping)
logits_processor = LogitsProcessorList([custom_processor])

# Prepare a prompt that eventually produces one of the conditioning tokens.
# For instance, we include "hello" at the end of the prompt so that the custom processor is triggered.
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end ."
        },
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("mps")  # Explicitly move to MPS device

# Generate text using our custom logits processor (greedy decoding for clarity).
generated_ids = model.generate(
    inputs,
    max_new_tokens=1200,
    logits_processor=logits_processor,
    #do_sample=False
)

print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<｜User｜>Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end .<｜Assistant｜><think>
Okay, so I'm trying to figure out if penguins can fly. I remember seeing penguins in movies and documentaries, and they always seem to move on land or ice, but I've never seen them fly. Let me think through this step by step.

First, I know that penguins are birds. Birds have feathers, beaks, wings, and they can fly, right? But wait, that seems right. So, if they are birds, they should be able to fly. But I also remember that not all birds can fly. For example, some birds like ostriches and emus can't fly because their bodies are too heavy and their wings aren't adapted for flight. So maybe penguins are similar?

I should consider the structure of a penguin. Penguins have wings, but I've never seen them use them to fly. They move on their bellies, flippers, and sometimes slide on ice. Their wings are more like flippers, which help them swim underwater

: 

In [26]:
# Print each token and its corresponding ID
for token_id in generated_ids[0]:
    token = tokenizer.decode([token_id])
    print(f"Token ID: {token_id.item()}, Token: '{token}'")

Token ID: 128000, Token: '<｜begin▁of▁sentence｜>'
Token ID: 128011, Token: '<｜User｜>'
Token ID: 6854, Token: 'Can'
Token ID: 281, Token: ' p'
Token ID: 56458, Token: 'enguins'
Token ID: 11722, Token: ' fly'
Token ID: 30, Token: '?'
Token ID: 38203, Token: ' Segment'
Token ID: 279, Token: ' the'
Token ID: 7422, Token: ' thinking'
Token ID: 1920, Token: ' process'
Token ID: 1139, Token: ' into'
Token ID: 2867, Token: ' clear'
Token ID: 7504, Token: ' steps'
Token ID: 323, Token: ' and'
Token ID: 13519, Token: ' indicate'
Token ID: 330, Token: ' "'
Token ID: 14331, Token: 'YES'
Token ID: 1, Token: '"'
Token ID: 477, Token: ' or'
Token ID: 330, Token: ' "'
Token ID: 9173, Token: 'NO'
Token ID: 1, Token: '"'
Token ID: 3131, Token: ' once'
Token ID: 520, Token: ' at'
Token ID: 279, Token: ' the'
Token ID: 842, Token: ' end'
Token ID: 662, Token: ' .'
Token ID: 128012, Token: '<｜Assistant｜>'
Token ID: 128013, Token: '<think>'
Token ID: 198, Token: '
'
Token ID: 33413, Token: 'Okay'
Token ID: 1

In [32]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [33]:
layer_id = (19 if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
            else 9 if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
            else 48)

In [34]:
target_act = gather_residual_activations(model, layer_id, inputs)

In [35]:
def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

sae, target_act = ensure_same_device(sae, target_act)
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [36]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.994


In [37]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Roleplay as a pirate"},
        {"role": "assistant", "content": "Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?"},
    ],
    return_tensors="pt",
).to("mps")

In [38]:
# Define the path to the .pt file - you can modify this to any file in the outputs directory
pt_file_path = "outputs/penguin/raw_outputs/output_79_temp0_6.pt"

# Load the tokens from the .pt file
loaded_tokens = torch.load(pt_file_path)

# Move to the appropriate device
inputs = loaded_tokens.to("mps")

# Print shape and first few tokens to verify
print(f"Loaded token tensor shape: {inputs.shape}")
print(f"First few tokens: {inputs[0, :10]}")

Loaded token tensor shape: torch.Size([1, 724])
First few tokens: tensor([128000, 128011,   6854,    281,  56458,  11722,     30,  38203,    279,
          7422], device='mps:0')


  loaded_tokens = torch.load(pt_file_path)


In [39]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
# assistant_activations = sae_acts[0, assistant_tokens]
# mean_activations = assistant_activations.mean(dim=0)

# Get activations for the entire sequence
all_activations = sae_acts[0]
mean_activations = all_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during penguin speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts, all_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during penguin speech:
Feature 48761: 1.515
Feature 45162: 0.874
Feature 3528: 0.639
Feature 1421: 0.369
Feature 53595: 0.347
Feature 34876: 0.303
Feature 13155: 0.299
Feature 3831: 0.261
Feature 13650: 0.239
Feature 55538: 0.231
Feature 44814: 0.222
Feature 5909: 0.221
Feature 47160: 0.220
Feature 53323: 0.204
Feature 41304: 0.197
Feature 60476: 0.188
Feature 26171: 0.183
Feature 4677: 0.178
Feature 46104: 0.171
Feature 45255: 0.163

Activation patterns across tokens:

Token: enguins
  Feature 41304: 1.537
  Feature 60476: 0.627

Token: Ġfly
  Feature 3528: 0.886
  Feature 53595: 0.834
  Feature 55538: 0.347
  Feature 5909: 1.116
  Feature 60476: 0.292
  Feature 4677: 4.147

Token: ?
  Feature 3528: 0.425
  Feature 1421: 1.792
  Feature 60476: 0.973

Token: Ġinto
  Feature 60476: 0.668

Token: Ġclear
  Feature 60476: 0.320

Token: Ġsteps
  Feature 60476: 0.545

Token: Ġand
  Feature 60476: 0.288

Token: Ġindicate
  Feature 60476: 0.552

Token: Ġor
  Feature 6047

In [40]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the thinking response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

# Look for <think> and </think> tags in the tokens
think_start_idx = None
think_end_idx = None

for i, token in enumerate(token_texts):
    if '<think>' in token:
        think_start_idx = i + 1  # Start after the <think> tag
    elif '</think>' in token:
        think_end_idx = i  # End before the </think> tag
        break

# If think tags not found, fall back to assistant response
if think_start_idx is None or think_end_idx is None:
    assistant_start = special_positions[-2] + 1
    thinking_tokens = slice(assistant_start, None)
else:
    thinking_tokens = slice(think_start_idx, think_end_idx)

# Get activation statistics for thinking response
thinking_activations = sae_acts[0, thinking_tokens]
mean_activations = thinking_activations.mean(dim=0)

# Find top activated features during thinking
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during thinking:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[thinking_tokens], thinking_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")

Top activated features during thinking:
Feature 48761: 1.568
Feature 45162: 0.841
Feature 3528: 0.660
Feature 1421: 0.397
Feature 53595: 0.371
Feature 34876: 0.314
Feature 13155: 0.292
Feature 44814: 0.258
Feature 3831: 0.237
Feature 53323: 0.237
Feature 5909: 0.235
Feature 55538: 0.234
Feature 47160: 0.230
Feature 13650: 0.226
Feature 41304: 0.217
Feature 60476: 0.200
Feature 46104: 0.190
Feature 26171: 0.188
Feature 45255: 0.175
Feature 4677: 0.174

Activation patterns across tokens:

Token: Ċ
  Feature 48761: 0.509
  Feature 34876: 1.070
  Feature 60476: 1.474

Token: Okay
  Feature 48761: 1.082
  Feature 34876: 1.974
  Feature 60476: 4.595

Token: ,
  Feature 48761: 1.083
  Feature 1421: 0.944
  Feature 53323: 0.710

Token: Ġso
  Feature 48761: 1.459
  Feature 1421: 0.581
  Feature 3831: 0.246
  Feature 53323: 0.273
  Feature 60476: 0.392

Token: ĠI
  Feature 48761: 1.353
  Feature 1421: 1.120
  Feature 53323: 0.463

Token: 'm
  Feature 48761: 1.637
  Feature 1421: 0.595
  Feature 

In [41]:
import numpy as np

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the thinking response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

# Look for <think> and </think> tags in the tokens
think_start_idx = None
think_end_idx = None

for i, token in enumerate(token_texts):
    if '<think>' in token:
        think_start_idx = i + 1  # Start after the <think> tag
    elif '</think>' in token:
        think_end_idx = i  # End before the </think> tag
        break

# If think tags not found, fall back to assistant response
if think_start_idx is None or think_end_idx is None:
    assistant_start = special_positions[-2] + 1
    thinking_tokens = slice(assistant_start, None)
else:
    thinking_tokens = slice(think_start_idx, think_end_idx)

# Get activations only for the thinking part
target_act = gather_residual_activations(model, layer_id, inputs[:, thinking_tokens])
sae_acts = sae.encode(target_act.to(torch.float32))

# Get activation statistics for thinking response
thinking_activations = sae_acts[0]
mean_activations = thinking_activations.mean(dim=0)

# Find top activated features during thinking
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during thinking:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[thinking_tokens], thinking_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.1:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.1:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")

Top activated features during thinking:
Feature 48761: 1.361
Feature 45162: 0.841
Feature 3528: 0.672
Feature 1421: 0.390
Feature 53595: 0.386
Feature 53323: 0.279
Feature 13155: 0.276
Feature 55538: 0.247
Feature 5909: 0.229
Feature 47160: 0.228
Feature 44814: 0.226
Feature 13650: 0.215
Feature 60476: 0.212
Feature 3831: 0.194
Feature 46104: 0.194
Feature 34347: 0.193
Feature 41304: 0.192
Feature 26171: 0.179
Feature 4677: 0.175
Feature 34876: 0.170

Activation patterns across tokens:

Token: Okay
  Feature 48761: 0.913
  Feature 60476: 0.270

Token: ,
  Feature 48761: 0.857
  Feature 60476: 0.538

Token: Ġso
  Feature 48761: 1.223
  Feature 60476: 1.686

Token: ĠI
  Feature 48761: 1.185
  Feature 60476: 1.289

Token: 'm
  Feature 48761: 1.273
  Feature 60476: 1.439

Token: Ġtrying
  Feature 48761: 1.289
  Feature 60476: 1.169
  Feature 3831: 0.282
  Feature 34347: 0.540

Token: Ġto
  Feature 48761: 0.974
  Feature 60476: 0.237

Token: Ġfigure
  Feature 48761: 1.229
  Feature 60476: 1

In [65]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_indices: list[int],
    interventions: list[float],
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    if len(feature_indices) != len(interventions):
        raise ValueError("feature_indices and interventions must have the same length.")

    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        for feature_idx, intervention in zip(feature_indices, interventions):
            features[:, :, feature_idx] += intervention

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        # End CoT prematurely
        think_end = torch.tensor(tokenizer.convert_tokens_to_ids('</think>')).reshape(1,1).to("mps")
        input_tokens = torch.cat((inputs, think_end), dim=-1)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text

In [None]:

print("\nGeneration with modified feature:")

feature_to_modify = [64704, 38425, 3831, 19748]

messages = [
    {"role": "user", 
         "content": "Can penguins fly? Segment the thinking process into clear steps and indicate \"YES\" or \"NO\" once at the end ."
        }
]

modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_indices=feature_to_modify,
    interventions=[-5, -5, -5, -5],
    target_layer=layer_id,
    max_new_tokens=1000
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



Generation with modified feature:
<｜begin▁of▁sentence｜><｜User｜>Can penguins fly? Segment the thinking process into clear steps and indicate "YES" or "NO" once at the end .<｜Assistant｜><think>
Okay, so I'm trying to figure out whether penguins can fly. At first glance, it seems pretty straightforward, but I want to make sure I don't miss anything. Let me start by thinking about what I know about penguins. They live in the cold regions of the world, like Antarctica, right? I've seen pictures of them waddling on the ice. They have this cute little waddle, and they can swim really fast. But can they fly?

Hmm, I remember that birds can fly, but penguins aren't exactly known for flying. I mean, I've never seen a penguin take off into the air. They mostly move on the ground or swim through the water. Maybe that's a clue. So, if they can't fly, why not? Do they have the necessary physical features for flying?

I should consider their body structure. Penguins have a streamlined body, which is

: 

In [44]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_idx: int,
    intervention: float = 3.0,
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        features[:, :, feature_idx] += intervention

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text