In [1]:
# -*- coding: utf-8 -*-
"""
Construcción de reservoirs por (SAE, k, feature) usando τ_k precomputados en los scan_state_*.pkl.
- Recorre TODO el dataset (sin saltos).
- Muestreo aleatorio uniforme de activaciones z_f >= τ_k (reservoir sampling).
- Guarda (gidx, score) por feature y por k en reservoirs_{label}.pkl.

Estructura de salida (pickle):
{
  "taus": {k: float},
  "reservoir_cap": int,
  "target_per_feature": int,
  "rng_seed": int,
  "tokens_seen": int,
  "counts": {k: [F] int64},                       # nº activaciones elegibles vistas por feature
  "reservoir": {
      k: {
         "gidx":  [ [int,...] for f in range(F) ],  # hasta reservoir_cap por feature
         "score": [ [float,...] for f in range(F) ]
      },
      ...
  },
  "meta": { ... }
}
"""

import pickle
from pathlib import Path

import numpy as np
import zarr
import mlx.core as mx
from tqdm.notebook import tqdm
from model import *

# -------------------------
# RUTAS Y CONFIG
# -------------------------
ROOT = Path("/Users/josue/proyectos/tesis/low-rank-bilinear-sae").resolve()
INTERP_DIR = ROOT / "interpretability"
DATA_DIR = INTERP_DIR / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Zarr con activaciones del residual stream (inputs del SAE)
ZARR_PATH =  "/Users/josue/proyectos/tesis/pythia/data/Pythia70M-L3-res-wiki.zarr"

SAE_PATHS = [
#    'checkpoints/bsae_512_2048_512__Er0_Dr0_k16_ep16_20251129-040858.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037.pkl',
#    'checkpoints/bsae_512_2048_512__Er2_Dr2_k16_ep16_20251130-113058.pkl'
]
SAE_LABELS = [Path(p).stem for p in SAE_PATHS]

# K a considerar (deben aparecer en los scan_state)
KSET = [4, 8, 16]

# Batching (sin saltos de bloques)
BATCH_SIZE = 4096

# Reservoirs
TARGET_PER_FEATURE = 10          # objetivo final por feature (tendrás que deduplicar luego por frase)
RESERVOIR_CAP = 64               # > TARGET_PER_FEATURE para permitir deduplicación posterior
GLOBAL_RNG_SEED = 42

# -------------------------
# AUXILIARES
# -------------------------
def detect_F_total(sae) -> int:
    """Intenta detectar F_total desde el SAE; fallback 8192."""
    try:
        if hasattr(sae, "decoder"):
            dec = sae.decoder
            if hasattr(dec, "weight"):
                return int(dec.weight.shape[1])
            if hasattr(dec, "W"):
                return int(dec.W.shape[1])
        if hasattr(sae, "F"):
            return int(sae.F)
    except Exception:
        pass
    return 2048

def batch_iterator(zarray, batch_size):
    """Itera linealmente por el dataset en batches contiguos (sin saltos)."""
    total_len = zarray.shape[0]
    for start in range(0, total_len, batch_size):
        xb = mx.array(zarray[start:start+batch_size], dtype=mx.float32)
        yield start, xb  # [B, D]

def make_reservoir_struct(F: int, kset):
    """
    Crea estructuras compactas:
    - gidx_lists[k][f] -> list[int]
    - score_lists[k][f] -> list[float]
    - counts[k] -> np.ndarray[F] int64 (total de elegibles vistos)
    """
    gidx_lists = {k: [list() for _ in range(F)] for k in kset}
    score_lists = {k: [list() for _ in range(F)] for k in kset}
    counts = {k: np.zeros(F, dtype=np.int64) for k in kset}
    return gidx_lists, score_lists, counts

def reservoir_update(counts_arr, g_lists, s_lists, feat_idx: int, score: float, gidx: int, cap: int, rng):
    """
    Actualiza el reservoir para un feature (reservoir sampling clásico).
    - counts_arr[feat] es el nº total de elegibles vistos ANTES de este.
    - Con prob cap / n se mantiene uno de los anteriores; si la lista < cap, se añade.
    """
    n_prev = int(counts_arr[feat_idx])
    n_cur = n_prev + 1
    counts_arr[feat_idx] = n_cur

    lst_g = g_lists[feat_idx]
    lst_s = s_lists[feat_idx]

    if len(lst_g) < cap:
        lst_g.append(int(gidx))
        lst_s.append(float(score))
    else:
        j = rng.integers(0, n_cur)  # uniforme en {0,...,n_cur-1}
        if j < cap:
            lst_g[j] = int(gidx)
            lst_s[j] = float(score)

# -------------------------
# CARGA ZARR Y LOOP POR SAE
# -------------------------
zarr_arr = zarr.open(ZARR_PATH, mode="r")
TOTAL_LEN = zarr_arr.shape[0]

rng = np.random.default_rng(GLOBAL_RNG_SEED)

for sae_path, label in zip(SAE_PATHS, SAE_LABELS):
    # Cargar SAE y F_total
    sae = BilinearSparseAutoencoder.load(sae_path)
    F_TOTAL = detect_F_total(sae)

    # Cargar τ_k desde scan_state_{label}.pkl
    scan_pkl = DATA_DIR / f"scan_state_{label}.pkl"
    with open(scan_pkl, "rb") as pf:
        scan_state = pickle.load(pf)

    taus = {int(k): float(scan_state["taus"][int(k)]) for k in KSET}
    # Validación mínima
    missing = [k for k in KSET if int(k) not in scan_state["taus"]]
    if missing:
        raise ValueError(f"[{label}] Faltan taus para k={missing} en {scan_pkl}")

    # Estructuras de reservoirs
    gidx_lists, score_lists, counts = make_reservoir_struct(F_TOTAL, KSET)

    tokens_seen = 0

    # tqdm por batch efectivo
    total_batches = (TOTAL_LEN + BATCH_SIZE - 1) // BATCH_SIZE
    desc = f"[{label}] Construyendo reservoirs (k16/32/48/64)"
    with tqdm(total=total_batches, desc=desc, leave=True) as pbar:
        for gstart, xb in batch_iterator(zarr_arr, BATCH_SIZE):
            # Encode batch completo
            z_full = sae.encoder(sae.clip_norms(xb))   # [B, F] MLX
            z_np = np.array(z_full, dtype=np.float32)  # CPU para indexar rápido
            B = z_np.shape[0]
            if B == 0:
                pbar.update(1)
                continue

            # Para cada k, usa τ_k global del scan
            for k in KSET:
                thr = taus[int(k)]
                mask = z_np >= thr                 # [B, F] boolean
                rows, cols = np.nonzero(mask)      # activaciones elegibles

                # Actualiza reservoirs por (feature=col)
                # gidx global del token = gstart + fila_local
                # score = z_np[row, col]
                for r, c in zip(rows, cols):
                    gidx = int(gstart + r)
                    score = float(z_np[r, c])
                    reservoir_update(
                        counts_arr=counts[int(k)],
                        g_lists=gidx_lists[int(k)],
                        s_lists=score_lists[int(k)],
                        feat_idx=int(c),
                        score=score,
                        gidx=gidx,
                        cap=RESERVOIR_CAP,
                        rng=rng,
                    )

            tokens_seen += B
            pbar.update(1)

    # Guardar reservoirs
    out_pkl = DATA_DIR / f"reservoirs_{label}.pkl"
    out_state = {
        "taus": taus,
        "reservoir_cap": int(RESERVOIR_CAP),
        "target_per_feature": int(TARGET_PER_FEATURE),
        "rng_seed": int(GLOBAL_RNG_SEED),
        "tokens_seen": int(tokens_seen),
        "counts": {int(k): counts[int(k)].tolist() for k in KSET},
        "reservoir": {
            int(k): {
                "gidx": gidx_lists[int(k)],
                "score": score_lists[int(k)],
            } for k in KSET
        },
        "meta": {
            "note": "Reservoirs por (k, feature) con muestreo uniforme sobre activaciones z_f >= tau_k.",
            "zarr_path": ZARR_PATH,
            "sae_path": sae_path,
            "F_total": int(F_TOTAL),
            "batch_size": int(BATCH_SIZE),
            "kset": list(map(int, KSET)),
            "gidx_semantics": "gidx es índice global del token en el zarr (fila absoluta).",
        },
    }
    with open(out_pkl, "wb") as pf:
        pickle.dump(out_state, pf)

    print(
        f"[OK] {label}: reservoirs guardados en {out_pkl} | "
        f"tokens_seen={tokens_seen:,} | cap={RESERVOIR_CAP} | target={TARGET_PER_FEATURE}"
    )

[bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124.pkl | tokens_seen=11,000,000 | cap=64 | target=10


[bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037] Construyendo reservoirs (k16/32/48/64):   0%|          |…

[OK] bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037: reservoirs guardados en /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/reservoirs_bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037.pkl | tokens_seen=11,000,000 | cap=64 | target=10
