### Imports & Preps

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
import numpy as np

In [2]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)

    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "mps") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [3]:
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"

file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

expansion_factor = 16
sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=model.config.hidden_size,
    expansion_factor=expansion_factor,
    device="mps"
)

layer_id = 19

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  state_dict = torch.load(path, map_location=device)


In [5]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs

    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

def ensure_same_device(sae, target_act):
    """Ensure SAE and activations are on the same device"""
    model_device = target_act.device
    sae = sae.to(model_device)
    return sae, target_act.to(model_device)

### Penguin SAE Analysis

In [6]:
# Define the path to the .pt file - you can modify this to any file in the outputs directory
pt_file_path = "outputs/penguin/raw_outputs/output_79_temp0_6.pt"

# Load the tokens from the .pt file
loaded_tokens = torch.load(pt_file_path)

# Move to the appropriate device
inputs = loaded_tokens.to("mps")

# Print shape and first few tokens to verify
print(f"Loaded token tensor shape: {inputs.shape}")
print(f"First few tokens: {inputs[0, :10]}")

Loaded token tensor shape: torch.Size([1, 724])
First few tokens: tensor([128000, 128011,   6854,    281,  56458,  11722,     30,  38203,    279,
          7422], device='mps:0')


  loaded_tokens = torch.load(pt_file_path)


#### Analyze model thinking part (get thinking acts)

In [7]:
# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the thinking response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

# Look for <think> and </think> tags in the tokens
think_start_idx = None
think_end_idx = None

for i, token in enumerate(token_texts):
    if '<think>' in token:
        think_start_idx = i + 1  # Start after the <think> tag
    elif '</think>' in token:
        think_end_idx = i  # End before the </think> tag
        break

# If think tags not found, fall back to assistant response
if think_start_idx is None or think_end_idx is None:
    assistant_start = special_positions[-2] + 1
    thinking_tokens = slice(assistant_start, None)
else:
    thinking_tokens = slice(think_start_idx, think_end_idx)

# Get activations only for the thinking part
target_act = gather_residual_activations(model, layer_id, inputs[:, thinking_tokens])
sae_acts = sae.encode(target_act.to(torch.float32))

# Get activation statistics for thinking response
thinking_activations = sae_acts[0]
mean_activations = thinking_activations.mean(dim=0)

# Find top activated features during thinking
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during thinking:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[thinking_tokens], thinking_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.1:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.1:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Top activated features during thinking:
Feature 48761: 1.361
Feature 45162: 0.841
Feature 3528: 0.672
Feature 1421: 0.390
Feature 53595: 0.386
Feature 53323: 0.279
Feature 13155: 0.276
Feature 55538: 0.247
Feature 5909: 0.229
Feature 47160: 0.228
Feature 44814: 0.226
Feature 13650: 0.215
Feature 60476: 0.212
Feature 3831: 0.194
Feature 46104: 0.194
Feature 34347: 0.193
Feature 41304: 0.192
Feature 26171: 0.179
Feature 4677: 0.175
Feature 34876: 0.170

Activation patterns across tokens:

Token: Okay
  Feature 48761: 0.913
  Feature 60476: 0.270

Token: ,
  Feature 48761: 0.857
  Feature 60476: 0.538

Token: Ġso
  Feature 48761: 1.223
  Feature 60476: 1.686

Token: ĠI
  Feature 48761: 1.185
  Feature 60476: 1.289

Token: 'm
  Feature 48761: 1.273
  Feature 60476: 1.439

Token: Ġtrying
  Feature 48761: 1.289
  Feature 60476: 1.169
  Feature 3831: 0.282
  Feature 34347: 0.540

Token: Ġto
  Feature 48761: 0.974
  Feature 60476: 0.237

Token: Ġfigure
  Feature 48761: 1.229
  Feature 60476: 1