# Interpretación de latentes 

In [1]:
import sys
import torch
from torch import nn
import jaxtyping
import datasets
from datasets import load_dataset
from sae import Step
import numpy as np

In [2]:
MINI_BATCH_SIZE = 1024

In [3]:
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="mech-interp-uam/llama3.2-1b-sae",
    filename = "sae.pth",
)
print(f"Checkpoint cached at -> {ckpt_path}")

Checkpoint cached at -> /home/admin/.cache/huggingface/hub/models--mech-interp-uam--llama3.2-1b-sae/snapshots/891afeb81f0a19d6a791b6b8a3110bc3d87c3f5f/sae.pth


In [4]:
ds = (
    load_dataset(
        "mech-interp-uam/llama-mlp8-outputs",
        split="train",
        # Para el desarrollo, podemos usar streaming=True o False depende de lo
        # que sea más cómodo, para los entrenamientos grandes, probablemente sea
        # conveniente # streaming=False
        # streaming=True,
    )
    .with_format("numpy")
    .batch(MINI_BATCH_SIZE)
)

In [5]:
state_dict = torch.load(ckpt_path, map_location="cpu")

# if next(iter(state_dict)).startswith("module."):
#     state_dict = {k.replace("module.", "", 1): v
#                   for k, v in state_dict.items()}

# missing, unexpected = sae.load_state_dict(state_dict, strict=False)
# if missing or unexpected:
#     print("Key mismatch:", missing, unexpected)

#sae.eval() 

In [6]:
# lógica de encoder directamente aquí, esto debería de hacerse con código en
# el módulo de sae, algo así: encoder = Encoder.from_sae_state_dict(state_dict)

# En el state dict, debería de haber un (k,v) que nos indique si usó pre_encoder
# bias, o fucionarla directamente antes de subirlo a hf


W = state_dict['enc.weight']
# Si se usó pre_encoder bias, entonces
# W(x - b_d) + b_e  = Wx + (b_e - Wb_d)
b = state_dict['enc.bias'] - W@state_dict['dec.bias']
threshold = state_dict['log_threshold'].exp()


In [None]:
# Compute expected norm
norm = (
    ds
    .shuffle()
    .take(10)
    .map(lambda row : {"norm" : np.linalg.norm(row['activations'], axis=1)})
    )
list(f for f in ds["norm"])

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Dataset({
    features: ['activations', 'norm'],
    num_rows: 10
})

In [26]:
# Get a placeholder for some sort of identifier that links output vectors of
# the mlp with their corresponding token/place in the text used to generate
# them

from jaxtyping import Array, Float, Int
from typing import Iterable

rows: Iterable[tuple[Int[Array, "b"], Float[Array, "b b_model"]]] = enumerate(row['activations'] for row in ds.take(8))

l0_moving_average = None
for i, row in rows:
    # Here we are scaling so that E[||x||] = 1, but this hardcoded value is ugly
    # Furtheremore, this 3.4 value is not uptodate, probably training sufered
    # because of this oof
    x = torch.from_numpy(row)
    x /= 3.4
    # print(x[0].norm())
    # print(x.norm(dim=1).mean())
    # print(x.norm(dim=1).size())
    x = x@W.T + b
    s = Step.apply(x, threshold)
    current_l0 = s.mean(0).sum()
    if i == 0:
        l0_moving_average = current_l0
    else:
        l0_moving_average = i/(i+1) * l0_moving_average + 1/(i+1) * current_l0

    x = s*x
print(l0_moving_average.item())
    


212.4140625


In [24]:
s.sum(1)

tensor([ 843., 1334., 1449.,  ..., 1597., 1748., 2206.])

In [None]:
from mech_interp.sae import Sae
sae = Sae(
    d_in = 2048, 
    d_sae=2048*8, 
    use_pre_enc_bias=True
)

In [16]:
in_dim = sae.enc.in_features
dummy = torch.randn(1, in_dim, device = next(sae.parameters()).device)
with torch.no_grad():
    recon_dict= sae(dummy)
print("Recon shape:", recon_dict["reconstruction"].shape)
print("Latent L0  :", recon_dict["l0"].item())

Recon shape: torch.Size([])
Latent L0  : 8294.0


In [15]:
with torch.no_grad():
    out = sae(dummy)
print(type(out))

<class 'dict'>


cargar datos con id's


In [17]:
import os
from openai import OpenAI

API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY) if API_KEY else None


In [None]:
"""""
def simulate_activating_examples_gpt(api_key, n=10) -> list[tuple[int, list[str]]]:
   
   # Simula ejemplos de activación máxima en frases relacionadas con comunicación.
    #Regresa el mismo formato que fetch_max_activating_examples().
    
    from openai import OpenAI
    client = OpenAI(api_key=api_key)

    system_prompt = (
    "You are a fluent, context-aware asssistan that wirtes natural, varied, and meaningful sentences"
    "related to the concept of commnication. Yout output will be used  to study neuron activations in a language model"
    "Each sentence must contain exactly one significant token related to communication, and that token must be wrapped with double angle brackets like this: <<talk>>"
    "Avoid generic language or filler. Ensure the token is conceptually central to the sentence's meaning"
    )
    user_prompt = (
        f"Generate {n} short sentences (10 to 15 words each) about communication. "
        "In each sentence, mark exactly one key word related to communication ussing double angle like this: <<talk>>" 
        "Return only the sentences, one per line, with no bullet poinsts, numbering, or explanations"
    )
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        max_tokens=300,
        temperature=0.7
    )

    raw_sentences = response.choices[0].message.content.strip().split("\n")
    raw_sentences = [s.strip().lstrip("0123456789. ") for s in raw_sentences if s.strip()]

    parsed = []
    for i, sentence in enumerate(raw_sentences):
        tokens = sentence.split()
        for j, tok in enumerate(tokens):
            if tok.startswith("<<") and tok.endswith(">>"):
                tokens[j] = f"<<{tok.strip('<>')}>>"
                break
        parsed.append((i, tokens))

    return parsed
"""

In [None]:
"""
def create_prompt_gpt4o_from_simulated(examples: list[tuple[int, list[str]]], use_chain_of_thought=True) -> dict[str, str]:
    
    #Construye un prompt de GPT-4o a partir de ejemplos simulados con tokens activadores marcados.
    
    formatted = "\n".join(f"{i+1}. {' '.join(tokens)}" for i, tokens in examples)

    system_prompt = (
        "You are analyzing a latent neuron in a transformer-based language model. "
        "Each sentence below contains one token that strongly activates this latent neuron, highlighted using << >>. "
        "Your task is to interpret what concept, theme, or category this neuron responds to. "
        "Be precise but not overly narrow. Use fewer than 20 words. "
        "Avoid punctuation, lists, formatting, or generic labels like 'words' or 'nouns'. "
        "Focus on shared semantic meaning across the highlighted tokens."
    )

    assistant_prompt = (
        "Start by mentally grouping the highlighted tokens into a conceptual category. "
        "Then, write a final interpretation in fewer than 20 words. "
        "This neuron activates on"
    )

    return {
        "system": system_prompt,
        "user": f"The activating examples are:\n\n{formatted}",
        "assistant": assistant_prompt,
    }
"""

In [20]:
def get_gpt4o_explanation_from_prompt(prompts: dict, n_completions=3, max_tokens=100) -> list[str]:
    """
    Llama a la API de GPT-4o con un prompt ya generado y devuelve las interpretaciones.
    """
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": prompts["system"]},
            {"role": "user", "content": prompts["user"]},
            {"role": "assistant", "content": prompts["assistant"]},
        ],
        n=n_completions,
        max_tokens=max_tokens,
        temperature=0.7,
    )
    return [choice.message.content.strip() for choice in response.choices]


In [None]:
if API_KEY:
    #simulated_examples = simulate_activating_examples_gpt(api_key=API_KEY, n=10)
    #prompt = create_prompt_gpt4o_from_simulated(simulated_examples, use_chain_of_thought=True)
    explanations = get_gpt4o_explanation_from_prompt(prompt, n_completions=3)

    for i, exp in enumerate(explanations):
        print(f"[Explicación {i+1}]: {exp}")
else:
    print("OPENAI_API_KEY no está configurada.")

OPENAI_API_KEY no está configurada.
