In [None]:
import torch
import heapq
import pickle
import time
import json
import numpy as np
import pandas as pd
import os
import cProfile
import pstats
from pathlib import Path
from tqdm.auto import tqdm
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from psutil import cpu_count

SAE_PATH      = Path('out/sae_65k_lambda26_ramp30/sae_final.pt')
TOP_N         = 50
DUMP_DIR      = Path('feature_dumps')
DUMP_DIR.mkdir(exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)
# ---- Load SAE (replace with your own class/loader) ----
from model import SAE
state_dict, config = torch.load(SAE_PATH, map_location=device).values()
sae = SAE(config['input_size'],config['hidden_size']).to(device).to(torch.bfloat16)
# Fix for "_orig_mod" prefix in state dict keys
fixed_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
sae.load_state_dict(fixed_state_dict)

n_features = sae.encode.out_features if hasattr(sae.encode,'out_features') else sae.n_features
print(f'Loaded SAE with {n_features} features')

def count_dead_features(sample_iter, sample_tokens=10_000_000):
    """Returns a boolean tensor of shape (n_features,) where True == dead."""
    fired = torch.zeros(n_features, dtype=torch.bool, device=device)
    seen  = 0
    for toks in tqdm(sample_iter, total=sample_tokens//len(next(iter(sample_iter)))):
        toks = toks.to(device)
        acts = sae.encode(toks) > 0  # bool mask of activations
        fired |= acts.any(dim=0)
        seen  += toks.size(0)
        if seen >= sample_tokens:
            break
    dead_mask = ~fired.cpu()
    print(f"Dead features: {dead_mask.sum().item()} / {n_features} ({dead_mask.float().mean()*100:.2f}%)")
    return dead_mask

# GPU optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# Constants
MODEL_NAME = "allenai/OLMo-2-1124-7B-Instruct"
BATCH_SIZE = 256
LAYER_OFFSET = -1
TOP_N = 50
device = "cuda"
DUMP_DIR = Path("./results")
os.makedirs(DUMP_DIR, exist_ok=True)

def main():
    # Load tokenizer + *half* model config (bf16, compiled)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    cfg = AutoModelForCausalLM.from_pretrained(MODEL_NAME).config
    cfg.num_hidden_layers //= 2  # half-model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        config=cfg,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        device_map=device
    ).eval()
    model = torch.compile(model, mode="reduce-overhead")

    # Load dataset
    dataset_iter = load_dataset(
        "HuggingFaceFW/fineweb",
        name="sample-10BT",
        split="train",
        streaming=False,
        num_proc=cpu_count(),
    ).shuffle()

    def residual_stream_iter(text_iter, batch=BATCH_SIZE):
        buf = []
        for record in text_iter:
            buf.extend(tokenizer(record["text"]).input_ids)
            while len(buf) >= batch:
                toks = torch.tensor(buf[:batch]).to(device)
                buf = buf[batch:]
                with torch.inference_mode():
                    outs = model(toks.unsqueeze(0), output_hidden_states=True)
                    resid = outs.hidden_states[LAYER_OFFSET].squeeze(0)  # (T, d)
                yield toks, resid  # feed straight to SAE

    def mine_top_tokens_and_dead(data_iter,
                              top_n=TOP_N,
                              target_tokens=100_000):
        """
        • Keeps the TOP-N strongest (activation, token) pairs per feature.
        • Tracks which features ever fire to flag the 'dead' ones.
        • Stops after `target_tokens` have been processed.
        """
        n_features = sae.encode.weight.shape[0]
        
        # Pre-allocate all buckets with empty heaps
        buckets = [[] for _ in range(n_features)]
        fired = torch.zeros(n_features, dtype=torch.bool, device=device)
        seen_toks = 0
        
        # Process batches with lighter progress indicator 
        start_time = time.time()
        batch_count = 0
        
        for toks, resid in data_iter:
            batch_count += 1
            if batch_count % 10 == 0:
                elapsed = time.time() - start_time
                tokens_per_sec = seen_toks / elapsed if elapsed > 0 else 0
                print(f"\rProcessed {seen_toks} tokens ({tokens_per_sec:.1f} tok/s)", end="")
                
            with torch.inference_mode():
                # Compute activations
                acts = sae.encode(resid)
                fired |= (acts > 0).any(dim=0)
                
                # Get top values and indices
                values, idx = acts.topk(1, dim=0)
                
                # Transfer to CPU in one batch  
                values_cpu = values[0].to(torch.float32).detach().cpu().numpy()
                indices_cpu = idx[0].detach().cpu().numpy()
                token_ids = toks[indices_cpu].cpu().numpy()
                
                # Process features in chunks for better performance
                for f in range(n_features):
                    val, tok_id = float(values_cpu[f]), int(token_ids[f])
                    heap = buckets[f]
                    if len(heap) < top_n:
                        heapq.heappush(heap, (val, tok_id))
                    elif val > heap[0][0]:
                        heapq.heapreplace(heap, (val, tok_id))
            
            seen_toks += toks.numel()
            if seen_toks >= target_tokens:
                break
        
        print(f"\nProcessed {seen_toks} tokens in {time.time() - start_time:.2f}s")
        
        # Post-process
        dead_mask = ~fired.cpu().numpy()  # Convert directly to numpy
        
        # Get unique token IDs for batch decoding
        unique_token_ids = set()
        for heap in buckets:
            for _, tok_id in heap:
                unique_token_ids.add(tok_id)
        
        # Convert set to list for batch decoding
        unique_token_list = list(unique_token_ids)
        decoded_tokens = tokenizer.batch_decode([[t] for t in unique_token_list])
        
        # Create mapping of token ID to decoded text
        token_id_to_text = {unique_token_list[i]: decoded_tokens[i] for i in range(len(unique_token_list))}
        
        # Create the final result with native Python types
        decoded = []
        for heap in buckets:
            feature_results = []
            for val, tok_id in sorted(heap, key=lambda x: -x[0]):  # Sort directly here
                feature_results.append((float(val), token_id_to_text[tok_id]))
            decoded.append(feature_results)  # Already sorted
        
        print(f"Dead features: {dead_mask.sum()} / {n_features} "
              f"({dead_mask.sum()/n_features*100:.2f}%)")
        
        return decoded, dead_mask

    data_iter = residual_stream_iter(dataset_iter)
    target_tokens = 50_000_000
    with torch.inference_mode():
        top_buckets, dead_mask = mine_top_tokens_and_dead(
            data_iter,
            top_n=50,
            target_tokens=50_000_000
        )

    # Save results - using faster formats where possible
    # Use pickle for faster serialization
    with open(DUMP_DIR / f"top_tokens_{target_tokens}.pkl", "wb") as f:
        pickle.dump(top_buckets, f)
    
    with open(DUMP_DIR / f"top_tokens_{target_tokens}.json", "w") as f:
        json.dump(top_buckets, f)
    
    np.save(DUMP_DIR / f"dead_features_{target_tokens}.npy", dead_mask)
    
    pd.Series(dead_mask).to_csv(DUMP_DIR / f"dead_features_{target_tokens}.csv", index=False)

# Run with profiling
if __name__ == "__main__":
    profiler = cProfile.Profile()
    profiler.enable()
    
    main()
    
    profiler.disable()
    
    # Save stats to a file
    stats = pstats.Stats(profiler)
    stats.sort_stats('cumtime')
    stats.dump_stats('profile_results.prof')
    
    print("\n\n--- Profiling Results ---")
    stats.sort_stats('cumtime').print_stats(20)
    
    print("\n\n--- Profiling Results by Function Calls ---")
    stats.sort_stats('calls').print_stats(20)

In [4]:
import asyncio, pickle, pandas as pd, instructor
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from tqdm.asyncio import tqdm
from dotenv import load_dotenv
import os
load_dotenv()
client = instructor.from_openai(
    AsyncOpenAI(base_url="https://openrouter.ai/api/v1", api_key=os.environ.get("OPENROUTER_API_KEY"))
)

class FeatureLabel(BaseModel):
    label: str = Field(..., description="≤5-word noun phrase or 'uncertain'")
    chain_of_thought: str = Field(..., description="One-sentence rationale")


def make_prompt(tokens: list[tuple]) -> str:
    token_str = ", ".join(str(tok) for _, tok in tokens)

    return (
        "You will receive several tokens that all activate *the same* hidden feature.\n\n"
        f"Tokens: {token_str}\n\n"
        "Return **only** the following JSON object (no additional text):\n"
        "{\n"
        '  "label": "<≤5-word noun phrase, or \'uncertain\', or \'punctuation\'>",\n'
        '  "chain_of_thought": "<one concise sentence explaining why>"\n'
        "}\n\n"
        "• Use \"uncertain\" if no coherent feature emerges.\n"
        "• Use \"punctuation\" if the tokens are mostly punctuation or formatting marks.\n"
        "• Do not output anything except that JSON object."
    )


top_buckets = pickle.load(open("results/top_tokens_50m.pkl", "rb"))
sem = asyncio.Semaphore(20)

async def label_one(idx, tokens):
    prompt = make_prompt(tokens)
    async with sem:
        fl = await client.chat.completions.create(
            model="google/gemini-2.0-flash-001",
            response_model=FeatureLabel,
            temperature=0.6,
            messages=[{"role":"user","content":prompt}],
        )
    return {
        "feature_id": idx,
        "label": fl.label,
        "chain_of_thought": fl.chain_of_thought,
        "tokens": tokens,
    }

async def main():
    tasks = [asyncio.create_task(label_one(i, t)) for i, t in enumerate(top_buckets)]
    return await tqdm.gather(*tasks)

rows = await main()

df = pd.DataFrame(rows).sort_values("feature_id")
df.to_csv("results/feature_labels.csv", index=False)
df.head()


100%|██████████| 65536/65536 [37:15<00:00, 29.31it/s]  


Unnamed: 0,feature_id,label,chain_of_thought,tokens
0,0,stop word,"The tokens are all common English stop words, ...","[(2.8125, the), (2.671875, ane), (2.59375, i..."
1,1,Worth,"The word 'worth' appears frequently, suggestin...","[(3.234375, can), (3.140625, Worth), (3.0312..."
2,2,punctuation,The tokens are all punctuation marks.,"[(4.5, .\n), (4.1875, \n), (4.125, \n), (3.828..."
3,3,uncertain,"The tokens appear to be a mix of common words,...","[(3.0625, _SCHEMA), (2.859375, _SCHEMA), (2.81..."
4,4,atheism,The tokens frequently co-occur with variations...,"[(2.09375, athe), (2.03125, Evans), (2.01562..."


In [1]:
# Import required libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "allenai/OLMo-2-1124-7B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

default_system_prompt = "You are OLMo 2, a helpful and harmless AI Assistant built by the Allen Institute for AI."

def chat_with_olmo(user_message, system_prompt=default_system_prompt, history=None):
    if history is None:
        history = []
    
    # Create messages list as before
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    for h in history:
        messages.append(h)
    messages.append({"role": "user", "content": user_message})
    
    formatted_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize with explicit attention mask
    inputs = tokenizer(
        formatted_prompt, 
        return_tensors="pt",
        padding=True,
        return_attention_mask=True  # Explicitly request attention mask
    ).to(device)
    
    # Generate with attention mask
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,  # Pass the attention mask
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Rest of the function remains the same
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_response = full_response.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
    
    history.append({"role": "user", "content": user_message})
    history.append({"role": "assistant", "content": assistant_response})
    
    return assistant_response, history

# Example usage
conversation_history = []

# First exchange
user_input = "What capabilities do you have as an AI assistant?"
response, conversation_history = chat_with_olmo(user_input, history=conversation_history)
print(f"User: {user_input}")
print(f"OLMo: {response}")
print("-" * 50)

# Follow-up question (demonstrating conversation history)
user_input = "Can you give me an example of how you might help with coding?"
response, conversation_history = chat_with_olmo(user_input, history=conversation_history)
print(f"User: {user_input}")
print(f"OLMo: {response}")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 49.44it/s]


User: What capabilities do you have as an AI assistant?
OLMo: As OLMo 2, I have a wide range of capabilities! Some of the key ones include understanding and generating human-like text, answering questions across many topics, and assisting with a variety of tasks such as scheduling, language translation, summarizing texts, and much more. I'm designed to continuously improve my responses based on the feedback I receive, so I'm here to learn and help you better over time. How can I assist you today?
--------------------------------------------------
User: Can you give me an example of how you might help with coding?
OLMo: Certainly! While I don't write or debug code myself, I can certainly help guide you through concepts, explain coding principles, or help you understand a particular problem you might be facing with coding. For example, if you're trying to solve a problem using Python and you're unsure about the syntax, I can provide you with the correct syntax or direct you to resources 

In [None]:
from transformers import Olmo2ForCausalLM

class SteeredOlmo(Olmo2ForCausalLM):