## fine-tuning with LM
Step 1.	Run Manifest & Environment
	•	Inputs: desired Python version; tool choice (uv/poetry/pip); seed
	•	Outputs: run_manifest.yaml (python version, platform, chip, mlx-lm version, commit hashes), locked deps (requirements.lock or poetry.lock), RNG seeds set

In [None]:
# STEP 1 — Run Manifest & Environment (Apple Silicon / MLX)
# - Captures exact runtime info (OS, chip, Python, key libs)
# - Locks dependencies via `pip freeze` -> requirements.lock
# - Sets deterministic seeds (random, numpy; PYTHONHASHSEED)
# - Writes manifest to run_manifest.yaml (falls back to JSON if PyYAML missing)

import os, sys, platform, subprocess, json, time, hashlib, shlex
from pathlib import Path

# ---------- Configuration (edit as needed) ----------
from config_loader import load_config
CFG = load_config()  # pulls default.yaml, then local.yaml, then CFG_* env, then (optionally) pass a dict of overrides
#print("JIM",CFG)
OUT_DIR = Path(CFG.run.output_dir)                  # where to write outputs
LOCKFILE = OUT_DIR / "requirements.lock"
MANIFEST_YAML = OUT_DIR / "run_manifest.yaml"
MANIFEST_JSON = OUT_DIR / "run_manifest.json"
SEED = CFG.run.seed
# ----------------------------------------------------

# 1) Set seeds for determinism (Python & NumPy)
import random
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
try:
    import numpy as np
    np.random.seed(SEED)
    numpy_ver = np.__version__
except Exception:
    numpy_ver = None

# 2) Collect environment info
def _safe_import_version(pkg_name):
    try:
        import importlib.metadata as md
        return md.version(pkg_name)
    except Exception:
        return None

def _which(cmd):
    try:
        r = subprocess.run(["which", cmd], capture_output=True, text=True)
        return r.stdout.strip() or None
    except Exception:
        return None

def _run(cmd):
    try:
        r = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        return (r.returncode, r.stdout.strip(), r.stderr.strip())
    except Exception as e:
        return (1, "", str(e))

py_ver = sys.version.split()[0]
platform_info = {
    "system": platform.system(),
    "release": platform.release(),
    "version": platform.version(),
    "machine": platform.machine(),
    "processor": platform.processor(),
    "python": py_ver,
}

# Apple chip details (best-effort)
chip_brand = None
if platform.system() == "Darwin":
    code, out, _ = _run("sysctl -n machdep.cpu.brand_string")
    chip_brand = out if code == 0 else None
    platform_info["mac_ver"] = platform.mac_ver()[0]
platform_info["chip_brand"] = chip_brand

# 3) Key package versions (MLX-focused)
mlx_lm_ver   = _safe_import_version("mlx-lm")
datasets_ver = _safe_import_version("datasets")
pandas_ver   = _safe_import_version("pandas")
tqdm_ver     = _safe_import_version("tqdm")

# 4) Lock dependencies with pip freeze
LOCKFILE.parent.mkdir(parents=True, exist_ok=True)
code, out, err = _run(f"{shlex.quote(sys.executable)} -m pip freeze")
if code == 0:
    LOCKFILE.write_text(out + "\n", encoding="utf-8")
else:
    print("[warn] pip freeze failed:", err)

# Hash the lock for quick integrity checks
lock_hash = None
if LOCKFILE.exists():
    lock_hash = hashlib.sha256(LOCKFILE.read_bytes()).hexdigest()

# 5) Build manifest object
manifest = {
    "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "seed": SEED,
    "platform": platform_info,
    "packages": {
        "mlx-lm": mlx_lm_ver,
        "datasets": datasets_ver,
        "pandas": pandas_ver,
        "tqdm": tqdm_ver,
        "numpy": numpy_ver,
    },
    "executables": {
        "python": sys.executable,
        "python_which": _which("python"),
        "pip_which": _which("pip"),
    },
    "artifacts": {
        "requirements_lock": str(LOCKFILE.resolve()) if LOCKFILE.exists() else None,
        "requirements_lock_sha256": lock_hash,
    },
    "notes": [
        "This manifest anchors the run. Keep it with any training outputs.",
        "If you change env/deps, regenerate this step to create a new lock."
    ],
}

# 6) Write manifest to YAML (fallback to JSON if PyYAML not installed)
def write_manifest_yaml(obj, path_yaml, path_json_fallback):
    try:
        import yaml  # type: ignore
        with open(path_yaml, "w", encoding="utf-8") as f:
            yaml.safe_dump(obj, f, sort_keys=False)
        return str(path_yaml)
    except Exception as e:
        # Fallback JSON
        with open(path_json_fallback, "w", encoding="utf-8") as f:
            json.dump(obj, f, indent=2)
        return f"{path_yaml} (PyYAML missing) -> wrote JSON: {path_json_fallback}"

out_path = write_manifest_yaml(manifest, MANIFEST_YAML, MANIFEST_JSON)

# 7) Print a compact summary
print("\n=== RUN MANIFEST SUMMARY ===")
print(f"Python:        {py_ver}")
print(f"OS/Chip:       {platform_info['system']} {platform_info.get('mac_ver') or platform_info['release']} | {platform_info.get('chip_brand') or platform_info['machine']}")
print(f"mlx-lm:        {mlx_lm_ver}")
print(f"datasets:      {datasets_ver}")
print(f"pandas:        {pandas_ver}")
print(f"tqdm:          {tqdm_ver}")
print(f"numpy:         {numpy_ver}")
print(f"Seed:          {SEED}")
print(f"Lockfile:      {LOCKFILE}  sha256={lock_hash[:12]+'…' if lock_hash else None}")
print(f"Manifest path: {out_path}")
print("============================\n")

In [None]:
# STEP 2′ — HF Dataset Import (parameterized) + write data_contract.json & data_catalog.json

from config_loader import load_config
cfg = load_config()

print("Dataset:", cfg.data.hf_dataset)
print("Subset:", cfg.data.subset)
print("Mode:", cfg.data.mode)
print("Valid fraction:", cfg.data.valid_fraction)
print("Seed:", cfg.run.seed)

# Use them in preprocessing
HF_DATASET  = cfg.data.hf_dataset
SUBSET      = cfg.data.subset
MODE        = cfg.data.mode
VALID_FRACT = cfg.data.valid_fraction
MIN_WORDS   = cfg.data.min_words
MAX_WORDS   = cfg.data.max_words
SEED        = cfg.run.seed

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
CATALOG     = out_dir / cfg.paths.catalog

from datasets import load_dataset
from pathlib import Path
import json, random, hashlib, time

random.seed(SEED)


print(f"Loading {HF_DATASET} subset={SUBSET} …")
ds = load_dataset(HF_DATASET, name=SUBSET, split="train")
print(ds)

def wc(s): return len(str(s).split())
def sha(s): return hashlib.sha256(str(s).encode("utf-8","ignore")).hexdigest()

rows = []
for r in ds:
    quote  = (r.get("quote") or "").strip()
    author = (r.get("author") or "").strip()
    if not quote:
        continue

    if MODE == "plain":
        text = quote
    else:
        instr = f"Write a short motivational quote in the style of {author}." if author else "Write a short motivational quote."
        text  = f"Instruction:\n{instr}\n\nResponse:\n{quote}"

    if not (MIN_WORDS <= wc(text) <= MAX_WORDS):
        continue
    rows.append(text)

# dedupe while preserving order
seen=set(); uniq=[]
for t in rows:
    h=sha(t)
    if h not in seen:
        seen.add(h); uniq.append(t)

# split
random.shuffle(uniq)
valid_n = max(100, int(len(uniq) * VALID_FRACT))
valid = uniq[:valid_n]
train = uniq[valid_n:]

def write_jsonl(path: Path, texts):
    with path.open("w", encoding="utf-8") as f:
        for t in texts:
            f.write(json.dumps({"text": t}, ensure_ascii=False) + "\n")

train_path = out_dir / "train.jsonl"
valid_path = out_dir / "valid.jsonl"
write_jsonl(train_path, train)
write_jsonl(valid_path, valid)

print(f"Wrote {len(train)} train, {len(valid)} valid to {out_dir.resolve()}")

# --- Write data_contract.json and data_catalog.json ---
def count_lines_bytes(p: Path):
    n = 0
    with p.open("rb") as f:
        for _ in f: n += 1
    return n, p.stat().st_size

def sha256_file(p: Path) -> str:
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1024*1024), b""):
            h.update(chunk)
    return h.hexdigest()

created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

# Contract (simple schema with detected string field = "text")
data_contract = {
    "created_utc": created,
    "data_dir": str(out_dir.resolve()),
    "filenames": {
        "train": {"chosen": train_path.name, "resolved": str(train_path.resolve())},
        "valid": {"chosen": valid_path.name, "resolved": str(valid_path.resolve())},
    },
    "schema": {"format": "jsonl", "fields": {"text": "string"}},
}
CONTRACT.write_text(json.dumps(data_contract, indent=2), encoding="utf-8")

# Catalog (write BOTH legacy 'entries' and simple 'files' views)
t_lines, t_bytes = count_lines_bytes(train_path)
v_lines, v_bytes = count_lines_bytes(valid_path)
t_sha = sha256_file(train_path)
v_sha = sha256_file(valid_path)

data_catalog = {
    "created_utc": created,
    "files": {
        "train": {"path": str(train_path.resolve()), "lines": t_lines, "bytes": t_bytes, "sha256": t_sha},
        "valid": {"path": str(valid_path.resolve()), "lines": v_lines, "bytes": v_bytes, "sha256": v_sha},
    },
    "entries": {
        "train": {"path": str(train_path.resolve()), "stats": {
            "num_valid_examples": t_lines, "num_bytes": t_bytes, "sha256": t_sha}},
        "valid": {"path": str(valid_path.resolve()), "stats": {
            "num_valid_examples": v_lines, "num_bytes": v_bytes, "sha256": v_sha}},
    },
}
CATALOG.write_text(json.dumps(data_catalog, indent=2), encoding="utf-8")

print("Wrote data_contract.json and data_catalog.json")

Step 3.	Data Validation & Stats
	•	Inputs: train/valid JSONL, contract
	•	Outputs: data_report.json (line counts, empty/dup checks, length histograms, charset, special-token frequencies), a few sampled rows

In [None]:
# STEP 3 — Data Validation & Stats
# Inputs:
#   - data_contract.json (from Step 2)
# Outputs:
#   - data_report.json   (per-split detailed stats & issues)
# Console:
#   - compact summary (counts, dupes, whitespace/control issues, length percentiles)
from config_loader import load_config
cfg = load_config()
from __future__ import annotations
import json, re, unicodedata, statistics, hashlib
from pathlib import Path
from typing import Dict, Any, List, Tuple

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
REPORT      = out_dir / cfg.paths.report

# Heuristics: potential stop/EOS markers to scan for
EOS_MARKERS = [
    "</s>",         # common HF eos
    "###",          # section break in some templates
    "\n\n",         # blank-line stop
    "<|eot_id|>",   # chat-style separators
    "<|endoftext|>" # GPT-like
]

def load_contract(path: Path) -> Tuple[str, Dict[str, str], str]:
    c = json.loads(path.read_text(encoding="utf-8"))
    data_dir = c["data_dir"]
    # discover the text field (first string-type field in schema)
    fields = c.get("schema", {}).get("fields", {})
    text_field = None
    for k, v in fields.items():
        if str(v).lower() == "string":
            text_field = k; break
    if not text_field:
        text_field = "text"  # fallback
    files = {split: info["resolved"] for split, info in c["filenames"].items() if info.get("resolved")}
    return text_field, files, data_dir

def hash_text(s: str) -> str:
    return hashlib.sha256(s.encode("utf-8", "ignore")).hexdigest()

def char_classes(s: str) -> Dict[str, int]:
    # Count basic unicode categories and control chars
    ctrl = sum(1 for ch in s if unicodedata.category(ch) in ("Cc","Cf"))
    ws   = sum(1 for ch in s if ch.isspace())
    nonascii = sum(1 for ch in s if ord(ch) > 127)
    return {"control": ctrl, "whitespace": ws, "non_ascii": nonascii}

def percentiles(values: List[int], q=(5, 25, 50, 75, 95)) -> Dict[str, int]:
    if not values: return {f"p{p}": 0 for p in q}
    vals = sorted(values)
    out = {}
    for p in q:
        k = max(0, min(len(vals)-1, int(round((p/100)* (len(vals)-1)))))
        out[f"p{p}"] = int(vals[k])
    return out

def scan_file(path: Path, field: str) -> Dict[str, Any]:
    n_lines = 0
    bad_json = 0
    missing_field = 0
    non_str = 0
    empty = 0
    whitespace_only = 0
    leading_ws = 0
    trailing_ws = 0
    ctrl_lines = 0

    lengths = []
    hashes = []
    eos_hits = {m: 0 for m in EOS_MARKERS}

    samples_good: List[str] = []
    samples_bad: List[str]  = []

    with path.open("r", encoding="utf-8", errors="replace") as f:
        for line in f:
            n_lines += 1
            line = line.rstrip("\n")
            try:
                obj = json.loads(line)
            except Exception:
                bad_json += 1
                if len(samples_bad) < 3: samples_bad.append(f"[bad_json] {line[:160]}")
                continue

            if field not in obj:
                missing_field += 1
                if len(samples_bad) < 3: samples_bad.append(f"[missing_field] {line[:160]}")
                continue
            val = obj[field]
            if not isinstance(val, str):
                non_str += 1
                if len(samples_bad) < 3: samples_bad.append(f"[non_string] {str(val)[:160]}")
                continue

            if val == "":
                empty += 1
            if val.strip() == "":
                whitespace_only += 1
            if val and val[0].isspace():
                leading_ws += 1
            if val and val[-1].isspace():
                trailing_ws += 1

            cc = char_classes(val)
            if cc["control"] > 0:
                ctrl_lines += 1

            L = len(val)
            lengths.append(L)
            hashes.append(hash_text(val))
            for m in EOS_MARKERS:
                if m in val:
                    eos_hits[m] += 1

            if len(samples_good) < 3:
                samples_good.append(val)

    # duplicates
    dup_count = 0
    dup_examples = []
    from collections import Counter
    c = Counter(hashes)
    for h, cnt in c.items():
        if cnt > 1:
            dup_count += cnt - 1
            if len(dup_examples) < 3:
                dup_examples.append(h)

    # length stats
    length_stats = {
        "count": len(lengths),
        "min": int(min(lengths)) if lengths else 0,
        "max": int(max(lengths)) if lengths else 0,
        "mean": float(statistics.mean(lengths)) if lengths else 0.0,
        "median": float(statistics.median(lengths)) if lengths else 0.0,
        "percentiles": percentiles(lengths),
    }

    return {
        "path": str(path),
        "lines": n_lines,
        "valid_examples": len(lengths),
        "errors": {
            "bad_json": bad_json,
            "missing_field": missing_field,
            "non_string_field": non_str,
        },
        "empties": {
            "empty_exact": empty,
            "whitespace_only": whitespace_only,
            "leading_whitespace": leading_ws,
            "trailing_whitespace": trailing_ws,
        },
        "control_char_lines": ctrl_lines,
        "duplicates": {
            "duplicate_example_count": dup_count,
            "sha256_examples": dup_examples,
        },
        "length_chars": length_stats,
        "eos_markers_hits": eos_hits,
        "samples": {
            "good_first3": samples_good,
            "bad_first3": samples_bad,
        },
    }

# Load contract and validate
text_field, files, data_dir = load_contract(CONTRACT)
report: Dict[str, Any] = {
    "created_utc": __import__("time").strftime("%Y-%m-%dT%H:%M:%SZ", __import__("time").gmtime()),
    "data_dir": data_dir,
    "text_field": text_field,
    "splits": {},
}

for split, p in files.items():
    rep = scan_file(Path(p), text_field)
    report["splits"][split] = rep

# Write full report
REPORT.write_text(json.dumps(report, indent=2), encoding="utf-8")

# Console summary
print("=== DATA VALIDATION SUMMARY ===")
for split, rep in report["splits"].items():
    errs = rep["errors"]
    empt = rep["empties"]
    lens = rep["length_chars"]
    eos  = rep["eos_markers_hits"]
    dup  = rep["duplicates"]["duplicate_example_count"]
    print(f"- {split}: lines={rep['lines']} valid={rep['valid_examples']} "
          f"errors(bad/miss/nonstr)={errs['bad_json']}/{errs['missing_field']}/{errs['non_string_field']} "
          f"empties(exact/ws/lead/trail)={empt['empty_exact']}/{empt['whitespace_only']}/{empt['leading_whitespace']}/{empt['trailing_whitespace']} "
          f"dupes={dup} len[min/med/95/max]={lens['min']}/{int(lens['median'])}/{lens['percentiles']['p95']}/{lens['max']} "
          f"eos_hits={{" + ", ".join(f'{k}:{v}' for k,v in eos.items() if v) + "}}")
print("Wrote:", REPORT)

Step 4.	Formatting Policy (Prompt Template)
	•	Inputs: selected template name (e.g., “plain_i/o”, “llama3_style”); tokenization preview
	•	Outputs: deterministic formatter function spec; example before/after table; stored as format_policy.json

In [None]:
# STEP 4 — Formatting Policy (Prompt Template)
# Goal:
#   - Pick a formatting policy for how prompts/responses should look during training & generation.
#   - DOES NOT MODIFY your existing JSONL files.
#   - Writes `format_policy.json` describing the chosen template + parameters.
#
# Inputs:
#   - data_contract.json  (from Step 2)
# Outputs:
#   - format_policy.json  (template choice & settings)
# Console:
#   - Before/After preview for a few examples

from __future__ import annotations
import json, textwrap
from pathlib import Path
from typing import Dict, Any, List
from config_loader import load_config
cfg = load_config()

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
POLICY      = out_dir / cfg.paths.policy

# ------------------------
# 1) Choose a template
# ------------------------
# Options:
#   "plain_text_passthrough" : use the JSONL "text" as-is (good when your data already contains instruction/response)
#   "icl_minimal"            : a simple Q/A style wrapper (no special chat tokens)
#   "llama3_style"           : a friendly chat-like wrapper (ASCII tags only)
TEMPLATE_NAME = "plain_text_passthrough"

# Optional stop strings you intend to use during generation probes later.
# (These are just recorded here; not enforced yet.)
STOP_STRINGS = ["\n\n"]   # common “blank line” stop
USE_EOS_TOKEN = True      # whether to set eos_token_id in “default” runs later

# ------------------------
# 2) Load contract & sample a few rows for preview
# ------------------------
def load_contract(path: Path):
    c = json.loads(path.read_text(encoding="utf-8"))
    data_dir = Path(c["data_dir"])
    files = {k: v["resolved"] for k, v in c["filenames"].items() if v.get("resolved")}
    # detect the text field name (from schema) with fallback
    fields = c.get("schema", {}).get("fields", {})
    text_field = next((k for k,v in fields.items() if str(v).lower()=="string"), "text")
    return data_dir, files, text_field

data_dir, files, TEXT_FIELD = load_contract(CONTRACT)
train_path = Path(files["train"])

def read_first_n_texts(p: Path, n: int = 3, field: str = "text") -> List[str]:
    out = []
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            if len(out) >= n: break
            try:
                obj = json.loads(line)
            except Exception:
                continue
            val = obj.get(field)
            if isinstance(val, str):
                out.append(val)
    return out

SAMPLES = read_first_n_texts(train_path, n=3, field=TEXT_FIELD)

# ------------------------
# 3) Define template functions (no mutation)
# ------------------------
def fmt_plain(text: str) -> str:
    # return exactly as stored
    return text

def fmt_icl_minimal(text: str) -> str:
    # Wrap the existing content as a single instruction→response block.
    # If your JSONL already contains both, this is nearly a no-op with a header.
    return (
        "### Instruction\n"
        "Share an important thought.\n\n"
        "### Response\n" + text.strip()
    )

def fmt_llama3_style(text: str) -> str:
    # A neutral chat-ish style using plain ASCII delimiters.
    # (We avoid special tokens here; real chat templates can be added later if desired.)
    return (
        "<s>[INSTRUCTION]\n"
        "Share an .\n"
        "[/INSTRUCTION]\n"
        "[RESPONSE]\n" + text.strip() + "\n[/RESPONSE]</s>"
    )

FORMATTERS = {
    "plain_text_passthrough": fmt_plain,
    "icl_minimal": fmt_icl_minimal,
    "llama3_style": fmt_llama3_style,
}

if TEMPLATE_NAME not in FORMATTERS:
    raise SystemExit(f"Unknown TEMPLATE_NAME: {TEMPLATE_NAME}")

formatter = FORMATTERS[TEMPLATE_NAME]

# ------------------------
# 4) Preview: before/after for a few rows
# ------------------------
print("=== FORMAT PREVIEW ===")
print(f"Template: {TEMPLATE_NAME}")
for i, txt in enumerate(SAMPLES, 1):
    print(f"\n--- Example {i}: BEFORE ---")
    print(textwrap.shorten(txt.replace("\n"," \\n "), width=220, placeholder="…"))
    print("--- Example {i}: AFTER  ---")
    print(textwrap.shorten(formatter(txt).replace("\n"," \\n "), width=220, placeholder="…"))

# ------------------------
# 5) Persist policy (for downstream steps)
# ------------------------
policy: Dict[str, Any] = {
    "template_name": TEMPLATE_NAME,
    "text_field": TEXT_FIELD,
    "stop_strings": STOP_STRINGS,
    "use_eos_token": USE_EOS_TOKEN,
    "notes": [
        "This policy describes how to *format* examples when generating or when materializing new data.",
        "Your current JSONL will not be changed by this step.",
        "Downstream steps can choose to apply this formatter or keep passthrough depending on the experiment."
    ],
}

# Keep a tiny deterministic sample of BEFORE/AFTER in the policy for traceability
policy["preview"] = [
    {"before": SAMPLES[i], "after": formatter(SAMPLES[i])} for i in range(min(2, len(SAMPLES)))
]

POLICY.write_text(json.dumps(policy, indent=2), encoding="utf-8")
print(f"\nWrote {POLICY}")

Step 5.	Train/Valid Materialization (optional)
	•	Inputs: raw aphorism text or other source (skip if already JSONL)
	•	Outputs: train.jsonl, valid.jsonl that conform to the contract; lineage recorded in data_lineage.json

In [None]:
# STEP 5 — Train/Valid Materialization (optional)
# Behavior:
#   - If JSONL files already exist as per data_contract.json => NO-OP (prints summary).
#   - Else: build them from RAW_INPUT_PATH (one example per line), apply Step-4 format policy,
#           shuffle/split, and write train/valid according to the contract's chosen filenames.
#
# Inputs:
#   - data_contract.json   (from Step 2)
#   - format_policy.json   (from Step 4)
#   - RAW_INPUT_PATH       (used ONLY if JSONLs missing)
# Outputs:
#   - train.jsonl + valid/val.jsonl (only when materialization is needed)

from __future__ import annotations
import json, random, unicodedata
from pathlib import Path
from typing import List, Dict, Any

from config_loader import load_config
cfg = load_config()

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
POLICY      = out_dir / cfg.paths.policy

# --------- Config (edit if you need to materialize) ---------
RAW_INPUT_PATH = Path("./input.txt")  # used ONLY if JSONL missing
MIN_LINE_WORDS = cfg.data.min_words                       # filter out very short lines
VAL_RATIO      = 0.10                    # 10% validation split
SHUFFLE_SEED   = cfg.run.seed
OVERWRITE      = False                   # set True to force rebuild even if files exist
NORMALIZE_NFC  = True                    # Unicode normalize to NFC
# ------------------------------------------------------------

def load_contract_and_policy(contract_p: Path, policy_p: Path):
    contract = json.loads(contract_p.read_text(encoding="utf-8"))
    policy   = json.loads(policy_p.read_text(encoding="utf-8"))
    data_dir = Path(contract["data_dir"])
    files = {k: v["resolved"] for k, v in contract["filenames"].items() if v.get("resolved")}
    chosen = {k: v["chosen"] for k, v in contract["filenames"].items() if v.get("chosen")}
    text_field = next((k for k,v in contract.get("schema",{}).get("fields",{}).items() if str(v).lower()=="string"), "text")
    return data_dir, files, chosen, text_field, policy

def exists_and_nonempty(p: Path) -> bool:
    return p.exists() and p.stat().st_size > 0

def read_lines_raw(p: Path, min_words: int) -> List[str]:
    if not p.exists():
        raise FileNotFoundError(f"RAW_INPUT_PATH not found: {p}")
    lines: List[str] = []
    with p.open("r", encoding="utf-8", errors="replace") as f:
        for ln in f:
            ln = ln.strip()
            if not ln: continue
            if len(ln.split()) < min_words: continue
            if NORMALIZE_NFC:
                ln = unicodedata.normalize("NFC", ln)
            lines.append(ln)
    if not lines:
        raise SystemExit(f"No usable lines found in {p}.")
    return lines

# simple formatters mirroring Step 4 names (we’ll only use the selected one)
def fmt_plain(text: str) -> str:
    return text

def fmt_icl_minimal(text: str) -> str:
    return (
        "### Instruction\n"
        "Share an important thought.\n\n"
        "### Response\n" + text.strip()
    )

def fmt_llama3_style(text: str) -> str:
    return (
        "<s>[INSTRUCTION]\n"
        "Share an important thought.\n"
        "[/INSTRUCTION]\n"
        "[RESPONSE]\n" + text.strip() + "\n[/RESPONSE]</s>"
    )

FORMATTERS = {
    "plain_text_passthrough": fmt_plain,
    "icl_minimal": fmt_icl_minimal,
    "llama3_style": fmt_llama3_style,
}

def write_jsonl_text(items: List[str], out_path: Path, field: str = "text"):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        for t in items:
            f.write(json.dumps({field: t}, ensure_ascii=False) + "\n")

# ---- Orchestration ----
data_dir, files_resolved, chosen_names, TEXT_FIELD, policy = load_contract_and_policy(CONTRACT, POLICY)
train_target = Path(files_resolved.get("train")) if "train" in files_resolved else data_dir / chosen_names.get("train","train.jsonl")
valid_key = "valid" if "valid" in files_resolved or "valid" in chosen_names else ("val" if "val" in files_resolved or "val" in chosen_names else "valid")
valid_target = Path(files_resolved.get(valid_key)) if valid_key in files_resolved else data_dir / chosen_names.get(valid_key, f"{valid_key}.jsonl")

already_present = exists_and_nonempty(train_target) and exists_and_nonempty(valid_target)

if already_present and not OVERWRITE:
    print("✅ JSONL already present. No materialization performed.")
    print(" - train:", train_target)
    print(f" - {valid_key}:", valid_target)
else:
    tmpl = policy.get("template_name", "plain_text_passthrough")
    if tmpl not in FORMATTERS:
        raise SystemExit(f"Unknown template_name in format_policy.json: {tmpl}")
    formatter = FORMATTERS[tmpl]

    raw_lines = read_lines_raw(RAW_INPUT_PATH, MIN_LINE_WORDS)
    # apply formatter
    examples = [formatter(x) for x in raw_lines]

    # shuffle & split
    random.seed(SHUFFLE_SEED)
    random.shuffle(examples)
    n = len(examples)
    n_val = max(1, int(round(VAL_RATIO * n)))
    val_items = examples[:n_val]
    train_items = examples[n_val:]

    # write
    write_jsonl_text(train_items, train_target, TEXT_FIELD)
    write_jsonl_text(val_items,   valid_target, TEXT_FIELD)

    print("📝 Materialized JSONL files:")
    print(" - train ->", train_target, f"(rows={len(train_items)})")
    print(f" - {valid_key} ->", valid_target, f"(rows={len(val_items)})")
    print("Template used:", tmpl)

In [None]:
# STEP 5.5 (REVISED) — Write catalog in BOTH schemas ("entries" + "files")
# Keeps Step 6 unchanged by emitting the legacy structure it expects.

from __future__ import annotations
import json, hashlib, time
from pathlib import Path

from config_loader import load_config
cfg = load_config()

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
POLICY      = out_dir / cfg.paths.policy

TRAIN = out_dir / "train.jsonl"
VALID = out_dir / "valid.jsonl"

def sha256_file(p: Path) -> str:
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1024*1024), b""):
            h.update(chunk)
    return h.hexdigest()

def count_lines_bytes(p: Path):
    n = 0
    with p.open("rb") as f:
        for _ in f: n += 1
    return n, p.stat().st_size

def sniff_string_field(p: Path) -> str:
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line: continue
            try:
                obj=json.loads(line)
                for k,v in obj.items():
                    if isinstance(v,str): return k
            except Exception:
                pass
    return "text"

# sanity
if not TRAIN.exists() or not VALID.exists():
    raise SystemExit("STEP 5.5: Expected ./data/train.jsonl and ./data/valid.jsonl.")

# contract (unchanged behavior)
created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
string_field = sniff_string_field(TRAIN)
contract = {
    "created_utc": created,
    "data_dir": str(out_dir.resolve()),
    "filenames": {
        "train": {"chosen": TRAIN.name, "resolved": str(TRAIN.resolve())},
        "valid": {"chosen": VALID.name, "resolved": str(VALID.resolve())},
    },
    "schema": {"format": "jsonl", "fields": {string_field: "string"}},
}
CONTRACT.write_text(json.dumps(contract, indent=2), encoding="utf-8")

# counts + hashes
t_lines, t_bytes = count_lines_bytes(TRAIN)
v_lines, v_bytes = count_lines_bytes(VALID)
t_sha = sha256_file(TRAIN)
v_sha = sha256_file(VALID)

# catalog — write BOTH formats:
catalog = {
    # Newer, simpler view:
    "files": {
        "train": {"path": str(TRAIN.resolve()), "lines": t_lines, "bytes": t_bytes, "sha256": t_sha},
        "valid": {"path": str(VALID.resolve()), "lines": v_lines, "bytes": v_bytes, "sha256": v_sha},
    },
    # Legacy view that Step 6 expects:
    "entries": {
        "train": {
            "path": str(TRAIN.resolve()),
            "stats": {
                "num_valid_examples": t_lines,
                "num_bytes": t_bytes,
                "sha256": t_sha,
            },
        },
        "valid": {
            "path": str(VALID.resolve()),
            "stats": {
                "num_valid_examples": v_lines,
                "num_bytes": v_bytes,
                "sha256": v_sha,
            },
        },
    },
    "created_utc": created,
}

CATALOG.write_text(json.dumps(catalog, indent=2), encoding="utf-8")

print("=== STEP 5.5 (revised) ===")
print(f"train: lines={t_lines} bytes={t_bytes:,}")
print(f"valid: lines={v_lines} bytes={v_bytes:,}")
print("Wrote:", CONTRACT.name, "and", CATALOG.name)
print("Next: run Step 6 as-is.")

In [None]:
# STEP 6 — Experiment Matrix
# Purpose:
#   - Define models + core hyperparams ONCE.
#   - Resolve dataset sizes from Step 2/3 outputs.
#   - Estimate MLX `--iters` (since mlx_lm.lora is iteration-based).
#   - Emit a clean experiments.csv (one row per model).
#
# Inputs:
#   - data_contract.json (Step 2)
#   - data_catalog.json  (Step 2)  [preferred for counts]
#   - data_report.json   (Step 3)  [fallback if catalog missing]
#
# Outputs:
#   - experiments.csv

from __future__ import annotations
import json, math, csv, time
from pathlib import Path
from typing import Dict, Any, Tuple, List

from config_loader import load_config
cfg = load_config()

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
CONTRACT    = out_dir / cfg.paths.contract
POLICY      = out_dir / cfg.paths.policy
REPORT      = out_dir / cfg.paths.report

RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS_CSV = RUN_DIR / "experiments.csv"

# ---------- EDITABLE BLOCK ----------
# List your MLX-compatible base models here
BASE_MODELS = [
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    # "mlx-community/phi-2",
    # add more as needed
]

# Core hyperparameters (shared across all rows; you can later copy/edit specific rows)
EPOCHS          = 1               # convenient, we’ll convert to iters
BATCH_SIZE      = 1
GRAD_ACCUM      = 8
MAX_SEQ_LENGTH  = 512
LEARNING_RATE   = 2e-4
BF16            = True
# Optional: override `iters` directly (0 = auto from dataset & epochs)
ITERS_OVERRIDE  = 0
# -----------------------------------

def load_contract() -> Dict[str, Any]:
    return json.loads(CONTRACT.read_text(encoding="utf-8"))

def get_counts_from_catalog() -> Tuple[int, int]:
    if not CATALOG.exists():
        return None, None  # signal fallback
    c = json.loads(CATALOG.read_text(encoding="utf-8"))
    train = c["entries"]["train"]["stats"]["num_valid_examples"]
    # 'valid' key name may be 'valid' or 'val' depending on contract; try both
    val_entry = c["entries"].get("valid") or c["entries"].get("val")
    valid = val_entry["stats"]["num_valid_examples"] if val_entry else 0
    return int(train), int(valid)

def get_counts_from_report() -> Tuple[int, int]:
    r = json.loads(REPORT.read_text(encoding="utf-8"))
    train = r["splits"]["train"]["valid_examples"]
    val_entry = r["splits"].get("valid") or r["splits"].get("val")
    valid = val_entry["valid_examples"] if val_entry else 0
    return int(train), int(valid)

def resolve_files_from_contract(ct: Dict[str, Any]) -> Dict[str, str]:
    files = {k: v["resolved"] for k, v in ct["filenames"].items() if v.get("resolved")}
    # normalize key for validation split
    if "valid" in files:
        files["validation"] = files["valid"]
    elif "val" in files:
        files["validation"] = files["val"]
    return files

def estimate_iters(num_train: int, epochs: int, batch: int, accum: int) -> int:
    # MLX lora uses --iters; here we approximate: steps ≈ epochs * num_train / (batch * accum)
    steps = max(1, math.ceil((epochs * max(1, num_train)) / max(1, batch * accum)))
    # also guard a reasonable floor so very tiny sets still do some learning
    return max(100, steps)

# 1) Load metadata and counts
ct = load_contract()
files = resolve_files_from_contract(ct)

train_count, valid_count = get_counts_from_catalog()
if train_count is None:
    train_count, valid_count = get_counts_from_report()

data_dir = Path(ct["data_dir"])

# 2) Build experiment rows
rows: List[Dict[str, Any]] = []
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

for model_id in BASE_MODELS:
    model_tag = model_id.replace("/", "--")
    out_root  = RUN_DIR / model_tag
    adapter_path = out_root / "adapter"
    logs_dir     = out_root / "logs"

    iters = ITERS_OVERRIDE or estimate_iters(
        num_train=train_count,
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        accum=GRAD_ACCUM,
    )

    # token budget (very rough): max_seq_length * batch * accum * iters
    est_tokens = MAX_SEQ_LENGTH * BATCH_SIZE * GRAD_ACCUM * iters

    rows.append({
        "created_utc": timestamp,
        "model_id": model_id,
        "data_dir": str(data_dir),
        "train_file": files.get("train"),
        "valid_file": files.get("validation"),
        "train_examples": train_count,
        "valid_examples": valid_count,
        "epochs": EPOCHS,
        "iters": iters,
        "batch_size": BATCH_SIZE,
        "grad_accum": GRAD_ACCUM,
        "max_seq_length": MAX_SEQ_LENGTH,
        "learning_rate": LEARNING_RATE,
        "bf16": int(bool(BF16)),
        "adapter_path": str(adapter_path),
        "log_dir": str(logs_dir),
        "est_tokens": est_tokens
    })

# 3) Write experiments.csv
EXPERIMENTS_CSV.parent.mkdir(parents=True, exist_ok=True)
fieldnames = list(rows[0].keys()) if rows else []
with EXPERIMENTS_CSV.open("w", newline="", encoding="utf-8") as f:
    w = csv.DictWriter(f, fieldnames=fieldnames)
    w.writeheader()
    for r in rows:
        w.writerow(r)

# 4) Console summary
print("=== EXPERIMENT MATRIX ===")
print(f"Data dir: {data_dir}")
print(f"Counts: train={train_count} valid={valid_count}")
print(f"Wrote: {EXPERIMENTS_CSV}\n")
for r in rows:
    print(f"- {r['model_id']}")
    print(f"   iters={r['iters']}  bs={r['batch_size']}  accum={r['grad_accum']}  max_len={r['max_seq_length']}  lr={r['learning_rate']}  bf16={r['bf16']}")
    print(f"   est_tokens≈{r['est_tokens']:,}  adapter={r['adapter_path']}")

In [None]:
# STEP 7 (PATCH) — MLX LoRA command fix for 0.26.x
# - Switch to `python -m mlx_lm lora` subcommand form
# - Remove unsupported flags: --gradient-accumulation, --log-dir, --bf16
# - Keep your computed `iters`, batch size, lr, max-seq-length
# - Optional: add reporting/eval knobs that lora *does* support

from __future__ import annotations
import csv, shlex, subprocess, sys
from pathlib import Path
from typing import Dict, Any, List, Optional
from config_loader import load_config
cfg = load_config()

out_dir = Path("data"); out_dir.mkdir(exist_ok=True)


RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS = RUN_DIR / cfg.run.experiments


# ---- Controls ----
DRY_RUN = False                 # set True to just print commands
ONLY_MODEL_ID = ""              # or set to a specific model_id string
ONLY_ROW = None                 # or an integer index
# Optional lora reporting/eval settings (set to 0 to skip passing)
STEPS_PER_REPORT = 10
STEPS_PER_EVAL   = 50
VAL_BATCHES      = 1
# ------------------

def load_rows(path: Path) -> List[Dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [dict(x) for x in r]
    for x in rows:
        for k in ("epochs", "iters", "batch_size", "grad_accum", "max_seq_length", "bf16"):
            if k in x and x[k] != "":
                x[k] = int(float(x[k]))
        for k in ("learning_rate",):
            if k in x and x[k] != "":
                x[k] = float(x[k])
    return rows

def select_rows(rows: List[Dict[str, Any]], only_model: str, only_row_idx: Optional[int]) -> List[Dict[str, Any]]:
    if only_row_idx is not None:
        return [rows[only_row_idx]]
    if only_model:
        return [r for r in rows if r.get("model_id") == only_model]
    return rows

def ensure_dirs(row: Dict[str, Any]):
    Path(row["adapter_path"]).mkdir(parents=True, exist_ok=True)
    Path(row["log_dir"]).mkdir(parents=True, exist_ok=True)

def build_cmd(row: Dict[str, Any]) -> str:
    py = shlex.quote(sys.executable)
    model = shlex.quote(row["model_id"])
    data_dir = shlex.quote(row["data_dir"])
    iters = int(row["iters"])
    bs = int(row["batch_size"])
    maxlen = int(row["max_seq_length"])
    lr = float(row["learning_rate"])
    adapter = shlex.quote(row["adapter_path"])

    # NOTE: no --gradient-accumulation / --bf16 / --log-dir
    parts = [
        f"{py} -m mlx_lm.lora",
        f"--model {model}",
        f"--data {data_dir}",
        "--train",
        "--fine-tune-type lora",
        f"--batch-size {bs}",
        f"--iters {iters}",
        f"--learning-rate {lr}",
        f"--max-seq-length {maxlen}",
        f"--adapter-path {adapter}",
    ]
    if VAL_BATCHES:      parts += [f"--val-batches {int(VAL_BATCHES)}"]
    if STEPS_PER_REPORT: parts += [f"--steps-per-report {int(STEPS_PER_REPORT)}"]
    if STEPS_PER_EVAL:   parts += [f"--steps-per-eval {int(STEPS_PER_EVAL)}"]
    return " ".join(parts)

def run_cmd(cmd: str) -> int:
    print("\n[MLX train]", cmd)
    if DRY_RUN:
        print("DRY_RUN=True -> not executing.")
        return 0
    return subprocess.run(cmd, shell=True).returncode

rows = load_rows(EXPERIMENTS)
todo = select_rows(rows, ONLY_MODEL_ID, ONLY_ROW)

print(f"Found {len(rows)} rows; running {len(todo)} row(s). DRY_RUN={DRY_RUN}")
for i, row in enumerate(todo):
    print(f"\n=== RUN {i+1}/{len(todo)} ===")
    ensure_dirs(row)
    rc = run_cmd(build_cmd(row))
    if rc != 0:
        print(f"❌ Training failed with returncode={rc}")
        break
    print("✅ Training launched.")

In [None]:
# STEP 8 — Artifact Registry
# Reads experiments.csv and registers produced artifacts (adapters, logs).
# - Computes SHA256 & sizes
# - Writes artifacts.json
# - Creates per-model symlinks: latest_adapter -> adapter , latest_logs -> logs

from __future__ import annotations
import json, hashlib, os, time
from pathlib import Path
from typing import Dict, Any, List
import csv

from config_loader import load_config
cfg = load_config()
out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts


def sha256_file(p: Path) -> str:
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            h.update(chunk)
    return h.hexdigest()

def gather_dir_files(root: Path) -> List[Dict[str, Any]]:
    out = []
    if not root.exists():
        return out
    for p in sorted(root.rglob("*")):
        if p.is_file():
            out.append({
                "path": str(p.resolve()),
                "rel": str(p.relative_to(root)),
                "bytes": p.stat().st_size,
                "sha256": sha256_file(p),
                "mtime_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(p.stat().st_mtime)),
            })
    return out

def load_rows(path: Path):
    with path.open("r", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [dict(x) for x in r]
    return rows

if not EXPERIMENTS_CSV.exists():
    raise SystemExit("experiments.csv not found (run Step 6).")

rows = load_rows(EXPERIMENTS)
registry: Dict[str, Any] = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "runs": []
}

for r in rows:
    model_id = r["model_id"]
    model_tag = model_id.replace("/", "--")
    out_root = Path(r["adapter_path"]).parent.parent  # runs/<model_tag>
    adapter_dir = Path(r["adapter_path"])
    logs_dir    = Path(r["log_dir"])

    # create handy symlinks
    try:
        (out_root / "latest_adapter").unlink(missing_ok=True)
        (out_root / "latest_adapter").symlink_to(adapter_dir.name)
    except Exception: pass
    try:
        (out_root / "latest_logs").unlink(missing_ok=True)
        (out_root / "latest_logs").symlink_to(logs_dir.name)
    except Exception: pass

    entry = {
        "model_id": model_id,
        "output_root": str(out_root.resolve()),
        "adapter_dir": str(adapter_dir.resolve()),
        "logs_dir": str(logs_dir.resolve()),
        "files": {
            "adapter": gather_dir_files(adapter_dir),
            "logs": gather_dir_files(logs_dir),
        },
        "training_params": {
            "iters": int(float(r.get("iters", 0) or 0)),
            "batch_size": int(float(r.get("batch_size", 0) or 0)),
            "max_seq_length": int(float(r.get("max_seq_length", 0) or 0)),
            "learning_rate": float(r.get("learning_rate", 0.0) or 0.0),
        }
    }
    registry["runs"].append(entry)

ARTIFACTS.write_text(json.dumps(registry, indent=2), encoding="utf-8")

# Console summary
print("=== ARTIFACT REGISTRY ===")
print("Wrote:", ARTIFACTS)
for run in registry["runs"]:
    adap_files = run["files"]["adapter"]
    n = len(adap_files)
    sizes = sum(f["bytes"] for f in adap_files)
    print(f"- {run['model_id']}")
    print("   adapter_dir:", run["adapter_dir"])
    print("   logs_dir:   ", run["logs_dir"])
    print(f"   adapter files: {n}  total bytes: {sizes:,}")
    if n:
        print("   latest:", adap_files[-1]["rel"], adap_files[-1]["bytes"], "bytes")
    print("   symlinks: latest_adapter, latest_logs")

In [None]:
# STEP 9 — Fuse & Quantize (final, clean/idempotent)
# - Reuses experiments.csv + artifacts.json from Steps 6–8
# - If needed, fuses adapter -> fused/model  (mlx_lm fuse)
# - Quantizes fused -> quantized/ (mlx_lm convert with explicit flags)
# - Removes any pre-existing quantized dir to avoid MLX "already exists" error
# - Updates artifacts.json

from __future__ import annotations
import json, hashlib, time, shlex, subprocess, sys, shutil
from pathlib import Path
from typing import Dict, Any, List

from config_loader import load_config
cfg = load_config()
out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts

# ---- Controls ----
DO_FUSE   = True           # set False to skip fusing (if already fused)
Q_BITS    = 4              # 4 or 8
Q_GROUP   = 64             # e.g., 32, 64, 128
DTYPE     = cfg.model.dtype     # float16 | bfloat16 | float32
DRY_RUN   = False
# -------------------

def run_cmd(cmd: str) -> int:
    print("[MLX]", cmd)
    if DRY_RUN:
        print("DRY_RUN=True -> not executing.")
        return 0
    return subprocess.run(cmd, shell=True).returncode

def sha256_file(p: Path) -> str:
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1024*1024), b""):
            h.update(chunk)
    return h.hexdigest()

def list_files(root: Path) -> List[Dict[str, Any]]:
    out = []
    if not root.exists(): return out
    for p in sorted(root.rglob("*")):
        if p.is_file():
            out.append({
                "path": str(p.resolve()),
                "rel": str(p.relative_to(root)),
                "bytes": p.stat().st_size,
                "sha256": sha256_file(p),
                "mtime_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(p.stat().st_mtime)),
            })
    return out

# Load artifacts (for adapter paths) and experiments (for model_ids)
if not ARTIFACTS.exists():
    raise SystemExit("artifacts.json not found. Run Steps 8/7/6 first.")

registry = json.loads(ARTIFACTS.read_text(encoding="utf-8"))
runs = registry.get("runs", [])
if not runs:
    raise SystemExit("No runs found in artifacts.json.")

py = shlex.quote(sys.executable)
updated = False

for entry in runs:
    model_id   = entry["model_id"]
    output_dir = Path(entry["output_root"])
    adapter_dir = Path(entry["adapter_dir"])
    fused_dir   = Path(entry.get("fused_dir") or (output_dir / "fused" / "model"))

    # 1) Fuse (optional / idempotent)
    if DO_FUSE and not fused_dir.exists():
        fused_dir.parent.mkdir(parents=True, exist_ok=True)
        cmd_fuse = (
            f"{py} -m mlx_lm.fuse "
            f"--model {shlex.quote(model_id)} "
            f"--adapter-path {shlex.quote(str(adapter_dir))} "
            f"--save-path {shlex.quote(str(fused_dir))}"
        )
        print("\n=== FUSE ===")
        rc = run_cmd(cmd_fuse)
        if rc != 0:
            print(f"❌ Fuse failed for {model_id}")
            continue
        entry["fused_dir"] = str(fused_dir.resolve())
        entry.setdefault("files", {})["fused"] = list_files(fused_dir)
        updated = True
    elif fused_dir.exists():
        entry["fused_dir"] = str(fused_dir.resolve())
        entry.setdefault("files", {})["fused"] = list_files(fused_dir)

    if not fused_dir.exists():
        print(f"Skipping quantize for {model_id}: fused_dir missing.")
        continue

    # 2) Quantize (idempotent + clean)
    q_dir = output_dir / "quantized"
    if q_dir.exists():
        print(f"Removing pre-existing quantized dir: {q_dir}")
        shutil.rmtree(q_dir)
    #q_dir.mkdir(parents=True, exist_ok=True)

    cmd_q = (
        f"{py} -m mlx_lm.convert "
        f"--hf-path {shlex.quote(str(fused_dir))} "
        f"--mlx-path {shlex.quote(str(q_dir))} "
        f"--q-bits {int(Q_BITS)} "
        f"--q-group-size {int(Q_GROUP)} "
        f"--dtype {shlex.quote(DTYPE)} "
        f"-q"
    )
    print("\n=== QUANTIZE ===")
    rc = run_cmd(cmd_q)
    if rc != 0:
        print(f"❌ Quantize failed for {model_id}")
        continue

    entry["quantized_dir"] = str(q_dir.resolve())
    entry["quantize_bits"] = int(Q_BITS)
    entry["q_group_size"]  = int(Q_GROUP)
    entry.setdefault("files", {})["quantized"] = list_files(q_dir)
    updated = True

# Save updated artifacts
if updated:
    registry["updated_utc"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    ARTIFACTS.write_text(json.dumps(registry, indent=2), encoding="utf-8")

# Summary
print("\n=== FUSE/QUANTIZE SUMMARY ===")
print("Wrote:", ARTIFACTS)
for entry in registry.get("runs", []):
    print(f"- {entry['model_id']}")
    if "fused_dir" in entry:
        print("   fused_dir:    ", entry['fused_dir'], f"({len(entry.get('files',{}).get('fused',[]))} files)")
    if "quantized_dir" in entry:
        print("   quantized_dir:", entry['quantized_dir'],
              f"(q{entry.get('quantize_bits')}, group={entry.get('q_group_size')})",
              f"files={len(entry.get('files',{}).get('quantized',[]))}")

Step 10.1 – Policy-locked generation
	•	Now: Uses the policy you just picked so future generations are consistent and reproducible.
	•	Later: This becomes your “official” inference cell. If MLX changes decode flags again, your outputs won’t change because the prompts and artifact choice are fixed.
    	10.	Generation Harness (Deterministic)

	•	Inputs: prompts, decoding params; model path (+optional adapter)
	•	Outputs: raw generations with full provenance (prompt, seeds, params), generations.jsonl

In [None]:
# STEP 10 — Dynamic few-shot + anti-copy retry → WRITE eval_out/generations.{jsonl,csv}
# Adds fields: mode, generation (alias of output_text)
from __future__ import annotations
import os, json, random, hashlib, csv, time
from pathlib import Path
from typing import List, Optional
from mlx_lm import load as mlx_load, generate as mlx_generate

# --- Config ---
BASE      = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER   = "runs/TinyLlama--TinyLlama-1.1B-Chat-v1.0/adapter"
FUSED     = "runs/fused/model"        # optional
QUANT     = "runs/quantized"          # optional
MAX_NEW   = 128
SEED      = 7
N_SHOTS   = 3
MIN_WORDS = 3
RETRIES   = 2
from config_loader import load_config
cfg = load_config()
OUT_DIR       = Path(cfg.data.output_dir); OUT_DIR.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EVAL_DIR      = Path(cfg.eval.output_dir); EVAL_DIR.mkdir(exist_ok=True)
EXPERIMENTS   = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts
CONTRACT      = OUT_DIR / cfg.paths.contract

SEED  = cfg.run.alt_seed

MODES = ["default_eos", "no_eos", "custom_stop"]
CUSTOM_STOP = "\n\n"  # client-side trim for 'custom_stop'

PROMPTS = [
    "Share a saying about time.",
    "Offer a short proverb on patience.",
    "Give a hopeful saying for travelers.",
]

JSONL_PATH = EVAL_DIR / (cfg.paths.generations+".jsonl")
CSV_PATH   = EVAL_DIR / (cfg.paths.generations+".csv")
TOKMETA    = OUT_DIR / (cfg.paths.tokenizer+".json")

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
random.seed(SEED)

# --- Load corpus for few-shot + anti-copy ---
contract = json.loads(CONTRACT.read_text(encoding="utf-8"))
text_field = next((k for k, v in contract["schema"]["fields"].items() if str(v).lower()=="string"), "text")
train_path = Path(contract["filenames"]["train"]["resolved"])

def sha(s: str) -> str: return hashlib.sha256(s.encode("utf-8","ignore")).hexdigest()
def wc(s: str) -> int:  return len(s.split())

train_lines: List[str] = []
with train_path.open("r", encoding="utf-8") as f:
    for line in f:
        try:
            obj = json.loads(line)
            t = obj.get(text_field, "")
            if isinstance(t, str) and t.strip():
                train_lines.append(t.strip())
        except Exception:
            pass

# dedupe + length buckets
seen = set(); unique = []
for t in train_lines:
    h = sha(t)
    if h not in seen:
        seen.add(h); unique.append(t)
short  = [t for t in unique if wc(t) <= 4]
medium = [t for t in unique if 5 <= wc(t) <= 12]
longer = [t for t in unique if wc(t) > 12]

def pick_diverse_shots(k: int) -> List[str]:
    pool = []
    if short:  pool.append(random.choice(short))
    if medium: pool.append(random.choice(medium))
    if longer: pool.append(random.choice(longer))
    rest = [t for t in unique if t not in pool]
    random.shuffle(rest)
    return (pool + rest)[:k]

train_blob = "\n\n".join(unique)
train_set  = set(unique)

def format_fewshot(prompt: str, shots: List[str]) -> str:
    return "Some Proverbs:\n- " + "\n- ".join(shots) + f"\n\n{prompt}\n- "

def trim_on_custom_stop(text: str, stop: str) -> str:
    i = text.find(stop)
    return text if i == -1 else text[:i]

def is_bad(gen: str) -> bool:
    g = gen.strip()
    if wc(g) < MIN_WORDS: return True
    if g in train_set:    return True
    if len(g) >= 24 and g in train_blob: return True
    return False

# choose artifact: quantized > fused > adapter
artifact_label: str
model_path: str
adapter_path: Optional[str] = None
if Path(QUANT).exists():
    artifact_label = "quantized"; model_path, adapter_path = QUANT, None
elif Path(FUSED).exists():
    artifact_label = "fused";     model_path, adapter_path = FUSED, None
else:
    artifact_label = "base+adapter"; model_path, adapter_path = BASE, ADAPTER

model, tok = mlx_load(model_path, adapter_path=adapter_path)

# Tokenizer meta (optional but helpful)
TOKMETA.write_text(json.dumps({
    "eos_token": getattr(tok, "eos_token", None),
    "eos_token_id": getattr(tok, "eos_token_id", None),
    "pad_token": getattr(tok, "pad_token", None),
    "pad_token_id": getattr(tok, "pad_token_id", None),
}, indent=2), encoding="utf-8")

def generate_once(p: str) -> tuple[str, str, list[str]]:
    """Return (full_prompt, generation, shots) with retry for short/copy."""
    tries = 0
    while True:
        shots = pick_diverse_shots(N_SHOTS)
        fp = format_fewshot(p, shots)
        txt = mlx_generate(model=model, tokenizer=tok, prompt=fp, max_tokens=MAX_NEW)
        gen = txt[len(fp):] if txt.startswith(fp) else txt
        gen = gen.strip()
        if not is_bad(gen) or tries >= RETRIES:
            return fp, gen, shots
        tries += 1

# Collect rows
ts = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
rows = []

for p in PROMPTS:
    # default_eos
    fp, gen, shots = generate_once(p)
    rows.append({
        "timestamp": ts, "seed": SEED,
        "model_id": BASE, "artifact": artifact_label,
        "artifact_model_path": model_path, "adapter_path": adapter_path or "",
        "prompt_variant": "fewshot-dynamic", "mode": "default_eos",
        "prompt": p, "input_text": fp,
        "output_text": gen, "generation": gen,   # <— add 'generation'
        "shots": shots, "max_new_tokens": MAX_NEW,
    })
    print(f"\n[default_eos] {p}\n→ {gen}")

    # no_eos (same call; labeled separately for grouping)
    fp, gen, shots = generate_once(p)
    rows.append({
        "timestamp": ts, "seed": SEED,
        "model_id": BASE, "artifact": artifact_label,
        "artifact_model_path": model_path, "adapter_path": adapter_path or "",
        "prompt_variant": "fewshot-dynamic", "mode": "no_eos",
        "prompt": p, "input_text": fp,
        "output_text": gen, "generation": gen,   # <— add 'generation'
        "shots": shots, "max_new_tokens": MAX_NEW,
    })
    print(f"\n[no_eos] {p}\n→ {gen}")

    # custom_stop (client-side trim)
    fp, gen, shots = generate_once(p)
    gen_trim = trim_on_custom_stop(gen, CUSTOM_STOP).strip()
    rows.append({
        "timestamp": ts, "seed": SEED,
        "model_id": BASE, "artifact": artifact_label,
        "artifact_model_path": model_path, "adapter_path": adapter_path or "",
        "prompt_variant": "fewshot-dynamic", "mode": "custom_stop",
        "prompt": p, "input_text": fp,
        "output_text": gen_trim, "generation": gen_trim,  # <— add 'generation'
        "shots": shots, "max_new_tokens": MAX_NEW,
        "custom_stop": CUSTOM_STOP,
    })
    print(f"\n[custom_stop] {p}\n→ {gen_trim}")

# Write JSONL
with JSONL_PATH.open("w", encoding="utf-8") as f:
    for r in rows:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

# Write CSV
csv_cols = ["timestamp","seed","model_id","artifact","artifact_model_path","adapter_path",
            "prompt_variant","mode","prompt","generation","output_text","shots","max_new_tokens"]
with CSV_PATH.open("w", encoding="utf-8", newline="") as f:
    w = csv.DictWriter(f, fieldnames=csv_cols)
    w.writeheader()
    for r in rows:
        rr = r.copy(); rr["shots"] = " | ".join(r["shots"])
        w.writerow({k: rr.get(k, "") for k in csv_cols})

print(f"\n=== GENERATION SUMMARY ===")
print(f"Models evaluated: {BASE}")
print(f"Rows: {len(rows)}  |  JSONL: {JSONL_PATH}  |  CSV: {CSV_PATH}")
print(f"Modes: {MODES}")
print(f"Artifacts: ['{artifact_label}']")

In [None]:
# STEP 11 — EOS Behavior Probe & Quick Analysis (revised, JSONL-first)
# Reads eval_out/generations.jsonl to avoid NaN coercion, computes the same stats,
# and shows any rows that CSV parsing would have treated as NaN.

from __future__ import annotations
import json, re, statistics, pandas as pd
from pathlib import Path

GEN_JSONL = Path("eval_out/generations.jsonl")
GEN_CSV   = Path("eval_out/generations.csv")   # optional: for diagnostics only

OUT_JSON  = Path("eval_out/eos_analysis.json")
OUT_SUM   = Path("eval_out/eos_summary.csv")
from config_loader import load_config
cfg = load_config()
OUT_DIR       = Path(cfg.data.output_dir); OUT_DIR.mkdir(exist_ok=True)
EVAL_DIR      = Path(cfg.eval.output_dir); EVAL_DIR.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS   = RUN_DIR / cfg.run.experiments
#ARTIFACTS     = RUN_DIR / cfg.run.artifacts
#CONTRACT      = OUT_DIR / cfg.paths.contract
GEN_JSONL     = EVAL_DIR / (cfg.paths.generations + ".jsonl" )
GEN_CSV       = EVAL_DIR / (cfg.paths.generations + ".csv")
OUT_SUM       = EVAL_DIR / (cfg.paths.summary + ".csv")
OUT_JSON      = EVAL_DIR / (cfg.paths.analysis + ".json")


if not GEN_JSONL.exists():
    raise SystemExit("Missing eval_out/generations.jsonl (run Step 10).")
if not CONTRACT.exists():
    raise SystemExit("Missing data_contract.json (from Step 2).")

# ---- Load generations from JSONL (authoritative) ----
rows = []
with GEN_JSONL.open("r", encoding="utf-8") as f:
    for line in f:
        if line.strip():
            rows.append(json.loads(line))
df = pd.DataFrame(rows)

# ---- (Optional) CSV diagnostics: which rows would be NaN? ----
csv_missing = pd.DataFrame()
if GEN_CSV.exists():
    df_csv = pd.read_csv(GEN_CSV, keep_default_na=False, na_filter=False)
    # If your earlier CSV had blank cells, reading with default settings would coerce to NaN.
    # keep_default_na=False prevents that. We diff the two to surface any mismatches.
    if len(df_csv) != len(df):
        csv_missing = pd.concat([df, df_csv]).drop_duplicates(keep=False)

# ----- Helpers -----
def word_count(s: str) -> int: return len(s.split())
def ends_with_terminator(s: str) -> bool: return bool(re.search(r"[.!?…]$", s.strip()))
def has_trailing_whitespace(s: str) -> bool: return len(s) > 0 and s[-1].isspace()
def distinct_n(tokens, n=1):
    if len(tokens) < n: return 0.0
    ngrams = set(tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1))
    return len(ngrams) / max(1, (len(tokens)-n+1))

# Load training examples for simple memorization checks
c = json.loads(CONTRACT.read_text(encoding="utf-8"))
train_path = Path(c["filenames"]["train"]["resolved"])
text_field = next((k for k,v in c["schema"]["fields"].items() if str(v).lower()=="string"), "text")
train_texts = []
with train_path.open("r", encoding="utf-8") as f:
    for line in f:
        try:
            obj = json.loads(line); t = obj.get(text_field, "")
            if isinstance(t, str): train_texts.append(t.strip())
        except Exception:
            pass
train_blob = "\n\n".join(train_texts)
train_set = set(train_texts)

# ----- Per-row metrics -----
def row_metrics(r):
    gen = str(r.get("generation", ""))
    toks = gen.split()
    d1 = distinct_n(toks, 1); d2 = distinct_n(toks, 2)
    exact_mem = gen.strip() in train_set
    substr_mem = (not exact_mem) and (len(gen.strip()) >= 20) and (gen.strip() in train_blob)
    return {
        **r,
        "len_chars": len(gen),
        "len_words": word_count(gen),
        "ends_sentence": int(ends_with_terminator(gen)),
        "ends_whitespace": int(has_trailing_whitespace(gen)),
        "distinct1": round(d1, 4),
        "distinct2": round(d2, 4),
        "memorized_exact": int(exact_mem),
        "memorized_substring": int(substr_mem),
    }

m = pd.DataFrame([row_metrics(r) for r in df.to_dict(orient="records")])

# ----- Aggregate by mode -----
agg = (m.groupby("mode")
         .agg(
             n=("generation","count"),
             avg_len_chars=("len_chars","mean"),
             med_len_chars=("len_chars","median"),
             avg_len_words=("len_words","mean"),
             sent_end_rate=("ends_sentence","mean"),
             trailing_ws_rate=("ends_whitespace","mean"),
             distinct1_mean=("distinct1","mean"),
             distinct2_mean=("distinct2","mean"),
             mem_exact_rate=("memorized_exact","mean"),
             mem_sub_rate=("memorized_substring","mean"),
           )
         .reset_index())

for col in ["avg_len_chars","med_len_chars","avg_len_words","sent_end_rate","trailing_ws_rate","distinct1_mean","distinct2_mean","mem_exact_rate","mem_sub_rate"]:
    if col in agg.columns:
        agg[col] = agg[col].map(lambda x: round(float(x), 4))

# ----- Per-prompt sample table -----
def sample_table(df_in: pd.DataFrame, n=1):
    out_rows = []
    for prompt, g in df_in.groupby("prompt"):
        for mode, gg in g.groupby("mode"):
            for _, rr in gg.head(n).iterrows():
                out_rows.append({"prompt": prompt, "mode": mode, "generation": rr["generation"]})
    return pd.DataFrame(out_rows)
preview = sample_table(m, n=1)

# ----- Save outputs -----
OUT_SUM.parent.mkdir(parents=True, exist_ok=True)
agg.to_csv(OUT_SUM, index=False)

analysis = {
    "created_utc": __import__("time").strftime("%Y-%m-%dT%H:%M:%SZ", __import__("time").gmtime()),
    "by_mode": agg.to_dict(orient="records"),
    "notes": [
        "JSONL is used as source of truth to avoid NaN coercion from CSV parsing.",
        "distinct* ~ lexical diversity over whitespace tokens.",
        "memorized_* checks generation against training set (exact / long substring).",
    ],
}
OUT_JSON.write_text(json.dumps(analysis, indent=2), encoding="utf-8")

# ----- Console summary -----
print("=== EOS / OUTPUT ANALYSIS (by mode) [JSONL] ===")
print(agg.to_string(index=False))

print("\n=== SAMPLE OUTPUTS (1 per prompt×mode) ===")
for _, row in preview.iterrows():
    print(f"\n[{row['mode']}] {row['prompt']}\n→ {row['generation']}")

if not csv_missing.empty:
    print("\n[CSV diagnostic] These rows mismatch when parsing CSV with defaults; JSONL kept them:")
    display(csv_missing[["mode","prompt","generation"]].head(6))

print(f"\nWrote: {OUT_SUM} and {OUT_JSON}")

Step 11 – EOS probe & quick analysis
	•	Now: Catches “silent” decodes, short/empty strings, trailing whitespace, sentence-end rate, and simple memorization flags. You already saw it surface empties.
	•	Later: Great triage for new bases or corpora. If empties spike or distinct-n tanks, you know to adjust prompts, iters, or artifacts before deeper eval.


Step 12 – Regeneration sanity checks (artifact × prompt ablation)
	•	Now: Proved “few-shot > directive > plain” and showed quantized ≈ fused ≈ adapter under that template.
	•	Later: This is your universal “why is it quiet?” playbook. Swap in any model; it quickly tells you whether the failure is the artifact (quantized vs fused) or the prompting policy.

In [None]:
# STEP 12 — Regeneration Sanity Checks (artifact + prompt ablation)
# Goal: diagnose empty outputs by trying:
#   1) artifact: quantized vs fused
#   2) prompts: plain / directive / few-shot
# Uses MLX defaults (no sampling kwargs) for broad compatibility.

from __future__ import annotations
import json, textwrap, time
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from mlx_lm import load as mlx_load, generate as mlx_generate

cfg = load_config()
OUT_DIR       = Path(cfg.data.output_dir); OUT_DIR.mkdir(exist_ok=True)
EVAL_DIR      = Path(cfg.eval.output_dir); EVAL_DIR.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS   = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts
CONTRACT      = OUT_DIR / cfg.paths.contract
GEN_JSONL     = EVAL_DIR / (cfg.paths.generations + ".jsonl" )
GEN_CSV       = EVAL_DIR / (cfg.paths.generations + ".csv")
OUT_SUM       = EVAL_DIR / (cfg.paths.summary + ".csv")
OUT_JSON      = EVAL_DIR / (cfg.paths.analysis + ".json")

# ---- Controls ----
ONLY_MODEL_ID = ""  # "" = all; or exact id
PROMPTS = [
    "Share a saying about time.",
    "Offer a short proverb on patience.",
    "Give a hopeful saying for travelers.",
]
MAX_NEW_TOKENS_SHORT = 64
MAX_NEW_TOKENS_LONG  = 128
# -------------------

def load_runs() -> List[Dict[str, Any]]:
    reg = json.loads(ARTIFACTS.read_text(encoding="utf-8"))
    runs = reg.get("runs", [])
    if ONLY_MODEL_ID:
        runs = [r for r in runs if r.get("model_id") == ONLY_MODEL_ID]
    if not runs:
        raise SystemExit("No matching runs in artifacts.json.")
    return runs

def pick_artifacts(run_entry: Dict[str, Any]) -> List[Tuple[str, Optional[str], str]]:
    """Return list of (model_path, adapter_path, label) in preference order for ablation."""
    out = []
    if run_entry.get("quantized_dir"):
        out.append((run_entry["quantized_dir"], None, "quantized"))
    if run_entry.get("fused_dir"):
        out.append((run_entry["fused_dir"], None, "fused"))
    # fallback: base + adapter
    out.append((run_entry["model_id"], run_entry["adapter_dir"], "base+adapter"))
    # dedup preserve order
    seen = set(); uniq=[]
    for m,a,label in out:
        key=(m,a or "")
        if key in seen: continue
        seen.add(key); uniq.append((m,a,label))
    return uniq

# Prompt variants
def pv_plain(prompt: str) -> str:
    # minimal: just the instruction
    return prompt

def pv_directive(prompt: str) -> str:
    # explicit response marker to discourage immediate EOS
    return f"{prompt}\n\nAnswer with a single important thought:"

def pv_fewshot(prompt: str) -> str:
    # prime with a couple of aphorism-style lines, then ask
    shots = [
        "The moon does not race the tide.",
        "A river carves stone by lingering.",
    ]
    return "Proverbs:\n- " + "\n- ".join(shots) + f"\n\n{prompt}\n- "

PROMPT_VARIANTS = [
    ("plain", pv_plain),
    ("directive", pv_directive),
    ("fewshot", pv_fewshot),
]

def run_generation(model_path: str, adapter_path: Optional[str], prompts: List[str], max_new: int):
    model, tok = mlx_load(model_path, adapter_path=adapter_path or None)
    outs=[]
    for p in prompts:
        txt = mlx_generate(model=model, tokenizer=tok, prompt=p, max_tokens=max_new)
        # strip echoed prompt
        cont = txt[len(p):] if txt.startswith(p) else txt
        outs.append(cont.strip())
    meta = {
        "eos_token": getattr(tok, "eos_token", None),
        "eos_token_id": getattr(tok, "eos_token_id", None),
        "pad_token": getattr(tok, "pad_token", None),
        "pad_token_id": getattr(tok, "pad_token_id", None),
    }
    return outs, meta

def preview(text: str, width=120) -> str:
    return textwrap.shorten(text.replace("\n"," ⏎ "), width=width, placeholder="…")

# Orchestrate
runs = load_runs()
stamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
rows=[]

for run in runs:
    art_list = pick_artifacts(run)

    for model_path, adapter_path, art_label in art_list:
        # Try each prompt variant, short then long budget
        for pv_label, pv_fn in PROMPT_VARIANTS:
            prompts_v = [pv_fn(p) for p in PROMPTS]

            outs_short, meta = run_generation(model_path, adapter_path, prompts_v, MAX_NEW_TOKENS_SHORT)
            outs_long,  _    = run_generation(model_path, adapter_path, prompts_v, MAX_NEW_TOKENS_LONG)

            print(f"\n=== {run['model_id']} | {art_label} | {pv_label} | max_new={MAX_NEW_TOKENS_SHORT} ===")
            for p, o in zip(PROMPTS, outs_short):
                print(f"- {p}\n→ {preview(o)}")

            print(f"\n=== {run['model_id']} | {art_label} | {pv_label} | max_new={MAX_NEW_TOKENS_LONG} ===")
            for p, o in zip(PROMPTS, outs_long):
                print(f"- {p}\n→ {preview(o)}")

            # record minimal table
            for budget, outs in [("short", outs_short), ("long", outs_long)]:
                for p, o in zip(PROMPTS, outs):
                    rows.append({
                        "timestamp_utc": stamp,
                        "model_id": run["model_id"],
                        "artifact": art_label,
                        "prompt_variant": pv_label,
                        "budget": budget,
                        "model_path": model_path,
                        "adapter_path": adapter_path or "",
                        "eos_token": meta["eos_token"],
                        "eos_token_id": meta["eos_token_id"],
                        "prompt": p,
                        "generation": o,
                        "len_chars": len(o),
                        "len_words": len(o.split()),
                        "is_empty": int(len(o.strip())==0),
                    })

# Save quick table
out_path = EVAL_DIR / cfg.paths.ablations
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
    for r in rows:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print(f"\nSaved detailed ablation outputs to {out_path}")
print("Tip: Look for cases where 'fused' + 'fewshot' fills in while 'quantized' + 'plain' is empty.")

Step 13 – Comparative report & policy lock-in
	•	Now: Chose a winning policy (quantized + few-shot) and wrote generation_policy.json + report.md.
	•	Later: This is your guardrail against drift. When you add a new model or dataset, re-run Step 12 → 13 and the notebook self-tunes the generation policy instead of you hand-tuning every time.

In [None]:
# STEP 13 — Comparative Report & Policy Lock-in
# Inputs:
#   - eval_out/ablation_generations.jsonl   (from Step 12)
#   - artifacts.json                        (for artifact names)
# Outputs:
#   - eval_out/report.md
#   - generation_policy.json  (chosen artifact & prompt template with params)

from __future__ import annotations
import json, pandas as pd, textwrap, time
from pathlib import Path

ABL_JSONL = Path("eval_out/ablation_generations.jsonl")
ARTIFACTS = Path("artifacts.json")
REPORT_MD = Path("eval_out/report.md")
POLICY_JS = Path("generation_policy.json")
cfg = load_config()
#OUT_DIR       = Path(cfg.data.output_dir); OUT_DIR.mkdir(exist_ok=True)
EVAL_DIR      = Path(cfg.eval.output_dir); EVAL_DIR.mkdir(exist_ok=True)
#EXPERIMENTS   = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts
#CONTRACT      = OUT_DIR / cfg.paths.contract
#GEN_JSONL     = EVAL_DIR / (cfg.paths.generations + ".jsonl" )
#GEN_CSV       = EVAL_DIR / (cfg.paths.generations + ".csv")
OUT_SUM       = EVAL_DIR / (cfg.paths.summary + ".csv")
OUT_JSON      = EVAL_DIR / (cfg.paths.analysis + ".json")
ABL_JSONL     = EVAL_DIR / cfg.paths.ablations
REPORT_MD     = EVAL_DIR / cfg.eval.report
POLICY_JS     = EVAL_DIR / cfg.eval.policy
if not ABL_JSONL.exists():
    raise SystemExit("Missing eval_out/ablation_generations.jsonl (run Step 12).")

# Load
rows = [json.loads(l) for l in ABL_JSONL.read_text(encoding="utf-8").splitlines() if l.strip()]
df = pd.DataFrame(rows)

# Score per (artifact, prompt_variant)
def summarize(g):
    n = len(g)
    empty_rate = (g["is_empty"].astype(int).sum()) / max(1, n)
    sent_end_rate = (g["generation"].fillna("").str.strip().str.endswith(tuple(".!?…")).astype(int).sum()) / max(1, n)
    avg_len = g["len_words"].mean()
    med_len = g["len_words"].median()
    return pd.Series(dict(
        n=n, empty_rate=empty_rate, sent_end_rate=sent_end_rate,
        avg_len=round(float(avg_len),3), med_len=float(med_len)
    ))

agg = df.groupby(["artifact","prompt_variant"], as_index=False, group_keys=False).apply(summarize)

# Pick winner by heuristic:
# 1) lowest empty_rate, 2) highest sent_end_rate, 3) highest avg_len
winner = (agg.sort_values(["empty_rate","sent_end_rate","avg_len"], ascending=[True,False,False])
            .iloc[0].to_dict())

# Build a human-friendly table
def pct(x): return f"{x*100:.1f}%"
table = agg.copy()
table["empty_rate"] = table["empty_rate"].map(pct)
table["sent_end_rate"] = table["sent_end_rate"].map(pct)

# Draft a short markdown report
ts = time.strftime("%Y-%m-%d %H:%M:%SZ", time.gmtime())
lines = []
lines += [f"# Learning Ablation Report  \n_{ts}_\n"]
lines += ["## Summary by artifact × prompt_variant"]
lines += ["\n| artifact | prompt_variant | n | empty_rate | sent_end_rate | avg_len | med_len |",
          "|---|---:|---:|---:|---:|---:|---:|"]
for _, r in table.iterrows():
    lines += [f"| {r['artifact']} | {r['prompt_variant']} | {int(r['n'])} | {r['empty_rate']} | {r['sent_end_rate']} | {r['avg_len']} | {int(r['med_len'])} |"]

lines += ["\n## Chosen policy"]
lines += [f"- **artifact**: `{winner['artifact']}`",
          f"- **prompt_variant**: `{winner['prompt_variant']}`",
          "- Rationale: minimize empty outputs, then prefer clean sentence endings and adequate length."]

# Add a tiny sample grid (first row per prompt for the winner)
win_mask = (df["artifact"]==winner["artifact"]) & (df["prompt_variant"]==winner["prompt_variant"]) & (df["budget"]=="long")
sample = df[win_mask].groupby("prompt").head(1)
lines += ["\n## Sample outputs (winner policy)"]
for _, r in sample.iterrows():
    gen = textwrap.shorten(str(r["generation"]).replace("\n"," ⏎ "), width=160, placeholder="…")
    lines += [f"- **{r['prompt']}** → {gen}"]

REPORT_MD.write_text("\n".join(lines), encoding="utf-8")
print(f"Wrote {REPORT_MD}")

# Save a reusable generation policy (Step 10 can read this later)
# Encode few-shot template explicitly so you can tweak the shots later.
POLICY = {
    "created_utc": ts,
    "artifact_preference": [winner["artifact"], "fused", "adapter"],  # fallbacks
    "prompt_policy": {
        "name": winner["prompt_variant"],
        "fewshot": {
            "shots": [
                "The moon does not race the tide.",
                "A river carves stone by lingering."
            ],
            "prefix": "some ideas:\n- ",
            "joiner": "\n- ",
            "suffix": "\n\n{prompt}\n- "
        },
        "directive": {
            "suffix": "\n\nAnswer with a single saying:"
        }
    }
}
POLICY_JS.write_text(json.dumps(POLICY, indent=2), encoding="utf-8")
print(f"Wrote {POLICY_JS}")

# Console preview
print("\n=== WINNER ===")
print(f"artifact={winner['artifact']}  prompt_variant={winner['prompt_variant']}")
print("\n=== TABLE ===")
print(agg.to_string(index=False))

In [None]:
# STEP 14 — Repro Command Builder (exports env vars; brace-safe heredoc)
# Writes: eval_out/repro.sh (executable), with train→fuse→(optional) quantize→generate

from __future__ import annotations
import os, csv, json, shlex, stat
from pathlib import Path

#EXPERIMENTS_CSV = Path("experiments.csv")
#OUT_DIR         = Path("eval_out"); OUT_DIR.mkdir(parents=True, exist_ok=True)
#REPRO_SH        = OUT_DIR / "repro.sh"
cfg = load_config()
OUT_DIR       = Path(cfg.data.output_dir); OUT_DIR.mkdir(exist_ok=True)
EVAL_DIR      = Path(cfg.eval.output_dir); EVAL_DIR.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
#EXPERIMENTS   = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts

GEN_JSONL     = EVAL_DIR / (cfg.paths.generations + ".jsonl" )
GEN_CSV       = EVAL_DIR / (cfg.paths.generations + ".csv")
OUT_SUM       = EVAL_DIR / (cfg.paths.summary + ".csv")
OUT_JSON      = EVAL_DIR / (cfg.paths.analysis + ".json")
REPRO_SH      = EVAL_DIR / cfg.eval.recreate    # a shell script to run it all again
# ---- knobs ----
ONLY_ROW        = 0
DO_QUANTIZE     = True
Q_BITS          = 4
Q_GROUP_SIZE    = 64
DTYPE           = "bfloat16"
MAX_NEW_TOKENS  = 128
PROMPTS         = [
    "Share an important saying about time.",
    "Offer a short proverb on patience.",
    "Give a hopeful saying for travelers.",
]
# ---------------

def die(msg): raise SystemExit(f"[step14] {msg}")

# Load chosen experiment row
if not EXPERIMENTS.exists():
    die("Missing experiments.csv (run Step 5.5 then Step 6).")

rows = list(csv.DictReader(EXPERIMENTS.open("r", encoding="utf-8")))
if not rows: die("experiments.csv has no rows.")
if not (0 <= ONLY_ROW < len(rows)): die(f"ONLY_ROW {ONLY_ROW} out of range (0..{len(rows)-1}).")

row = rows[ONLY_ROW]
need = ["model_id","data_dir","iters","batch_size","learning_rate","max_seq_length","adapter_path"]
miss = [k for k in need if not row.get(k)]
if miss: die(f"Row missing fields: {miss}")

# Normalize fields
model_id = str(row["model_id"])
data_dir = str(row["data_dir"])
iters    = int(float(row["iters"]))
bs       = int(float(row["batch_size"]))
lr       = float(row["learning_rate"])
maxlen   = int(float(row["max_seq_length"]))
adapter  = str(row["adapter_path"])

run_root  = str(Path(adapter).parent.parent)   # runs/<model_tag>
fused_dir = str(Path(run_root) / "fused" / "model")
quant_dir = str(Path(run_root) / "quantized")

# Build the script (no f-strings inside heredoc; use env vars)
lines = []
def L(s=""): lines.append(s)

L("#!/usr/bin/env bash")
L("set -euo pipefail")
L("export TOKENIZERS_PARALLELISM=false")
L("")
# Export parameters as env vars (so heredoc Python can read them)
L("# === Export experiment parameters ===")
L(f'export MODEL_ID={shlex.quote(model_id)}')
L(f'export DATA_DIR={shlex.quote(data_dir)}')
L(f'export ADAPTER={shlex.quote(adapter)}')
L(f'export FUSED_DIR={shlex.quote(fused_dir)}')
L(f'export QUANT_DIR={shlex.quote(quant_dir)}')
L(f'export MAX_NEW_TOKENS={MAX_NEW_TOKENS}')
L(f'export PROMPTS_JSON={shlex.quote(json.dumps(PROMPTS, ensure_ascii=False))}')
L("")

# 1) TRAIN
L('echo "== 1) TRAIN (LoRA) =="')
L("python -m mlx_lm lora \\")
L("  --model \"$MODEL_ID\" \\")
L("  --data \"$DATA_DIR\" \\")
L("  --train --fine-tune-type lora \\")
L(f"  --batch-size {bs} \\")
L(f"  --iters {iters} \\")
L(f"  --learning-rate {lr} \\")
L(f"  --max-seq-length {maxlen} \\")
L("  --adapter-path \"$ADAPTER\" \\")
L("  --val-batches 1 \\")
L("  --steps-per-report 10 \\")
L("  --steps-per-eval 50")
L("")

# 2) FUSE
L('echo "== 2) FUSE (adapter -> fused model) =="')
L("python -m mlx_lm fuse \\")
L("  --model \"$MODEL_ID\" \\")
L("  --adapter-path \"$ADAPTER\" \\")
L("  --save-path \"$FUSED_DIR\"")
L("")

# 3) CONVERT (optional)
if DO_QUANTIZE:
    L(f'echo "== 3) CONVERT (fused -> MLX q{Q_BITS}, group={Q_GROUP_SIZE}) =="')
    L("rm -rf \"$QUANT_DIR\"")
    L("python -m mlx_lm convert \\")
    L("  --hf-path \"$FUSED_DIR\" \\")
    L("  --mlx-path \"$QUANT_DIR\" \\")
    L(f"  --q-bits {Q_BITS} \\")
    L(f"  --q-group-size {Q_GROUP_SIZE} \\")
    L(f"  --dtype {shlex.quote(DTYPE)} \\")
    L("  -q")
else:
    L('echo "== 3) (SKIP) Quantization disabled =="')
L("")

# 4) GENERATE — brace-safe heredoc; reads env vars inside Python
heredoc = r"""echo "== 4) GENERATE (policy-locked) =="
python - <<'PY'
import json, os, sys
from pathlib import Path
from mlx_lm import load as mlx_load, generate as mlx_generate

# Read env vars
MODEL_ID       = os.environ["MODEL_ID"]
DATA_DIR       = os.environ["DATA_DIR"]
ADAPTER        = os.environ["ADAPTER"]
FUSED_DIR      = os.environ["FUSED_DIR"]
QUANT_DIR      = os.environ["QUANT_DIR"]
MAX_NEW_TOKENS = int(os.environ["MAX_NEW_TOKENS"])
PROMPTS        = json.loads(os.environ["PROMPTS_JSON"])

# Policy: load if available, else fallback
policy_path = Path("generation_policy.json")
policy = {}
if policy_path.exists():
    try:
        policy = json.loads(policy_path.read_text(encoding="utf-8"))
    except Exception as e:
        print("[warn] failed to parse generation_policy.json:", e, file=sys.stderr)

if "artifact_preference" not in policy:
    policy["artifact_preference"] = ["quantized","fused","adapter"]
if "prompt_policy" not in policy:
    policy["prompt_policy"] = {
        "name": "fewshot",
        "fewshot": {
            "shots": [
                "The moon does not race the tide.",
                "A river carves stone by lingering."
            ],
            "prefix": "Proverbs:\n- ",
            "joiner": "\n- ",
            "suffix": "\n\n{prompt}\n- "
        },
        "directive": { "suffix": "\n\nAnswer with a single saying:" }
    }

# Choose artifact by policy + availability
pref = policy.get("artifact_preference", ["quantized","fused","adapter"])
candidates = []
if os.path.isdir(QUANT_DIR): candidates.append(("quantized", QUANT_DIR, None))
if os.path.isdir(FUSED_DIR): candidates.append(("fused", FUSED_DIR, None))
candidates.append(("adapter", MODEL_ID, ADAPTER))

choice = None
for want in pref:
    for label, m, a in candidates:
        if want == label:
            choice = (label, m, a); break
    if choice: break
if choice is None:
    choice = candidates[0]

label, model_path, adapter_path = choice
print(f"[info] using artifact: {label} -> model_path={model_path} adapter={adapter_path or ''}", file=sys.stderr)

def format_prompt(pol, p):
    name = pol["prompt_policy"]["name"]
    if name == "fewshot":
        fp = pol["prompt_policy"]["fewshot"]
        return fp["prefix"] + fp["joiner"].join(fp["shots"]) + fp["suffix"].replace("{prompt}", p)
    elif name == "directive":
        return p + pol["prompt_policy"]["directive"]["suffix"]
    return p

model, tok = mlx_load(model_path, adapter_path=adapter_path)
for p in PROMPTS:
    fp  = format_prompt(policy, p)
    out = mlx_generate(model=model, tokenizer=tok, prompt=fp, max_tokens=MAX_NEW_TOKENS)
    gen = out[len(fp):] if out.startswith(fp) else out
    print(f"\n[prompt] {p}\n→ {gen.strip()}")
PY
"""
L(heredoc)

# Write file + chmod + preview
REPRO_SH.write_text("\n".join(lines), encoding="utf-8")
os.chmod(REPRO_SH, os.stat(REPRO_SH).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)

print(f"Wrote {REPRO_SH} ({REPRO_SH.stat().st_size} bytes)")
print("Preview (head):")
print("\n".join(lines[:20]))
print("Preview (tail):")
print("\n".join(lines[-20:]))

In [None]:
# STEP 15 — Cleanup & Freeze
# - Non-destructive by default.
# - Optionally prune bulky files (old adapter checkpoints, logs).
# - Creates a reproducible archive with:
#     * run_manifest.(yaml|json), requirements.lock
#     * data_contract.json, data_catalog.json, data_report.json
#     * experiments.csv, artifacts.json, generation_policy.json (if present)
#     * eval_out/ (generations, reports)
#     * runs/<model>/adapter (final) + fused/ + quantized/ (unless PRUNE_* set)
# - Writes: dist/<project>-bundle-YYYYmmdd-HHMMSS.tar.gz + SHA256

from __future__ import annotations
import os, json, time, hashlib, shutil
from pathlib import Path
from typing import List, Dict, Any, Tuple
from config_loader import load_config
cfg = load_config()
out_dir = Path("data"); out_dir.mkdir(exist_ok=True)
RUN_DIR       = Path(cfg.run.output_dir)  # where per-model outputs will go
EXPERIMENTS = RUN_DIR / cfg.run.experiments
ARTIFACTS     = RUN_DIR / cfg.run.artifacts


# ---------- Controls (edit as needed) ----------
PROJECT_NAME = "mlx training"
DIST_DIR     = Path("dist")
RUNS_DIR     = Path("runs")
EVAL_DIR     = Path("eval_out")
DATA_FILES   = ["data_contract.json", "data_catalog.json", "data_report.json"]

# Pruning (safe defaults)
PRUNE_OLD_ADAPTER_CHECKPOINTS = False   # True => remove numbered LoRA adapter snapshots, keep final adapters.safetensors
PRUNE_LOGS                     = False   # True => remove runs/*/logs contents (keeps folder)
KEEP_FUSED                     = True    # False => omit fused model from bundle
KEEP_QUANTIZED                 = True    # False => omit quantized model from bundle

# Archiving
INCLUDE_TRAIN_DATA = False              # True => copy data/train+valid JSONL into bundle (usually keep False)
ARCHIVE_TIMESTAMP  = time.strftime("%Y%m%d-%H%M%S", time.gmtime())
ARCHIVE_NAME       = f"{PROJECT_NAME}-bundle-{ARCHIVE_TIMESTAMP}"
# ----------------------------------------------

DIST_DIR.mkdir(parents=True, exist_ok=True)
STAGING = DIST_DIR / f"{ARCHIVE_NAME}"
if STAGING.exists():
    shutil.rmtree(STAGING)
STAGING.mkdir(parents=True, exist_ok=True)

def sha256_file(p: Path) -> str:
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1024*1024), b""):
            h.update(chunk)
    return h.hexdigest()

def dir_size(p: Path) -> int:
    total = 0
    for root, _, files in os.walk(p):
        for fn in files:
            try:
                total += (Path(root)/fn).stat().st_size
            except Exception:
                pass
    return total

def safe_copy(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if src.is_file():
        shutil.copy2(src, dst)
    elif src.is_dir():
        if dst.exists():
            shutil.rmtree(dst)
        shutil.copytree(src, dst, dirs_exist_ok=True)

# 1) Collect top-level files
top_level = [
    "run_manifest.yaml", "run_manifest.json", "requirements.lock",
    "experiments.csv", "artifacts.json", "generation_policy.json"
]
for name in top_level + DATA_FILES:
    p = Path(name)
    if p.exists():
        safe_copy(p, STAGING / p.name)

# 2) Optionally include training data directory per contract
if INCLUDE_TRAIN_DATA and Path("data_contract.json").exists():
    try:
        c = json.loads(Path("data_contract.json").read_text(encoding="utf-8"))
        data_dir = Path(c["data_dir"])
        if data_dir.exists():
            # copy only train/valid/val jsonl
            for fn in ["train.jsonl", "valid.jsonl", "val.jsonl"]:
                q = data_dir / fn
                if q.exists():
                    safe_copy(q, STAGING / "data" / fn)
    except Exception as e:
        print("[warn] include data failed:", e)

# 3) Prepare model artifacts: runs/
bundle_runs_root = STAGING / "runs"
if RUNS_DIR.exists():
    for model_dir in sorted(RUNS_DIR.iterdir()):
        if not model_dir.is_dir(): continue
        out_model_dir = bundle_runs_root / model_dir.name
        out_model_dir.mkdir(parents=True, exist_ok=True)

        # Adapter
        adapter_dir = model_dir / "adapter"
        if adapter_dir.exists():
            # optional pruning: remove numbered checkpoints (e.g., 0000100_adapters.safetensors)
            if PRUNE_OLD_ADAPTER_CHECKPOINTS:
                for ck in adapter_dir.glob("*_adapters.safetensors"):
                    try: ck.unlink()
                    except Exception: pass
            safe_copy(adapter_dir, out_model_dir / "adapter")

        # Logs
        logs_dir = model_dir / "logs"
        if logs_dir.exists():
            if PRUNE_LOGS:
                # keep folder name but clear contents
                try:
                    shutil.rmtree(logs_dir)
                    logs_dir.mkdir(parents=True, exist_ok=True)
                except Exception:
                    pass
            # copy (possibly empty)
            safe_copy(logs_dir, out_model_dir / "logs")

        # Fused model
        fused_dir = model_dir / "fused" / "model"
        if fused_dir.exists() and KEEP_FUSED:
            safe_copy(fused_dir, out_model_dir / "fused" / "model")

        # Quantized model
        quant_dir = model_dir / "quantized"
        if quant_dir.exists() and KEEP_QUANTIZED:
            safe_copy(quant_dir, out_model_dir / "quantized")

        # Preserve convenience symlinks as real dirs/files
        for linkname in ["latest_adapter", "latest_logs"]:
            l = model_dir / linkname
            if l.exists():
                try:
                    # resolve target and copy into staging under same name
                    target = l.resolve()
                    if target.is_dir():
                        safe_copy(target, out_model_dir / linkname)
                    elif target.is_file():
                        safe_copy(target, out_model_dir / linkname)
                except Exception:
                    pass

# 4) Evaluation outputs
if EVAL_DIR.exists():
    safe_copy(EVAL_DIR, STAGING / "eval_out")

# 5) Size report before archiving
staging_bytes = dir_size(STAGING)
print("=== FREEZE PREVIEW ===")
print(f"- Staging folder: {STAGING}  size={staging_bytes/1e6:.2f} MB")
print(f"  Includes:")
for p in sorted(STAGING.iterdir()):
    size_mb = dir_size(p)/1e6 if p.is_dir() else (p.stat().st_size/1e6)
    print(f"   • {p.name:<22} {size_mb:>7.2f} MB")

# 6) Create tar.gz
archive_path = shutil.make_archive(str(DIST_DIR / ARCHIVE_NAME), "gztar", root_dir=DIST_DIR, base_dir=ARCHIVE_NAME)

# 7) Hash the archive
archive_sha = sha256_file(Path(archive_path))

# 8) Write manifest for the bundle
bundle_manifest = {
    "bundle_name": Path(archive_path).name,
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "project": PROJECT_NAME,
    "paths": {
        "staging": str(STAGING.resolve()),
        "archive": str(Path(archive_path).resolve())
    },
    "sha256": archive_sha,
    "prune": {
        "old_adapter_checkpoints_removed": bool(PRUNE_OLD_ADAPTER_CHECKPOINTS),
        "logs_pruned": bool(PRUNE_LOGS),
        "keep_fused": bool(KEEP_FUSED),
        "keep_quantized": bool(KEEP_QUANTIZED),
        "include_train_data": bool(INCLUDE_TRAIN_DATA),
    }
}
(Path(archive_path).with_suffix(".sha256")).write_text(archive_sha + "\n", encoding="utf-8")
(DIST_DIR / f"{ARCHIVE_NAME}.manifest.json").write_text(json.dumps(bundle_manifest, indent=2), encoding="utf-8")

print("\n=== FREEZE COMPLETE ===")
print(f"Bundle: {archive_path}")
print(f"SHA256: {archive_sha}")
print(f"Manifest: {DIST_DIR / (ARCHIVE_NAME + '.manifest.json')}")
print("Tip: verify with  shasum -a 256 <archive>  and compare.")