In [1]:
# -*- coding: utf-8 -*-
"""
Muestreo de 1000 features por SAE y 10 ejemplos por feature (secuencias distintas),
marcando TODOS los tokens de la secuencia donde z_f >= tau_k, usando token_ids desde JSONL.

Salida por SAE (JSONL):
- interpretability/data/samples_{label}__k{k}__F1000_E10.jsonl
  con registros: { "sae_label", "feature", "examples" }
"""

import json
import pickle
from pathlib import Path

import numpy as np
import zarr
import mlx.core as mx
from tqdm.notebook import tqdm
from mlx_lm.utils import load_tokenizer

# ---------- CONFIG ----------
ROOT = Path("/Users/josue/proyectos/tesis/low-rank-bilinear-sae").resolve()
INTERP_DIR = ROOT / "interpretability"
DATA_DIR = INTERP_DIR / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Zarr con activaciones del residual stream (inputs del SAE)
ZARR_ACT_PATH = "/Users/josue/proyectos/tesis/pythia/data/Pythia70M-L3-res-wiki.zarr"

# JSONL con listas de token ids (cada línea es una lista de enteros)
TOKEN_IDS_JSONL = "/Users/josue/proyectos/tesis/pythia/data/Pythia70M-L3-res-wiki-token-ids.jsonl"

# Tokenizer (ruta local o nombre en HF compatible con mlx_lm.utils)
TOKENIZER_PATH = Path("EleutherAI/pythia-70m-deduped")

# SAEs a procesar
SAE_PATHS = [
#    'checkpoints/bsae_512_2048_512__Er0_Dr0_k16_ep16_20251129-040858.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037.pkl',
#    'checkpoints/bsae_512_2048_512__Er2_Dr2_k16_ep16_20251130-113058.pkl'
]
SAE_LABELS = [Path(p).stem for p in SAE_PATHS]

# Parámetros de muestreo
K = 16                # τ_k y elegibilidad
N_FEATURES = 500     # nº de features por SAE
N_EXAMPLES = 10       # nº de ejemplos (secuencias únicas) por feature
GLOBAL_RNG_SEED = 20251127

# ---------- UTILIDADES ----------
from model import *  # Debe exponer .load(), .encoder(), .clip_norms()

def detect_F_total(sae) -> int:
    try:
        if hasattr(sae, "decoder"):
            dec = sae.decoder
            if hasattr(dec, "weight"):
                return int(dec.weight.shape[1])
            if hasattr(dec, "W"):
                return int(dec.W.shape[1])
        if hasattr(sae, "F"):
            return int(sae.F)
    except Exception:
        pass
    return 2048

def search_seq_id(cumsum: np.ndarray, gidx: int) -> int:
    return int(np.searchsorted(cumsum, int(gidx), side="right"))

# Usa tokenizer global para decodificar; marcamos tokens con z_f >= thr
def mark_sequence(seq_token_ids: np.ndarray, z_feat_seq: np.ndarray, thr: float) -> str:
    out = []
    thr = float(thr)
    for tid, s in zip(seq_token_ids, z_feat_seq):
        tok = tokenizer.decode(int(tid))
        if float(s) >= thr:
            out.append(f"<<{tok}>>({float(s):.2f})")
        else:
            out.append(tok)
    return "".join(out)

# ---------- CARGA ARRAYS Y TOKEN IDS ----------
act_zarr = zarr.open(ZARR_ACT_PATH, mode="r")   # [N, D]
TOTAL_TOKENS = int(act_zarr.shape[0])

# Cargar token_ids_docs JSONL y construir cumsum alineado a TOTAL_TOKENS (idéntico a tu procedimiento)
with open(TOKEN_IDS_JSONL, "r") as f:
    token_ids_docs = [json.loads(line) for line in f]

token_ids_cumsum = np.cumsum([len(doc) for doc in token_ids_docs], dtype=np.int64)

# Recorta/ajusta el último documento para alinear con TOTAL_TOKENS
last_seq = np.abs(TOTAL_TOKENS - token_ids_cumsum).argmin().item()
token_ids_docs = token_ids_docs[:last_seq + 1]
token_ids_docs[last_seq] = token_ids_docs[last_seq][:-int(token_ids_cumsum[last_seq] - TOTAL_TOKENS)]
token_ids_cumsum = np.cumsum([len(doc) for doc in token_ids_docs], dtype=np.int64)

#assert token_ids_cumsum[-1] == TOTAL_TOKENS, "Desalineación: TOTAL_TOKENS != suma de token_ids_docs"

tokenizer = load_tokenizer(TOKENIZER_PATH)
rng = np.random.default_rng(GLOBAL_RNG_SEED)

# ---------- LOOP POR SAE ----------
for sae_path, label in zip(SAE_PATHS, SAE_LABELS):
    print(f"[{label}] Cargando SAE / estados…")
    sae = BilinearSparseAutoencoder.load(sae_path)
    F_TOTAL = detect_F_total(sae)

    # τ_k desde scan_state
    scan_pkl = DATA_DIR / f"scan_state_{label}.pkl"
    with open(scan_pkl, "rb") as pf:
        scan_state = pickle.load(pf)
    if int(K) not in scan_state["taus"]:
        raise ValueError(f"[{label}] τ_k ausente en scan_state para k={K}.")
    tau_k = float(scan_state["taus"][int(K)])

    # reservoirs
    res_pkl = DATA_DIR / f"reservoirs_{label}.pkl"
    with open(res_pkl, "rb") as pf:
        reservoirs = pickle.load(pf)
    res_k = reservoirs["reservoir"][int(K)]
    gidx_lists = res_k["gidx"]   # list[F] de listas de gidx
    assert len(gidx_lists) == F_TOTAL, "Reservoir no coincide con F_TOTAL."

    # Elegibles: features con ≥1 gidx en reservoir
    eligible_feats = [f for f in range(F_TOTAL) if len(gidx_lists[f]) > 0]
#    if len(eligible_feats) == 0:
#        raise RuntimeError(f"[{label}] Sin features elegibles en reservoir para k={K}.")
#    if len(eligible_feats) < N_FEATURES:
#        print(f"[{label}] Aviso: solo {len(eligible_feats)} features con elegibles; se usan todos.")
#        chosen_feats = np.array(eligible_feats, dtype=np.int64)
#    else:
    chosen_feats = rng.choice(2048, size=N_FEATURES, replace=False,
                              p=np.array(scan_state['pos_counts_all'][int(K)]) / sum(scan_state['pos_counts_all'][int(K)]))

    out_jsonl = DATA_DIR / f"samples_{label}__k{K}__F{len(chosen_feats)}_E{N_EXAMPLES}.jsonl"
    print(f"[{label}] Muestreando {len(chosen_feats)} features; {N_EXAMPLES} ejemplos por feature…")

    n_feats_written = 0
    with open(out_jsonl, "w", encoding="utf-8") as jf, tqdm(total=len(chosen_feats), desc=f"[{label}] features") as pbar:
        for fidx in chosen_feats:
            fidx = int(fidx)
            cand_g = np.array(gidx_lists[fidx], dtype=np.int64)
            if cand_g.size == 0:
                pbar.update(1)
                continue
            rng.shuffle(cand_g)

            # Recolectar ejemplos (secuencias únicas) con activación real del feature
            examples = []
            seen_seq = set()

            for g in cand_g:
                sid = search_seq_id(token_ids_cumsum, int(g))
                if sid in seen_seq:
                    continue

                # Delimitar secuencia
                seq_end = int(token_ids_cumsum[sid])
                seq_start = int(token_ids_cumsum[sid - 1]) if sid > 0 else 0
                if not (0 <= seq_start < seq_end <= TOTAL_TOKENS):
                    continue

                # Activaciones y tokens de la secuencia
                x_seq = mx.array(act_zarr[seq_start:seq_end], dtype=mx.float32)   # [L, D]
                tok_seq = np.array(token_ids_docs[sid], dtype=np.int64)           # [L]
                L = seq_end - seq_start
                if len(tok_seq) != L:
                    tok_seq = tok_seq[:L]  # seguridad

                # Codificar y extraer columna del feature (→ NumPy para evitar indexación MLX)
                z_full = sae.encoder(sae.clip_norms(x_seq))                       # [L, F] (MLX)
                z_full_np = np.array(z_full, dtype=np.float32)                    # [L, F] (NumPy)
                z_feat_seq = z_full_np[:, fidx]                                   # [L]

                # Requisito: debe haber activación del feature en la secuencia
                if not np.any(z_feat_seq >= tau_k):
                    continue

                # Registrar ejemplo (texto marcado)
                marked_text = mark_sequence(tok_seq, z_feat_seq, thr=tau_k)
                examples.append(marked_text)
                seen_seq.add(sid)

                if len(examples) >= N_EXAMPLES:
                    break

            # Emitir registro solo si hay ≥1 ejemplo
            if len(examples) > 0:
                rec = {
                    "sae_label": label,
                    "feature": int(fidx),
                    "examples": examples,
                }
                jf.write(json.dumps(rec, ensure_ascii=False) + "\n")
                n_feats_written += 1

            pbar.update(1)

    print(f"[OK] {label}: {n_feats_written} features escritos → {out_jsonl}")

[bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503] Cargando SAE / estados…
[bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503__k16__F500_E10.jsonl
[bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808] Cargando SAE / estados…
[bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808__k16__F500_E10.jsonl
[bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654] Cargando SAE / estados…
[bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654__k16__F500_E10.jsonl
[bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024] Cargando SAE / estados…
[bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024__k16__F500_E10.jsonl
[bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537] Cargando SAE / estados…
[bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537__k16__F500_E10.jsonl
[bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124] Cargando SAE / estados…
[bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124__k16__F500_E10.jsonl
[bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037] Cargando SAE / estados…
[bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037] Muestreando 500 features; 10 ejemplos por feature…


[bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037] features:   0%|          | 0/500 [00:00<?, ?it/s]

[OK] bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037: 500 features escritos → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/samples_bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037__k16__F500_E10.jsonl
