In [None]:
# ==== FAST SAMPLER (no Kedro), hard-coded artifacts & decoder ====
from pathlib import Path
import os, sys, json, random
import numpy as np
import pandas as pd
import torch

# Optional: avoid slow CUDA probing for a quick CPU sampler
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
# os.environ.setdefault("PYTHONUNBUFFERED", "1")
# torch.set_num_threads(1)

# ---------------------------------------------------------------------------
# 0) Project root discovery (so this runs fine from notebooks/ too)
# ---------------------------------------------------------------------------
def find_project_root(start: Path = Path.cwd()) -> Path:
    for p in [start, *start.parents]:
        if (p / "pyproject.toml").exists():
            return p
    raise RuntimeError(f"Couldn't find pyproject.toml above {start}")

PROJECT_PATH = find_project_root()
src_path = PROJECT_PATH / "src"
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))
print("Project root:", PROJECT_PATH)

# ---------------------------------------------------------------------------
# 1) Choose model + casing
# ---------------------------------------------------------------------------
# Set one of: "model_B21" (CASED) or "model_B22" (CAPS)
MODEL_ID  = "model_B21"

# If you prefer to infer from MODEL_ID, set CASE_MODE=None; otherwise force "CASED" or "CAPS"
CASE_MODE = None  # e.g., "CASED" or "CAPS" to override

def _infer_case_mode(model_id: str) -> str:
    if CASE_MODE in ("CASED", "CAPS"):
        return CASE_MODE
    # Heuristic: CAPS for B22, CASED otherwise
    return "CAPS" if model_id.endswith("B22") or model_id.endswith("_B22") else "CASED"

CASE_MODE = _infer_case_mode(MODEL_ID)
print(f"CASE_MODE: {CASE_MODE}")

# ---------------------------------------------------------------------------
# 2) Hard-code artifact locations (no recursive scans, no metrics lookup)
# ---------------------------------------------------------------------------
CKPT_DIR = PROJECT_PATH / f"data/06_models/{MODEL_ID}"
TOK_DIR  = CKPT_DIR  # prefer colocated tokenizer spec

# Choose the first existing checkpoint from this fixed, small list
_ckpt_candidates = [
    CKPT_DIR / "state_dict_best.pt",
    CKPT_DIR / "best.pt",
    CKPT_DIR / "model_best.pt",
    CKPT_DIR / "state_dict.pt",          # fallback
]
CKPT_PATH = next((p for p in _ckpt_candidates if p.exists()), None)
if CKPT_PATH is None:
    raise FileNotFoundError(f"No checkpoint found. Tried:\n" + "\n".join(str(p) for p in _ckpt_candidates))
print("Using checkpoint:", CKPT_PATH)

# Tokenizer spec: prefer colocated, then a single fixed fallback
_tok_candidates = [
    TOK_DIR / "tokenizer_spec.json",
    PROJECT_PATH / f"data/02_intermediate/{MODEL_ID}/tokenizer_spec.json",
]
TOK_PATH = next((p for p in _tok_candidates if p.exists()), None)
if TOK_PATH is None:
    raise FileNotFoundError(f"tokenizer_spec.json not found. Tried:\n" + "\n".join(str(p) for p in _tok_candidates))
print("Using tokenizer spec:", TOK_PATH)

# ---------------------------------------------------------------------------
# 3) Load artifacts
# ---------------------------------------------------------------------------
print("Loading checkpoint…")
try:
    state_dict = torch.load(CKPT_PATH, map_location="cpu", weights_only=True)  # PyTorch 2.x
except TypeError:
    state_dict = torch.load(CKPT_PATH, map_location="cpu")                      # older PyTorch
print("✓ checkpoint loaded")

print("Loading tokenizer spec…")
with open(TOK_PATH, "r", encoding="utf-8") as f:
    tokenizer_spec = json.load(f)
print("✓ tokenizer spec loaded")

# ---------------------------------------------------------------------------
# 4) Hard-code the decoder import
#    Change this line if your decode() lives somewhere else.
# ---------------------------------------------------------------------------
from main.pipelines.model_common.inference import decode
print("Using decoder: main.pipelines.model_common.inference.decode()")

# ---------------------------------------------------------------------------
# 5) Sampler config + helpers
# ---------------------------------------------------------------------------
TOP_P, TOP_K, STOP = 0.90, 50, None

def _maybe_caps(s: str) -> str:
    return s.upper() if CASE_MODE == "CAPS" and isinstance(s, str) else s

# ---------------------------------------------------------------------------
# 6) Generate 4 random samples
# ---------------------------------------------------------------------------
pairs = [(0.7, 111), (0.9, 222), (0.7, 333), (0.9, 444)]
rows = []
prefix = _maybe_caps("")
for i, (t, s) in enumerate(pairs, 1):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    text = decode(
        state_dict, tokenizer_spec, prefix,
        max_new_tokens=120, top_p=TOP_P, top_k=TOP_K,
        temperature=float(t), seed=int(s), stop=STOP
    )
    rows.append({
        "kind":"random","idx":i,"prompt":prefix,
        "temperature":t,"seed":s,
        "top_p":TOP_P,"top_k":TOP_K,"max_new_tokens":120,
        "text":text
    })
df_random = pd.DataFrame(rows)

# ---------------------------------------------------------------------------
# 7) Generate 4 prompted samples
# ---------------------------------------------------------------------------
prompts = [
    "Bhala isiqendu esifutshane sichaza inkqubo yokuvota eMzantsi Afrika.",
    "Qalisa ibali: 'Kwizolo kusasa, ndiphume ndisiya eTaxi Rank...'",
    "Phendula nge-JSON enezitshixo `topic`, `bullets` (3) ngomxholo: 'Ukulungiselela udliwanondlebe lomsebenzi'.",
    "Nika uluhlu lwezixeko ezi-5 eMpuma Koloni ngesiXhosa, uze uchaze esinye ngesiNgesi kwisivakalisi esinye.",
]
prompts = [_maybe_caps(p) for p in prompts]

rows = []
T, S = 0.7, 333
for i, pr in enumerate(prompts, 1):
    random.seed(S); np.random.seed(S); torch.manual_seed(S)
    text = decode(
        state_dict, tokenizer_spec, pr,
        max_new_tokens=120, top_p=TOP_P, top_k=TOP_K,
        temperature=float(T), seed=int(S), stop=STOP
    )
    rows.append({
        "kind":"prompted","idx":i,"prompt":pr,
        "temperature":T,"seed":S,
        "top_p":TOP_P,"top_k":TOP_K,"max_new_tokens":120,
        "text":text
    })
df_prompt = pd.DataFrame(rows)

# ---------------------------------------------------------------------------
# 8) Save CSVs where the reporting pipeline expects
# ---------------------------------------------------------------------------
out_dir = PROJECT_PATH / "data/08_reporting"
out_dir.mkdir(parents=True, exist_ok=True)
rand_path = out_dir / "_random_samples.csv"
prpt_path = out_dir / "_prompted_samples.csv"
df_random.to_csv(rand_path, index=False)
df_prompt.to_csv(prpt_path, index=False)

print("✅ Saved:")
print("  -", rand_path)
print("  -", prpt_path)
