In [88]:
# Minimal M8: load → score → sweep thresholds → pick → inspect
# Paste this whole cell into a Jupyter notebook.

import json, csv, os, math, joblib, numpy as np, pandas as pd
from pathlib import Path
from typing import List, Dict, Any
import os
cwd =  os.getcwd().replace("/notebooks","")
os.chdir(cwd)

In [79]:
# ---------------- I/O ----------------

def read_prompts_jsonl(fp: str) -> List[str]:
    rows = []
    with open(fp, "r") as f:
        for line in f:
            try:
                J = json.loads(line)
                p = (J.get("prompt") if isinstance(J, dict) else None) or line.strip()
                p = str(p).strip()
                if p:
                    rows.append(p)
            except Exception:
                s = line.strip()
                if s:
                    rows.append(s)
    return rows

def read_labels_csv(fp: str) -> Dict[str, str]:
    if not fp:
        return {}
    gold = {}
    try:
        df = pd.read_csv(fp)
        for _, r in df.iterrows():
            p = str(r["prompt"]).strip()
            y = str(r["label"]).strip()
            if p:
                gold[p] = y
    except Exception:
        with open(fp, newline="") as f:
            r = csv.DictReader(f)
            for row in r:
                p = str(row.get("prompt","")).strip()
                y = str(row.get("label","")).strip()
                if p:
                    gold[p] = y
    return gold

# ------------- Scoring helpers -------------

def predict_scores(mapper, prompts: List[str], class_names: List[str]) -> List[Dict[str,float]]:
    """
    Returns per-prompt score dictionaries (class -> probability-like score).
    Logic matches the robust fallbacks used in your M8 script.
    """
    scores: List[Dict[str,float]] = []
    # Try predict_proba first
    if hasattr(mapper, "predict_proba"):
        probs = mapper.predict_proba(prompts)
        classes = list(getattr(mapper, "classes_", class_names))
        for row in probs:
            row_map = {str(c): float(p) for c, p in zip(classes, row)}
            for cname in class_names: row_map.setdefault(cname, 0.0)
            scores.append(row_map)
        return scores
    # Then decision_function → softmax
    if hasattr(mapper, "decision_function"):
        logits = mapper.decision_function(prompts)
        logits = np.array(logits, dtype=float)
        if logits.ndim == 1: logits = logits.reshape(-1, 1)
        classes = list(getattr(mapper, "classes_", class_names))
        for row in logits:
            ex = np.exp(row - row.max())
            prob = ex / (ex.sum() + 1e-12)
            row_map = {str(c): float(p) for c, p in zip(classes, prob)}
            for cname in class_names: row_map.setdefault(cname, 0.0)
            scores.append(row_map)
        return scores
    # Fallback: predict-only
    preds = mapper.predict(prompts)
    for y in preds:
        row_map = {c: 0.0 for c in class_names}
        row_map[str(y)] = 1.0
        scores.append(row_map)
    return scores

def metrics_for_threshold(prompts: List[str],
                          scores: List[Dict[str,float]],
                          class_names: List[str],
                          thr: float,
                          gold: Dict[str,str]) -> Dict[str,Any]:
    total = len(prompts); fired = 0; abstain = 0; correct = 0
    for p, smap in zip(prompts, scores):
        top = max(class_names, key=lambda c: smap.get(c, 0.0))
        conf = smap.get(top, 0.0)
        if conf >= thr:
            fired += 1
            if gold and gold.get(p) == top:
                correct += 1
        else:
            abstain += 1
    coverage = fired / max(1, total)
    abstain_rate = abstain / max(1, total)
    acc_on_fired = (correct / max(1, fired)) if fired else None
    overall_acc = correct / max(1, total)
    return dict(
        threshold=thr, total=total,
        abstain=abstain, abstain_rate=abstain_rate,
        coverage=coverage, fired=fired,
        correct_on_fired=correct,
        accuracy_on_fired=acc_on_fired,
        overall_correct=correct, overall_accuracy=overall_acc
    )

def choose_operating_point(metrics: List[Dict[str,Any]],
                           max_abstain_rate: float = 0.10,
                           choose_by: str = "abstain_then_acc") -> Dict[str,Any]:
    admissible = [m for m in metrics if m["abstain_rate"] <= max_abstain_rate]
    if admissible:
        # prefer higher thresholds; break ties by acc_on_fired
        admissible.sort(key=lambda m: (m["threshold"], m.get("accuracy_on_fired") or 0.0), reverse=True)
        return admissible[0]
    if choose_by == "utility":
        def util(m):
            a = (m["accuracy_on_fired"] or 0.0)
            c = m["coverage"]
            return 0.0 if a == 0 or c == 0 else (2*a*c)/(a+c)
        return max(metrics, key=util)
    # default: minimize abstain, prefer higher acc_on_fired
    metrics.sort(key=lambda m: (m["abstain_rate"], -(m.get("accuracy_on_fired") or 0.0)))
    return metrics[0]

def rows_at_threshold(prompts: List[str],
                      scores: List[Dict[str,float]],
                      class_names: List[str],
                      thr: float,
                      gold: Dict[str,str]) -> pd.DataFrame:
    rows=[]
    for p, smap in zip(prompts, scores):
        top = max(class_names, key=lambda c: smap.get(c, 0.0))
        conf = float(smap.get(top, 0.0))
        fire = conf >= thr
        rows.append(dict(prompt=p,
                         gold_label=gold.get(p,""),
                         predicted=(top if fire else ""),
                         confidence=conf,
                         abstain=(not fire),
                         threshold=thr))
    return pd.DataFrame(rows)

# ------------- One-call runner for notebooks -------------

def m8_notebook_simple(
    mapper_path: str,
    prompts_jsonl: str,
    labels_csv: str = "",
    class_names: List[str] = ("deposit_asset","withdraw_asset","swap_asset","check_balance"),
    thresholds: List[float] = (0.5,0.6,0.7,0.8,0.9),
    max_abstain_rate: float = 0.10,
    min_overall_acc: float | None = 0.85,  # ignored if no labels
    choose_by: str = "abstain_then_acc",
):
    # Load artifacts
    mapper = joblib.load(mapper_path)
    prompts = read_prompts_jsonl(prompts_jsonl)
    gold = read_labels_csv(labels_csv) if labels_csv else {}

    # Score once
    scores = predict_scores(mapper, prompts, list(class_names))

    # Sweep
    per_thr = [metrics_for_threshold(prompts, scores, list(class_names), float(t), gold)
               for t in thresholds]
    metrics_df = pd.DataFrame(per_thr).sort_values("threshold")

    # Choose operating point
    chosen = choose_operating_point(per_thr, max_abstain_rate=max_abstain_rate, choose_by=choose_by)

    # Gate status (if labels provided and min_overall_acc is set)
    has_labels = bool(gold)
    pass_abstain = (chosen["abstain_rate"] <= max_abstain_rate)
    pass_accuracy = True if (not has_labels or min_overall_acc is None) else \
                    ((chosen.get("overall_accuracy") or 0.0) >= float(min_overall_acc))
    status = "pass" if (pass_abstain and pass_accuracy) else "fail"

    # Per-row table at the chosen threshold (easy drilldown)
    rows_df = rows_at_threshold(prompts, scores, list(class_names), chosen["threshold"], gold)

    # Pretty print summary
    display_cols = ["threshold","abstain_rate","coverage","accuracy_on_fired","overall_accuracy","fired","abstain","total"]
    print("[M8] chosen:", {k: chosen.get(k) for k in display_cols})
    print("[M8] status:", status)

    return dict(
        status=status,
        metrics=metrics_df,
        rows=rows_df,
        chosen=chosen,
        has_labels=has_labels,
    )

# ---------------- Example call (edit paths as needed) ----------------
# result = m8_notebook_simple(
#     mapper_path=".artifacts/defi_mapper_embed.joblib",
#     prompts_jsonl="tests/fixtures/defi/defi_mapper_5k_prompts.jsonl",
#     labels_csv="tests/fixtures/defi/defi_mapper_labeled_5k.csv",  # or "" if unlabeled
#     thresholds=[0.2,0.25,0.3,0.35,0.4],
#     max_abstain_rate=0.20,
#     min_overall_acc=0.85,
#     choose_by="utility",  # or "abstain_then_acc"
# )
# display(result["metrics"].head())
# display(result["rows"].head())
