In [1]:
import os
_ = os.environ.setdefault('GEMINI_API_KEY', 'AIzaSyDARA_wlniAc_y4whYoc4w9ibwBFQhzpRo')

In [2]:
# %% [markdown]
# Runner de interpretabilidad (solo API, sin parsear)
# - Lee features desde JSONL generados por el sampler:
#     líneas con: {"sae_label", "feature", "examples": [ "<marked_text>", ... ]}
# - Construye prompts SOLO con ejemplos positivos (sin near-miss)
# - Llama a Gemini 2.5 Flash (JSON estricto esperado; aquí no parseamos)
# - Guarda resultados en interpretability/data/llm-responses.jsonl con resume idempotente

# %%
import os
import re
import json
import time
import threading
from typing import List, Dict, Tuple
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from tqdm.notebook import tqdm
import google.generativeai as genai

# ==============================
# CONFIG
# ==============================
ROOT = Path("/Users/josue/proyectos/tesis/low-rank-bilinear-sae").resolve()
INTERP_DIR = ROOT / "interpretability"
DATA_DIR = INTERP_DIR / "data"

# Etiquetas SAE (stems de los checkpoints) y K-set
SAE_LABELS = [
#    'bsae_512_2048_512__Er0_Dr0_k16_ep16_20251129-040858',
    'bsae_512_2048_512__Er1_Dr0_k16_ep16_20251129-152503',
    'bsae_512_2048_512__Er2_Dr0_k16_ep16_20251129-183808',
    'bsae_512_2048_512__Er0_Dr1_k16_ep16_20251129-204654',
    'bsae_512_2048_512__Er1_Dr1_k16_ep16_20251129-233024',
    'bsae_512_2048_512__Er2_Dr1_k16_ep16_20251130-025537',
    'bsae_512_2048_512__Er0_Dr2_k16_ep16_20251130-051124',
    'bsae_512_2048_512__Er1_Dr2_k16_ep16_20251130-080037',
#    'bsae_512_2048_512__Er2_Dr2_k16_ep16_20251130-113058'
]
KSET = [16]

# Entradas: archivos JSONL con nombre tipo:
#   samples_{label}__k{k}__F{...}_E{...}.jsonl
SAE_FILES: List[Path] = []
for lbl in SAE_LABELS:
    for k in KSET:
        matches = sorted(DATA_DIR.glob(f"samples_{lbl}__k{k}__F*_E*.jsonl"))
        if not matches:
            raise FileNotFoundError(f"No se encontró samples para {lbl} k={k} en {DATA_DIR}")
        SAE_FILES.append(matches[-1])  # toma el más “reciente” por orden lexicográfico

OUTPUT_JSONL = DATA_DIR / "llm-responses.jsonl"

GENERATION_MODEL = "models/gemini-2.5-flash"

# API key SOLO por variable de entorno
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise RuntimeError("Falta GEMINI_API_KEY en el entorno. Exporta la variable antes de ejecutar.")

# Límites de cuenta (ajusta si corresponde)
LLM_LIMITS = {"rpm": 1000, "tpm": 1_000_000}
MAX_WORKERS = 64

# ==============================
# I/O jobs y resume
# ==============================
def _parse_topk_from_stem(stem: str) -> int:
    # p.ej., "samples_...__k16__F1000_E10" → 16
    m = re.search(r"__k(\d+)", stem)
    return int(m.group(1)) if m else -1

def load_feature_lines(path: Path, sae_model_label: str, topk_from_file: int) -> List[Dict]:
    """
    Carga líneas JSON (cada una = un feature) en el nuevo formato:
      {"sae_label": str, "feature": int, "examples": [ "<marked_text>", ... ]}
    Devuelve registros normalizados para el runner:
      {"sae_model", "topk", "feature_id", "examples":[str,...]}
    """
    feats = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rec = json.loads(line)
            feats.append({
                "sae_model": sae_model_label,                            # usamos el del archivo
                "topk": int(rec.get("topk", topk_from_file)),            # desde el nombre del archivo
                "feature_id": int(rec.get("feature")),                   # nuevo campo
                "examples": list(rec.get("examples", [])),               # lista de strings marcados
            })
    return feats

def load_all_jobs(file_paths: List[Path]) -> List[Dict]:
    """
    Lee todos los JSONL y rinde lista de jobs:
      {"sae_model","topk","feature_id","examples":[str,...]}
    """
    jobs = []
    for fp in tqdm(file_paths, desc="Cargando features (JSONL)"):
        if not fp.exists():
            raise FileNotFoundError(f"No existe {fp}")
        stem = fp.stem  # p.ej. "samples_{lbl}__k16__F1000_E10"
        topk = _parse_topk_from_stem(stem)
        # 'sae_model' puro: parte antes de "__k"
        # stem puede partir con "samples_{lbl}", así que extraemos hasta "__k"
        sae_model_pure = stem.split("__k")[0]
        # si empieza con "samples_", quítalo para dejar solo el label original
        if sae_model_pure.startswith("samples_"):
            sae_model_pure = sae_model_pure[len("samples_"):]
        feats = load_feature_lines(fp, sae_model_pure, topk)
        jobs.extend(feats)
    return jobs

def load_processed_set(output_path: Path) -> set[Tuple[str, int, int]]:
    """Set de (sae_model, topk, feature_id) ya escritos en OUTPUT_JSONL."""
    done: set[Tuple[str, int, int]] = set()
    if not output_path.exists():
        return done
    with open(output_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue
            key = (rec.get("sae_model"),
                   int(rec.get("topk", -1)),
                   int(rec.get("feature_id", -1)))
            done.add(key)
    return done

def filter_pending_jobs(jobs, processed):
    return [it for it in jobs if (it["sae_model"], int(it["topk"]), int(it["feature_id"])) not in processed]

# ==============================
# Rate limiting (RPM/TPM por minuto)
# ==============================
class RateLimiter:
    def __init__(self, rpm: int, tpm: int):
        self.rpm = rpm
        self.tpm = tpm
        self.req_times = []
        self.tok_times = []  # (timestamp, tokens)
        self.lock = threading.Lock()

    @staticmethod
    def approx_tokens(s: str) -> int:
        return max(1, len(s) // 4)

    def _prune(self, now):
        self.req_times = [t for t in self.req_times if now - t < 60]
        self.tok_times = [(t, k) for (t, k) in self.tok_times if now - t < 60]

    def acquire(self, texts: List[str]):
        need = sum(self.approx_tokens(x) for x in texts)
        while True:
            with self.lock:
                now = time.time()
                self._prune(now)
                used_rpm = len(self.req_times)
                used_tpm = sum(k for _, k in self.tok_times)
                ok_rpm = (used_rpm + 1) <= self.rpm
                ok_tpm = (used_tpm + need) <= self.tpm
                if ok_rpm and ok_tpm:
                    self.req_times.append(now)
                    self.tok_times.append((now, need))
                    return
                wait = 0.05
                if not ok_rpm and self.req_times:
                    wait = max(wait, 60 - (now - self.req_times[0]))
                if not ok_tpm and self.tok_times:
                    wait = max(wait, 60 - (now - self.tok_times[0][0]))
            time.sleep(min(2.0, wait))

def approx_tokens(s: str) -> int:
    return RateLimiter.approx_tokens(s)

# ==============================
# Normalización de marcadores
# ==============================
_marker_with_score = re.compile(r"<<(.*?)>>\((\d+(?:\.\d+)?)\)", re.DOTALL)
_marker_no_score   = re.compile(r"<<(.*?)>>", re.DOTALL)

def _normalize_marker_inner(inner: str) -> str:
    s = inner
    s = s.replace("\\n", "[NL]").replace("\n", "[NL]")
    s = s.replace("\\t", "[TAB]").replace("\t", "[TAB]")
    if re.fullmatch(r"[ ]+", s):
        s = "[WSP]"
    return s

def normalize_markers_in_text(text: str) -> str:
    def sub_with_score(m: re.Match) -> str:
        inner, score = m.group(1), m.group(2)
        return f"<<{_normalize_marker_inner(inner)}>>({score})"
    text2 = _marker_with_score.sub(sub_with_score, text)

    def sub_no_score(m: re.Match) -> str:
        inner = m.group(1)
        return f"<<{_normalize_marker_inner(inner)}>>"
    text2 = _marker_no_score.sub(sub_no_score, text2)
    return text2

def needs_whitespace_handling(examples_texts: List[str]) -> bool:
    normed = [normalize_markers_in_text(e) for e in examples_texts]
    pattern = re.compile(r"<<(.*?)>>", re.DOTALL)
    for ex in normed:
        for inner in pattern.findall(ex):
            if any(tok in inner for tok in ("[NL]", "[NL2]", "[TAB]", "[WSP]")):
                return True
    return False

# =========================
# Construcción del prompt
# =========================
def build_prompt_from_record(item: Dict) -> str:
    """
    item (nuevo): {"sae_model","topk","feature_id","examples":[str,...]}
    NOTA: ya no hay near-miss ni thresholds/deltas.
    """
    k = int(item.get("topk", -1))

    # Solo ejemplos positivos (marcados)
    pos_texts: List[str] = []
    for txt in item.get("examples", []):
        pos_texts.append(normalize_markers_in_text(txt))

    ws_flag = needs_whitespace_handling(pos_texts)

    # Extraer términos “prohibidos” desde marcadores (para evitar simple parroting)
    marker_pat = re.compile(r"<<(.*?)>>(?:\(\d+(?:\.\d+)?\))?", re.DOTALL)
    banned: List[str] = []
    seen = set()
    for t in pos_texts:
        for inner in marker_pat.findall(t):
            key = inner.strip()
            if key and key not in seen:
                seen.add(key)
                banned.append(key)
    if len(banned) > 40:
        banned = banned[:40]

    def section(title: str, arr: List[str]) -> str:
        if not arr:
            return f"{title}:\n(none)\n"
        lines = []
        for i, t in enumerate(arr, 1):
            lines.append(f"{title} {i}:\n{t.strip()}")
        return "\n\n".join(lines) + "\n"

    note = ""
    if ws_flag:
        note = (
            "IMPORTANT: If the trigger involves whitespace, do NOT use real newlines or tabs in your output.\n"
            "Use placeholders instead: [NL] for newline (\\n), [NL2] for two newlines (\\n\\n), "
            "[TAB] for tab (\\t), [WSP] for a single space. Use them in activating_sentences when needed.\n\n"
        )

    banned_note = (
        "GUIDANCE:\n"
        "- Avoid reusing the exact same trigger tokens/phrases (or close variants) as in the examples.\n"
        + ("- In particular, avoid: " + ", ".join(f'"{b}"' for b in banned) + ".\n" if banned else "")
        + "- You may vary topic, wording, or structure as long as the sentences are likely to activate the feature.\n\n"
    )

    body = []
    body.append(section("POSITIVE EXAMPLES (tokens marked; score may be shown)", pos_texts))
    body_text = "\n---\n\n".join(body)

    return (
        "You are analyzing a single SAE (Sparse Autoencoder) feature in a language model.\n"
        f"Top-k context: k = {k}.\n"
        "Activated tokens are marked as <<token>>(score). Ignore the markers when writing NEW sentences.\n\n"
        f"{note}"
        "TASKS:\n"
        "1) Under \"explanation\": Briefly describe the common trigger/pattern.\n"
        "2) Under \"activating_sentences\": Provide EXACTLY 5 NEW sentences that SHOULD strongly activate this feature.\n\n"
        f"{banned_note}"
        "OUTPUT (STRICT):\n"
        "- Return ONLY a valid JSON object with EXACTLY these keys: \"explanation\", \"activating_sentences\".\n"
        "- \"activating_sentences\": array of 5 strings.\n"
        "- Valid JSON only (double quotes, no trailing commas). No code fences, no comments, no extra text.\n\n"
        f"{body_text}"
        "OUTPUT EXPECTATION:\n"
        "A single valid JSON object with the two keys above and nothing else."
    )

# ==============================
# Escritura thread-safe (raw completo)
# ==============================
_write_lock = threading.Lock()
def append_jsonl(path: Path, obj: dict):
    line = json.dumps(obj, ensure_ascii=False) + "\n"
    with _write_lock:
        with open(path, "a", encoding="utf-8") as f:
            f.write(line)

# ==============================
# Cliente + wrapper (sin parsear)
# ==============================
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel(GENERATION_MODEL)

class _Retry:
    @staticmethod
    def parse_retry_delay_seconds(err: Exception) -> int:
        m = re.search(r"retry_delay\s*\{\s*seconds:\s*(\d+)", str(err))
        return int(m.group(1)) if m else 0

class RateLimiter:
    def __init__(self, rpm: int, tpm: int):
        self.rpm = rpm
        self.tpm = tpm
        self.req_times = []
        self.tok_times = []
        self.lock = threading.Lock()
    @staticmethod
    def approx_tokens(s: str) -> int:
        return max(1, len(s) // 4)
    def _prune(self, now):
        self.req_times = [t for t in self.req_times if now - t < 60]
        self.tok_times = [(t, k) for (t, k) in self.tok_times if now - t < 60]
    def acquire(self, texts: List[str]):
        need = sum(self.approx_tokens(x) for x in texts)
        while True:
            with self.lock:
                now = time.time()
                self._prune(now)
                used_rpm = len(self.req_times)
                used_tpm = sum(k for _, k in self.tok_times)
                ok_rpm = (used_rpm + 1) <= self.rpm
                ok_tpm = (used_tpm + need) <= self.tpm
                if ok_rpm and ok_tpm:
                    self.req_times.append(now)
                    self.tok_times.append((now, need))
                    return
                wait = 0.05
                if not ok_rpm and self.req_times:
                    wait = max(wait, 60 - (now - self.req_times[0]))
                if not ok_tpm and self.tok_times:
                    wait = max(wait, 60 - (now - self.tok_times[0][0]))
            time.sleep(min(2.0, wait))

limiter = RateLimiter(rpm=LLM_LIMITS["rpm"], tpm=LLM_LIMITS["tpm"])

def call_llm(prompt: str) -> str:
    limiter.acquire([prompt])
    try:
        resp = gemini_model.generate_content(prompt)
        return (resp.text or "")
    except Exception as e:
        rd = _Retry.parse_retry_delay_seconds(e)
        time.sleep(rd if rd > 0 else 2.0)
        limiter.acquire([prompt])
        resp = gemini_model.generate_content(prompt)
        return (resp.text or "")

In [3]:
# ==============================
# Main
# ==============================
# 1) Cargar jobs
jobs = load_all_jobs(SAE_FILES)
print(f"[INFO] features cargados (total): {len(jobs)}")

# 2) Resume
done = load_processed_set(OUTPUT_JSONL)
pending = filter_pending_jobs(jobs, done)
print(f"[INFO] pendientes: {len(pending)} (ya procesados: {len(done)})")

# 3) Crear archivo de salida si no existe
if not OUTPUT_JSONL.exists():
    open(OUTPUT_JSONL, "a", encoding="utf-8").close()

def one_job_runner(item: Dict):
    prompt = build_prompt_from_record(item)
    raw_text = call_llm(prompt)  # sin parsear (guardamos la respuesta completa)
    record = {
        "llm_model": GENERATION_MODEL,
        "sae_model": item["sae_model"],       # sin sufijo __k
        "topk": int(item["topk"]),
        "feature_id": int(item["feature_id"]),
        # Nuevo formato: solo ejemplos positivos marcados
        "examples": item["examples"],         # eco para auditoría
        # Respuesta cruda del LLM (debe contener "explanation" y "activating_sentences")
        "raw": raw_text,
        "approx_prompt_tokens": approx_tokens(prompt),  # auditoría
        "expected_output_keys": ["explanation", "activating_sentences"],
    }
    append_jsonl(OUTPUT_JSONL, record)
    return 1

total_done = 0
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
    futures = [pool.submit(one_job_runner, it) for it in pending]
    for fut in tqdm(as_completed(futures), total=len(futures), desc="LLM jobs"):
        try:
            total_done += fut.result()
        except Exception as e:
            print("[JOB ERROR]", e)

print(f"[DONE] escritos={total_done} → {OUTPUT_JSONL.resolve()}")


Cargando features (JSONL):   0%|          | 0/7 [00:00<?, ?it/s]

[INFO] features cargados (total): 3500
[INFO] pendientes: 3500 (ya procesados: 1000)


LLM jobs:   0%|          | 0/3500 [00:00<?, ?it/s]

[DONE] escritos=3500 → /Users/josue/proyectos/tesis/low-rank-bilinear-sae/interpretability/data/llm-responses.jsonl
