<a href="https://colab.research.google.com/github/giacomoarienti/data-intensive-lab/blob/master/sae-test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
%pip install torch transformers accelerate



In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        # n -> n * factor -> n
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "cuda") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [41]:
from huggingface_hub import hf_hub_download

In [42]:
# sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
sae_name = "Llama-3.2-1B-Instruct-SAE-l9"
# sae_name = "DeepSeek-R1-Distill-Llama-70B-SAE-l48"

In [43]:
file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

In [44]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = ("deepseek-ai/DeepSeek-R1-Distill-Llama-8B" if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
              else "meta-llama/Llama-3.2-1B-Instruct" if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
              else "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 8 if sae_name == "DeepSeek-R1-Distill-Llama-70B-SAE-l48" else 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="cuda"
)

  state_dict = torch.load(path, map_location=device)


In [45]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("cuda")
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hello! I'm doing well, thanks for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help you with any questions or topics you'd like to discuss. How about


In [46]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [47]:
# show model architecture
print(model.model.layers)

ModuleList(
  (0-15): 16 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
  )
)


In [48]:
layer_id = (19 if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
            else 9 if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
            else 48)

In [49]:
target_act = gather_residual_activations(model, layer_id, inputs)

In [50]:
def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

sae, target_act = ensure_same_device(sae, target_act)
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [51]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.999


In [104]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Golden Gate Bridge"},
    ],
    return_tensors="pt",
).to("cuda")

In [106]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
assistant_activations = sae_acts[0, assistant_tokens]
mean_activations = assistant_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during pirate speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[assistant_tokens], assistant_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during pirate speech:
Feature 25317: 0.405
Feature 30078: 0.293
Feature 28128: 0.237
Feature 9942: 0.219
Feature 13916: 0.189
Feature 3663: 0.156
Feature 8103: 0.143
Feature 31638: 0.140
Feature 5833: 0.138
Feature 2052: 0.134
Feature 28447: 0.116
Feature 15786: 0.108
Feature 28539: 0.107
Feature 31021: 0.099
Feature 22263: 0.092
Feature 16742: 0.079
Feature 30504: 0.077
Feature 25757: 0.075
Feature 18508: 0.074
Feature 11775: 0.072

Activation patterns across tokens:

Token: ĊĊ
  Feature 28539: 0.536
  Feature 31021: 0.497
  Feature 11775: 0.362

Token: Golden
  Feature 30078: 1.465
  Feature 3663: 0.497
  Feature 16742: 0.396

Token: ĠGate
  Feature 25317: 0.596
  Feature 28128: 0.274
  Feature 8103: 0.713
  Feature 31638: 0.246
  Feature 5833: 0.688
  Feature 15786: 0.375
  Feature 18508: 0.370

Token: ĠBridge
  Feature 25317: 1.242
  Feature 28128: 0.730
  Feature 9942: 0.742
  Feature 13916: 0.500
  Feature 31638: 0.309
  Feature 2052: 0.354
  Feature 28447:

In [92]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_idx: list[int],
    intervention: float = 3.0,
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        features[:, :, feature_idx] += intervention

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text

In [119]:
messages = [
    {"role": "user", "content": "How are you doing?"}
    # {"role": "user", "content": "Who are you?"}
    # {"role": "user", "content": "Roleplay as a pirate"}
]
features_to_modify = [25317, 30078, 28128]
intervation = 1.5

print("Original generation:")
input_tokens = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
outputs = model.generate(input_tokens, max_new_tokens=1000, do_sample=False)
print(tokenizer.decode(outputs[0]))

print("\nGeneration with modified feature:")
modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=features_to_modify,
    intervention=intervation,
    target_layer=layer_id,
    max_new_tokens=100
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Original generation:


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

How are you doing?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm doing well, thanks for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm always ready to help with any questions or topics you'd like to discuss. How about you? How's your day going?<|eot_id|>

Generation with modified feature:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

How are you doing?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

This is a replica of the statue of the sun at the Eiffel Tower.<|eot_id|>


In [57]:
print("\nGeneration with modified feature:")

messages = [
    {"role": "user", "content": "How many Rs in strawberry?"}
]

modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=features_to_modify,
    intervention=8,
    target_layer=layer_id,
    max_new_tokens=1000
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Generation with modified feature:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

How many Rs in strawberry?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The number of Rs in "Strawberry" is 2.<|eot_id|>
