<a href="https://colab.research.google.com/github/giacomoarienti/data-intensive-lab/blob/master/sae-test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install torch transformers accelerate

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        # n -> n * factor -> n
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "cuda") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [None]:
from huggingface_hub import hf_hub_download

In [None]:
# sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
sae_name = "Llama-3.2-1B-Instruct-SAE-l9"
# sae_name = "DeepSeek-R1-Distill-Llama-70B-SAE-l48"

In [None]:
file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

Llama-3.2-1B-Instruct-SAE-l9.pt:   0%|          | 0.00/537M [00:00<?, ?B/s]

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = ("deepseek-ai/DeepSeek-R1-Distill-Llama-8B" if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
              else "meta-llama/Llama-3.2-1B-Instruct" if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
              else "deepseek-ai/DeepSeek-R1-Distill-Llama-70B")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 8 if sae_name == "DeepSeek-R1-Distill-Llama-70B-SAE-l48" else 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="cuda"
)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

  state_dict = torch.load(path, map_location=device)


In [None]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("cuda")
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here and ready to help you with any questions or topics you'd like to discuss. How


In [None]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [None]:
# show model architecture
print(model.model.layers)

ModuleList(
  (0-15): 16 x LlamaDecoderLayer(
    (self_attn): LlamaAttention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
  )
)


In [None]:
layer_id = (19 if sae_name == "DeepSeek-R1-Distill-Llama-8B-SAE-l19"
            else 9 if sae_name == "Llama-3.2-1B-Instruct-SAE-l9"
            else 48)

In [None]:
target_act = gather_residual_activations(model, layer_id, inputs)

In [None]:
def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

sae, target_act = ensure_same_device(sae, target_act)
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [None]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.999


In [None]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "The Golden Gate Bridge is a suspension bridge spanning the Golden Gate, the one-mile-wide strait connecting San Francisco Bay and the Pacific Ocean in California, United States."},
    ],
    return_tensors="pt",
).to("cuda")

In [None]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, layer_id, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
assistant_activations = sae_acts[0, assistant_tokens]
mean_activations = assistant_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during pirate speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[assistant_tokens], assistant_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during pirate speech:
Feature 15546: 0.120
Feature 9942: 0.106
Feature 25317: 0.100
Feature 3077: 0.099
Feature 8400: 0.087
Feature 30078: 0.086
Feature 1973: 0.085
Feature 24720: 0.082
Feature 28655: 0.069
Feature 8103: 0.067
Feature 5991: 0.067
Feature 19713: 0.061
Feature 15299: 0.060
Feature 4976: 0.058
Feature 16117: 0.054
Feature 29693: 0.050
Feature 79: 0.047
Feature 4219: 0.046
Feature 30671: 0.045
Feature 28596: 0.045

Activation patterns across tokens:

Token: ĠGolden
  Feature 30078: 1.484

Token: ĠGate
  Feature 25317: 0.862
  Feature 8103: 0.744

Token: ĠBridge
  Feature 9942: 0.763
  Feature 25317: 1.244

Token: Ġis
  Feature 9942: 0.428
  Feature 25317: 0.242

Token: Ġsuspension
  Feature 9942: 0.433

Token: Ġbridge
  Feature 9942: 0.823
  Feature 8400: 0.257
  Feature 1973: 0.451
  Feature 15299: 0.239

Token: Ġspanning
  Feature 9942: 0.485
  Feature 8400: 0.259
  Feature 1973: 0.748
  Feature 24720: 0.230
  Feature 28655: 0.296
  Feature 15299: 

In [None]:
def generate_with_intervention(
    model,
    tokenizer,
    sae,
    messages: list[dict],
    feature_idx: list[int],
    intervention: list[float],
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None

    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]

        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed

        interv = torch.tensor(intervention).to(model.device)
        features[:, :, feature_idx] += interv

        modified = sae.decode(features) + error
        modified_activations = modified
        modified_activations = modified.to(torch.bfloat16)

        return outputs

    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs

    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]

    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )

        generated_text = tokenizer.decode(outputs[0])

    finally:
        for handle in handles:
            handle.remove()

    return generated_text

In [105]:
messages = [
    # {"role": "user", "content": "How are you doing?"}
    {"role": "user", "content": "What's your name?"}
]
"""
Token: Golden
  Feature 30078: 1.484

Token: Gate
  Feature 25317: 0.862
  Feature 8103: 0.744

Token: Bridge
  Feature 9942: 0.763
  Feature 25317: 1.244
"""
# features_to_modify = [25317, 30078, 28128, 9942, 5833]
features_to_modify = [30078, 25317, 8103, 9942]
intervation = [1, 1, 1, 1]

In [None]:
print("Original generation:")
input_tokens = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
outputs = model.generate(input_tokens, max_new_tokens=1000, do_sample=False)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Original generation:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

How are you doing?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm doing well, thanks for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm always ready to help with any questions or topics you'd like to discuss. How about you? How's your day going?<|eot_id|>


In [104]:
print("\nGeneration with modified feature:")
modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=features_to_modify,
    intervention=intervation,
    target_layer=layer_id,
    max_new_tokens=100
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Generation with modified feature:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Mar 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm an artificial sculpture in the shape of an Etrian architecture, and I'm here to guide you through the iconic monument.<|eot_id|>
