In [1]:
import torch
import heapq
import pickle
import time
import json
import numpy as np
import pandas as pd
import os
import cProfile
import pstats
from pathlib import Path
from tqdm.auto import tqdm
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from psutil import cpu_count

SAE_PATH      = Path('out/sae_65k_lambda_ramp_80to25/sae_final.pt')
TOP_N         = 50
DUMP_DIR      = Path('feature_dumps')
DUMP_DIR.mkdir(exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)
# ---- Load SAE (replace with your own class/loader) ----
from model import SAE
state_dict, config = torch.load(SAE_PATH, map_location=device).values()
sae = SAE(config['input_size'],config['hidden_size']).to(device).to(torch.bfloat16)
# Fix for "_orig_mod" prefix in state dict keys
fixed_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
sae.load_state_dict(fixed_state_dict)

n_features = sae.encode.out_features if hasattr(sae.encode,'out_features') else sae.n_features
print(f'Loaded SAE with {n_features} features')

def count_dead_features(sample_iter, sample_tokens=10_000_000):
    """Returns a boolean tensor of shape (n_features,) where True == dead."""
    fired = torch.zeros(n_features, dtype=torch.bool, device=device)
    seen  = 0
    for toks in tqdm(sample_iter, total=sample_tokens//len(next(iter(sample_iter)))):
        toks = toks.to(device)
        acts = sae.encode(toks) > 0  # bool mask of activations
        fired |= acts.any(dim=0)
        seen  += toks.size(0)
        if seen >= sample_tokens:
            break
    dead_mask = ~fired.cpu()
    print(f"Dead features: {dead_mask.sum().item()} / {n_features} ({dead_mask.float().mean()*100:.2f}%)")
    return dead_mask

# GPU optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# Constants
MODEL_NAME = "allenai/OLMo-2-1124-7B-Instruct"
BATCH_SIZE = 256
LAYER_OFFSET = -1
TOP_N = 50
device = "cuda"
DUMP_DIR = Path("./results")
os.makedirs(DUMP_DIR, exist_ok=True)

def main():
    # Load tokenizer + *half* model config (bf16, compiled)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    cfg = AutoModelForCausalLM.from_pretrained(MODEL_NAME).config
    cfg.num_hidden_layers //= 2  # half-model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        config=cfg,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        device_map=device
    ).eval()
    model = torch.compile(model, mode="reduce-overhead")

    # Load dataset
    dataset_iter = load_dataset(
        "HuggingFaceFW/fineweb",
        name="sample-10BT",
        split="train",
        streaming=False,
        num_proc=cpu_count(),
    ).shuffle()

    def residual_stream_iter(text_iter, batch=BATCH_SIZE):
        buf = []
        for record in text_iter:
            buf.extend(tokenizer(record["text"]).input_ids)
            while len(buf) >= batch:
                toks = torch.tensor(buf[:batch]).to(device)
                buf = buf[batch:]
                with torch.inference_mode():
                    outs = model(toks.unsqueeze(0), output_hidden_states=True)
                    resid = outs.hidden_states[LAYER_OFFSET].squeeze(0)  # (T, d)
                yield toks, resid  # feed straight to SAE

    def mine_top_tokens_and_dead(data_iter,
                              top_n=TOP_N,
                              target_tokens=100_000):
        """
        • Keeps the TOP-N strongest (activation, token) pairs per feature.
        • Tracks which features ever fire to flag the 'dead' ones.
        • Stops after `target_tokens` have been processed.
        """
        n_features = sae.encode.weight.shape[0]
        
        # Pre-allocate all buckets with empty heaps
        buckets = [[] for _ in range(n_features)]
        fired = torch.zeros(n_features, dtype=torch.bool, device=device)
        seen_toks = 0
        
        # Process batches with lighter progress indicator 
        start_time = time.time()
        batch_count = 0
        
        for toks, resid in data_iter:
            batch_count += 1
            if batch_count % 10 == 0:
                elapsed = time.time() - start_time
                tokens_per_sec = seen_toks / elapsed if elapsed > 0 else 0
                print(f"\rProcessed {seen_toks} tokens ({tokens_per_sec:.1f} tok/s)", end="")
                
            with torch.inference_mode():
                # Compute activations
                acts = sae.encode(resid)
                fired |= (acts > 0).any(dim=0)
                
                # Get top values and indices
                values, idx = acts.topk(1, dim=0)
                
                # Transfer to CPU in one batch  
                values_cpu = values[0].to(torch.float32).detach().cpu().numpy()
                indices_cpu = idx[0].detach().cpu().numpy()
                token_ids = toks[indices_cpu].cpu().numpy()
                
                # Process features in chunks for better performance
                for f in range(n_features):
                    val, tok_id = float(values_cpu[f]), int(token_ids[f])
                    heap = buckets[f]
                    if len(heap) < top_n:
                        heapq.heappush(heap, (val, tok_id))
                    elif val > heap[0][0]:
                        heapq.heapreplace(heap, (val, tok_id))
            
            seen_toks += toks.numel()
            if seen_toks >= target_tokens:
                break
        
        print(f"\nProcessed {seen_toks} tokens in {time.time() - start_time:.2f}s")
        
        # Post-process
        dead_mask = ~fired.cpu().numpy()  # Convert directly to numpy
        
        # Get unique token IDs for batch decoding
        unique_token_ids = set()
        for heap in buckets:
            for _, tok_id in heap:
                unique_token_ids.add(tok_id)
        
        # Convert set to list for batch decoding
        unique_token_list = list(unique_token_ids)
        decoded_tokens = tokenizer.batch_decode([[t] for t in unique_token_list])
        
        # Create mapping of token ID to decoded text
        token_id_to_text = {unique_token_list[i]: decoded_tokens[i] for i in range(len(unique_token_list))}
        
        # Create the final result with native Python types
        decoded = []
        for heap in buckets:
            feature_results = []
            for val, tok_id in sorted(heap, key=lambda x: -x[0]):  # Sort directly here
                feature_results.append((float(val), token_id_to_text[tok_id]))
            decoded.append(feature_results)  # Already sorted
        
        print(f"Dead features: {dead_mask.sum()} / {n_features} "
              f"({dead_mask.sum()/n_features*100:.2f}%)")
        
        return decoded, dead_mask

    data_iter = residual_stream_iter(dataset_iter)
    with torch.inference_mode():
        top_buckets, dead_mask = mine_top_tokens_and_dead(
            data_iter,
            top_n=50,
            target_tokens=40_000_000
        )

    # Save results - using faster formats where possible
    # Use pickle for faster serialization
    with open(DUMP_DIR / "top_tokens.pkl", "wb") as f:
        pickle.dump(top_buckets, f)
    
    with open(DUMP_DIR / "top_tokens.json", "w") as f:
        json.dump(top_buckets, f)
    
    np.save(DUMP_DIR / "dead_features.npy", dead_mask)
    
    pd.Series(dead_mask).to_csv(DUMP_DIR / "dead_features.csv", index=False)

# Run with profiling
if __name__ == "__main__":
    profiler = cProfile.Profile()
    profiler.enable()
    
    main()
    
    profiler.disable()
    
    # Save stats to a file
    stats = pstats.Stats(profiler)
    stats.sort_stats('cumtime')
    stats.dump_stats('profile_results.prof')
    
    print("\n\n--- Profiling Results ---")
    stats.sort_stats('cumtime').print_stats(20)
    
    print("\n\n--- Profiling Results by Function Calls ---")
    stats.sort_stats('calls').print_stats(20)

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


  state_dict, config = torch.load(SAE_PATH, map_location=device).values()


Loaded SAE with 65536 features


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.04s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.70it/s]
Some weights of the model checkpoint at allenai/OLMo-2-1124-7B-Instruct were not used when initializing Olmo2ForCausalLM: ['model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.post_feedforward_layernorm.weight', 'model.layers.16.self_attn.k_norm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_norm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.post_feedforward_layernorm.weight', 'model.layers.17.self_attn.k_

Processed 39999744 tokens (2228.3 tok/s)
Processed 40000000 tokens in 17951.12s
Dead features: 21834 / 65536 (33.32%)


--- Profiling Results ---
         11274669237 function calls (11232668199 primitive calls) in 17826.901 seconds

   Ordered by: cumulative time
   List reduced from 14472 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    3.715    3.715 18038.575 18038.575 /tmp/ipykernel_2205077/2717389824.py:66(main)
        1 9835.364 9835.364 17958.598 17958.598 /tmp/ipykernel_2205077/2717389824.py:101(mine_top_tokens_and_dead)
   468751 5528.595    0.012 5528.595    0.012 {method 'cpu' of 'torch._C.TensorBase' objects}
   156251   11.797    0.000 1946.550    0.012 /tmp/ipykernel_2205077/2717389824.py:89(residual_stream_iter)
6094293/312500   11.743    0.000 1512.591    0.005 /home/henry/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1549(_wrapped_call_impl)
6094293/312500  144.188    0.000 1511.623  

In [None]:
from openai import OpenAI

def auto_label_feature(token_list):
    """Return a one‑line label via OpenAI LLM (adjust model & key)."""
    prompt = (
        'Tokens that all activate the same hidden feature:\n' +
        ', '.join(token_list) +
        '\n\nGive a concise, literal description of what these tokens have in common. '
        'If unsure, reply "uncertain".'
    )
    client = OpenAI()
    resp = client.chat.completions.create(model='gpt-4o-mini', messages=[{'role':'user','content':prompt}])
    return resp.choices[0].message.content.strip()

# Example:
# labels = [auto_label_feature([tok for _, tok in bucket]) for bucket in top_buckets]
# pd.Series(labels).to_csv(DUMP_DIR/'auto_labels.csv')