In [1]:
# -*- coding: utf-8 -*-
"""
Estimación ligera de τ_k y frecuencias de uso por feature SIN reservoirs.
- k ∈ {16, 32, 48, 64}
- τ_k por (SAE,k): running mean de umbrales por sub-batch (batch_topk).
- Frecuencia relativa por feature y k: count[z_f >= τ_k(sub-batch)] / tokens_procesados.
- Muestreo:
    * BLOCK_SKIP: procesa 1 de cada N bloques (bloque = BATCH_SIZE * BATCHES_PER_BLOCK filas).
    * ROW_STRIDE: dentro de cada sub-batch, usa 1 de cada ROW_STRIDE filas.

La barra de progreso (tqdm) se actualiza **por cada sub-batch efectivamente procesado**.
"""

import math
import pickle
from pathlib import Path

import numpy as np
import zarr
import mlx.core as mx
from tqdm.notebook import tqdm
from model import *

# -------------------------
# RUTAS Y CONFIG BÁSICA
# -------------------------
ROOT = Path("/Users/josue/proyectos/tesis/low-rank-bilinear-sae").resolve()
INTERP_DIR = ROOT / "interpretability"
OUT_DIR = INTERP_DIR / "data"
OUT_DIR.mkdir(parents=True, exist_ok=True)

sae_paths = [
#    'checkpoints/bsae_512_2048_512__Er0_Dr0_k16_ep16_20251129-040858.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024.pkl',
    'checkpoints/bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537.pkl',
    'checkpoints/bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124.pkl',
    'checkpoints/bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037.pkl',
#    'checkpoints/bsae_512_2048_512__Er2_Dr2_k16_ep16_20251130-113058.pkl'
]

sae_labels = [Path(p).stem for p in sae_paths]

zarr_path = "/Users/josue/proyectos/tesis/pythia/data/Pythia70M-L3-res-wiki.zarr"

# k a estimar
KSET = [4, 8, 16]

# -------------------------
# PARÁMETROS DE ESCANEO LIGERO
# -------------------------
BATCH_SIZE = 4096
BATCHES_PER_BLOCK = 100

# Procesa solo 1 de cada N BLOQUES (si N=1: procesa todos)
BLOCK_SKIP = 4       # p.ej., ~25% de los bloques

# Dentro de cada sub-batch, usa 1 de cada ROW_STRIDE filas (si 1: usa todas)
ROW_STRIDE = 1       # p.ej., ~50% de filas

# -------------------------
# AUXILIARES
# -------------------------
class RunningMean:
    """Media incremental simple (para τ_k)."""
    def __init__(self):
        self.n = 0
        self.mu = 0.0
    def update(self, x: float):
        self.n += 1
        self.mu += (x - self.mu) / self.n
    def value(self) -> float:
        return float(self.mu) if self.n > 0 else 0.0

def block_iterator(zarray, block_size):
    """Itera por bloques de tamaño fijo (en filas)."""
    total_len = zarray.shape[0]
    for start in range(0, total_len, block_size):
        yield start, mx.array(zarray[start:start+block_size], dtype=mx.float32)  # [B_block, D]

def batch_tau_from_topk(z_full_mx: mx.array, k_target: int, row_stride: int = 1) -> float:
    """
    Umbral del sub-batch como el mínimo de los top-(k_target * B_eff) de z aplanado,
    donde B_eff = max(B // row_stride, 1).
    """
    if k_target <= 0:
        return 0.0
    B = int(z_full_mx.shape[0])
    if B == 0:
        return 0.0
    B_eff = max(B // row_stride, 1)
    kk = k_target * B_eff
    flat = mx.reshape(z_full_mx[::row_stride], (-1,))
    kk = min(kk, int(flat.shape[0]))
    if kk <= 0:
        return 0.0
    vals = mx.topk(flat, kk)          # descendente
    return float(mx.min(vals))

def detect_F_total(sae) -> int:
    """Detecta F_total desde el SAE (decoder o atributo F); fallback 8192."""
    try:
        if hasattr(sae, "decoder"):
            dec = sae.decoder
            if hasattr(dec, "weight"):
                return int(dec.layers[0].weight.shape[1])
            if hasattr(dec, "W"):
                return int(dec.W.shape[1])
        if hasattr(sae, "F"):
            return int(sae.F)
    except Exception:
        pass
    return 2048

# -------------------------
# CARGA DE DATOS/MODELOS
# -------------------------
zarr_arr = zarr.open(zarr_path, mode="r")
TOTAL_LEN = zarr_arr.shape[0]

saes = [BilinearSparseAutoencoder.load(p) for p in sae_paths]
F_TOTAL = detect_F_total(saes[0])

# -------------------------
# PRE-CÁLCULO: nº de SUB-BATCHES que se procesarán (para tqdm)
# -------------------------
BLOCK_SIZE = BATCH_SIZE * BATCHES_PER_BLOCK
total_blocks = (TOTAL_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE

# Longitud real de cada bloque (el último puede ser más corto)
block_lengths = []
for b in range(total_blocks):
    start = b * BLOCK_SIZE
    end = min((b + 1) * BLOCK_SIZE, TOTAL_LEN)
    block_lengths.append(end - start)

# Índices de bloques que sí se procesarán (aplicando BLOCK_SKIP)
processed_block_indices = [b for b in range(total_blocks) if (BLOCK_SKIP <= 1 or (b % BLOCK_SKIP == 0))]

# Batches por bloque procesado
batches_per_processed_block = [
    math.ceil(block_lengths[b] / BATCH_SIZE) for b in processed_block_indices
]

total_batches_to_process = sum(batches_per_processed_block)

# -------------------------
# LOOP LIGERO (por BLOQUE real) – tqdm por SUB-BATCH
# -------------------------
for sae, path, label in zip(saes, sae_paths, sae_labels):

    # Running taus por k
    tau_running = {k: RunningMean() for k in KSET}

    # Conteos por feature y k
    pos_counts_all = {k: np.zeros(F_TOTAL, dtype=np.int64) for k in KSET}

    tokens_seen = 0               # filas efectivamente evaluadas (tras ROW_STRIDE)
    processed_batches = 0         # sub-batches efectivamente procesados

    desc = f"[{label}] Batches procesados"
    with tqdm(total=total_batches_to_process, desc=desc, leave=True) as pbar:

        for bidx, (bstart, xblock) in enumerate(block_iterator(zarr_arr, BLOCK_SIZE)):

            # Aplica salto de bloques; si se salta, no cuenta batches
            if BLOCK_SKIP > 1 and (bidx % BLOCK_SKIP != 0):
                continue

            # Recorre el bloque en sub-batches de BATCH_SIZE
            B_block = int(xblock.shape[0])
            for off in range(0, B_block, BATCH_SIZE):
                xbatch = xblock[off:off+BATCH_SIZE]   # [B, D]
                B_cur = int(xbatch.shape[0])
                if B_cur == 0:
                    continue

                # Encode sub-batch
                z_full = sae.encoder(sae.clip_norms(xbatch))  # [B, F] (MLX)

                # τ_k por k con submuestreo de filas
                batch_taus = {}
                for k in KSET:
                    tk = batch_tau_from_topk(z_full, k_target=k, row_stride=ROW_STRIDE)
                    batch_taus[k] = tk
                    tau_running[k].update(tk)

                # Submuestreo de filas para conteo
                z_sub = z_full[::ROW_STRIDE]                   # [B_eff, F]
                B_eff = int(z_sub.shape[0])
                if B_eff > 0:
                    # Conteos por columnas
                    for k in KSET:
                        thr = batch_taus[k]
                        mask = (z_sub >= thr)                      # MLX [B_eff, F]
                        col_counts = np.array(mx.sum(mask, axis=0), dtype=np.int64)  # [F]
                        pos_counts_all[k] += col_counts

                    tokens_seen += B_eff

                processed_batches += 1
                pbar.update(1)

    # τ finales por k
    TAU_FINAL = {int(k): float(tau_running[k].value()) for k in KSET}

    # Frecuencias relativas por feature y k
    if tokens_seen > 0:
        freq_all = {
            int(k): (pos_counts_all[k].astype(np.float32) / float(tokens_seen)).tolist()
            for k in KSET
        }
    else:
        freq_all = {int(k): (pos_counts_all[k].astype(np.float32)).tolist() for k in KSET}

    # Guardar estado
    state = {
        "F_total": int(F_TOTAL),
        "taus": TAU_FINAL,
        "pos_counts_all": {int(k): pos_counts_all[k].tolist() for k in KSET},
        "freq_all": freq_all,
        "tokens_seen": int(tokens_seen),
        "meta": {
            "note": "Estimación ligera sin reservoirs. Frecuencia por feature/k con umbral por sub-batch.",
            "block_skip": int(BLOCK_SKIP),
            "row_stride": int(ROW_STRIDE),
            "batch_size": int(BATCH_SIZE),
            "batches_per_block": int(BATCHES_PER_BLOCK),
            "kset": list(map(int, KSET)),
            "total_len": int(TOTAL_LEN),
            "total_blocks": int(total_blocks),
            "blocks_processed": int(len(processed_block_indices)),
            "total_batches_to_process": int(total_batches_to_process),
            "batches_processed": int(processed_batches),
        },
    }
    pkl_path = OUT_DIR / f"scan_state_{label}.pkl"
    with open(pkl_path, "wb") as pf:
        pickle.dump(state, pf)
    print(
        f"[OK] Guardado: {pkl_path} | tokens_seen={tokens_seen:,} | "
        f"batches={processed_batches}/{total_batches_to_process} "
        f"(blocks {len(processed_block_indices)}/{total_blocks}, skip={BLOCK_SKIP})"
    )

[bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)


[bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037] Batches procesados:   0%|          | 0/700 [00:00<?, ?it…

[OK] Guardado: /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/scan_state_bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037.pkl | tokens_seen=2,867,200 | batches=700/700 (blocks 7/27, skip=4)
