In [1]:
from __future__ import annotations
import json, csv, sys, os, time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import joblib
import pandas as pd
import numpy as np

# --- sklearn bits
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator, TransformerMixin

# --- sentence-transformers for embeddings
try:
    from sentence_transformers import SentenceTransformer
except Exception as e:
    SentenceTransformer = None
    print("[WARN] sentence-transformers not importable right now. Install it to train/encode.", file=sys.stderr)

In [2]:
# ---------------- SBERT encoder (pipeline compatible) ----------------
class SBERTEncoder(BaseEstimator, TransformerMixin):
    """Lightweight sentence-embedding transformer for sklearn pipelines."""
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2", batch_size: int = 64, normalize: bool = True):
        self.model_name = model_name
        self.batch_size = batch_size
        self.normalize = normalize
        self._model = None

    def _ensure_model(self):
        if self._model is None:
            if SentenceTransformer is None:
                raise RuntimeError("sentence-transformers is required to encode prompts.")
            self._model = SentenceTransformer(self.model_name)

    def fit(self, X, y=None):
        self._ensure_model()
        return self

    def transform(self, X):
        self._ensure_model()
        embs = self._model.encode(
            list(X),
            batch_size=self.batch_size,
            show_progress_bar=False,
            normalize_embeddings=self.normalize,
        )
        return np.asarray(embs)

In [3]:
# ------------------------ Training ------------------------
def train_mapper(
    labels_csv: str,
    out_path: str = ".artifacts/defi_mapper.joblib",
    sbert_model: str = "sentence-transformers/all-MiniLM-L6-v2",
    C: float = 8.0,
    max_iter: int = 2000,
    calibrate: bool = True,
    calibration_method: str = "auto",  # 'auto' | 'isotonic' | 'sigmoid'
    calibration_cv: int = 3,
) -> str:
    """Train a SBERT + LogisticRegression pipeline, optionally calibrated, and dump to joblib."""
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    df = pd.read_csv(labels_csv)
    need = {"prompt","label"}
    if not need.issubset(df.columns):
        raise SystemExit(f"[train_mapper] labels_csv must have columns {need}, got {df.columns.tolist()}" )
    df = df.dropna(subset=["prompt","label"]).copy()
    df["prompt"] = df["prompt"].astype(str).str.strip()
    df["label"]  = df["label"].astype(str).str.strip()
    df = df[df["prompt"].str.len() > 0]
    if df.empty:
        raise SystemExit("[train_mapper] No non-empty prompts after cleaning.")

    X = df["prompt"].tolist()
    y = df["label"].tolist()

    base = LogisticRegression(max_iter=max_iter, C=C, class_weight="balanced", random_state=0)

    model = base
    if calibrate:
        # pick a safe calibration automatically for tiny classes
        from collections import Counter
        cnt = Counter(y); m = min(cnt.values())
        method = calibration_method; cv = calibration_cv
        if method == "auto":
            if m >= max(3, cv):
                method, cv = "isotonic", max(3, cv)
            elif m >= 2:
                method, cv = "sigmoid", max(2, min(m, cv))
            else:
                print("[train_mapper] Not enough per-class samples for calibration; skipping.", file=sys.stderr)
                method = None
        if method in ("isotonic","sigmoid"):
            try:
                model = CalibratedClassifierCV(estimator=base, method=method, cv=cv)  # sklearn>=1.3
            except TypeError:
                model = CalibratedClassifierCV(base_estimator=base, method=method, cv=cv)  # older sklearn

    pipe = make_pipeline(SBERTEncoder(sbert_model), model)
    pipe.fit(X, y)
    joblib.dump(pipe, out_path)
    print(f"[train_mapper] wrote: {out_path}  (n={len(X)})")
    return out_path

In [4]:
import argparse, csv, json, os, sys, time
def get_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--backend",       default="wordmap", help="wordmap|sbert")
    ap.add_argument("--model_path",    default=".artifacts/defi_mapper.joblib")
    ap.add_argument("--prompts_jsonl", default="tests/fixtures/defi/defi_mapper_5k_prompts.json")
    ap.add_argument("--labels_csv_pred", default="tests/fixtures/defi/defi_mapper_labeled_5k.csv")
    ap.add_argument("--train_labels_csv", default="tests/fixtures/defi/defi_mapper_labeled_large.csv")
    ap.add_argument("--thresholds",    default="0.5,0.55,0.6,0.65,0.7")
    ap.add_argument("--max_iter",    default="2000")
    ap.add_argument("--C",    default="8")
    ap.add_argument("--calibrate",    default="True")
    ap.add_argument("--calibration_method", choices=["auto","isotonic","sigmoid"], default="auto")
    ap.add_argument("--calibration_cv", type=int, default=3)
    ap.add_argument("--sbert_model", default="sentence-transformers/all-MiniLM-L6-v2")
    ap.add_argument("--threshold", type=float, default=0.6)
    ap.add_argument("--out_path", default="defi_mapper_embed.joblib")
    ap.add_argument("--out_dir",       default="")
    ap.add_argument("--out_rows_csv", default=".artifacts/m8_rows_simple.csv")
    ap.add_argument("--min_overall_acc", default=None)
    
    notebook_args = [
        "--backend", "sbert",
        "--model_path", ".artifacts/defi_mapper.joblib",
        "--prompts_jsonl", "tests/fixtures/defi/defi_mapper_5k_prompts.jsonl",
        "--labels_csv_pred",    "tests/fixtures/defi/defi_mapper_labeled_5k.csv",
        "--train_labels_csv", "tests/fixtures/defi/defi_mapper_labeled_large.csv",
        "--thresholds", "0.5,0.55,0.6,0.65,.7",
        "--max_iter", "2000",
        "--C", "8",
        "--calibrate", "True",
        "--calibration_method", "auto",
        "--calibration_cv", "3",
        "--threshold", "0.5",
        "--min_overall_acc", "0.75",
        "--sbert_model", "sentence-transformers/all-MiniLM-L6-v2",
        "--out_path", "defi_mapper_embed.joblib",
        "--out_rows_csv", ".artifacts/m8_rows_simple.csv",
        "--out_dir", ".artifacts/defi/mapper_bench",
    ]
    
    return ap.parse_args(notebook_args)
    
# ------------------------ CLI (optional) ------------------------
def _as_bool(x: str) -> bool:
    return str(x).strip().lower() in {"1","true","t","yes","y"}

In [5]:
import os
cwd =  os.getcwd().replace("/notebooks","")
os.chdir(cwd)

args = get_args()

model_path = train_mapper(
        labels_csv=args.train_labels_csv,
        out_path=args.out_path,
        sbert_model=args.sbert_model,
        C=float(args.C),
        max_iter=int(args.max_iter),
        calibrate=_as_bool(args.calibrate),
        calibration_method=args.calibration_method,
        calibration_cv=args.calibration_cv
    )

  return forward_call(*args, **kwargs)


[train_mapper] wrote: defi_mapper_embed.joblib  (n=1000)


## ChatGPT Fix

### 0) seed terms + encoder + prototypes

In [6]:
# ===== Cell 0: seed term bank (tight, unambiguous) =====
TERM_BANK = {
    "deposit_asset":  ["deposit", "supply", "provide"],
    "withdraw_asset": ["withdraw", "redeem", "unstow"],
    "swap_asset":     ["swap", "convert", "trade", "exchange"],
    "borrow_asset":   ["borrow", "draw"],
    "repay_asset":    ["repay", "pay back"],
    "stake_asset":    ["stake", "lock", "bond"],
    "unstake_asset":  ["unstake", "unlock", "unbond"],
    "claim_rewards":  ["claim", "harvest", "collect rewards"],
}
PRIMS = list(TERM_BANK.keys())


In [7]:
# ===== Cell 1: SBERT encoder wrapper =====
from sentence_transformers import SentenceTransformer
import numpy as np

class Emb:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        self._m = SentenceTransformer(model_name)
    def encode(self, texts, normalize_embeddings=True, show_progress_bar=False):
        return self._m.encode(
            texts, normalize_embeddings=normalize_embeddings,
            show_progress_bar=show_progress_bar
        )
    def encode_one(self, text, normalize_embeddings=True):
        return self.encode([text], normalize_embeddings=normalize_embeddings)[0]

emb = Emb()  # or Emb("sentence-transformers/all-mpnet-base-v2")


In [8]:
# ===== Cell 2: build primitive -> prototype vector (centroid over seed terms) =====
def build_prototypes(term_bank: dict[str, list[str]], emb: Emb) -> dict[str, np.ndarray]:
    protos = {}
    for prim, terms in term_bank.items():
        V = np.asarray(emb.encode(terms, normalize_embeddings=True))
        protos[prim] = V.mean(axis=0)
    return protos

prototypes = build_prototypes(TERM_BANK, emb)


  return forward_call(*args, **kwargs)


## 1) spans_from_prompt (your style, dict-only return)

In [9]:
# ===== Cell 3: spans_from_prompt (dict-only return) =====
import re

_WORD = re.compile(r"[a-z0-9]+(?:'[a-z0-9]+)?")

def _norm_tokens(txt: str) -> list[str]:
    txt = txt.lower()
    return _WORD.findall(txt)

def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    da = float(np.linalg.norm(a) + 1e-9)
    db = float(np.linalg.norm(b) + 1e-9)
    return float(np.dot(a, b) / (da * db))

def spans_from_prompt(prompt: str,
                      prototypes: dict[str, np.ndarray],
                      emb: Emb,
                      tau_span: float = 0.55,
                      n_max: int = 5,
                      topk_per_prim: int = 3):
    """
    Returns: dict primitive -> [ {primitive, term, score, t_center, start, len}, ... ] (top-k per primitive)
    """
    toks = _norm_tokens(prompt)
    if not toks:
        return {}

    grams, meta = [], []   # meta holds (start, n, t_center)
    for n in range(1, min(n_max, len(toks)) + 1):
        for i in range(0, len(toks) - n + 1):
            s = " ".join(toks[i:i+n])
            t_center = (i + n/2.0) / max(1.0, len(toks))
            grams.append(s)
            meta.append((i, n, t_center))

    V = emb.encode(grams, normalize_embeddings=True, show_progress_bar=False)  # [M, D]

    by_prim: dict[str, list[dict]] = {k: [] for k in prototypes.keys()}
    for m, (i, n, t_center) in enumerate(meta):
        v = V[m]
        for prim, proto in prototypes.items():
            sc = max(0.0, _cosine(v, proto))
            if sc >= tau_span:
                by_prim[prim].append({
                    "primitive": prim,
                    "term": grams[m],
                    "score": round(sc, 4),
                    "t_center": round(t_center, 4),
                    "start": i,
                    "len": n,
                })

    # keep top-k per primitive by score
    for prim in list(by_prim.keys()):
        arr = sorted(by_prim[prim], key=lambda x: x["score"], reverse=True)[:topk_per_prim]
        if arr:
            by_prim[prim] = arr
        else:
            # drop empty lists to make downstream checks easy
            by_prim.pop(prim, None)

    return by_prim


### 2) audit wrapper (robust to dict-only or (dict, meta) variants)

In [10]:
# ===== Cell 4: audit_prompt_with_spans (robust wrapper) =====
def audit_prompt_with_spans(prompt: str,
                            prototypes: dict[str, np.ndarray],
                            emb: Emb,
                            tau_span: float = 0.55,
                            rel_margin: float = 0.06):
    """
    Uses spans_from_prompt to produce:
      - best_primitive: lexical 'winner' (requires tau + small relative margin)
      - scores: top-score per primitive (0 if none)
      - spans: raw span map from spans_from_prompt
      - rel_margin: best - second best score
    Works whether spans_from_prompt returns dict or (dict, meta).
    """
    out = spans_from_prompt(prompt, prototypes, emb, tau_span=tau_span)
    if isinstance(out, tuple):
        span_map, meta = out
    else:
        span_map, meta = out, {}

    # top score per primitive (ensure we cover all prototypes, even if absent in span_map)
    scores = {}
    for prim in prototypes.keys():
        lst = span_map.get(prim, [])
        scores[prim] = (lst[0]["score"] if lst else 0.0)

    ordered = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
    best_prim, best_sc = ordered[0]
    second_sc = ordered[1][1] if len(ordered) > 1 else 0.0
    winner = best_prim if (best_sc >= tau_span and (best_sc - second_sc) >= rel_margin) else None

    return {
        "best_primitive": winner,
        "scores": scores,
        "spans": span_map,
        "rel_margin": best_sc - second_sc,
        "meta": meta,
        "params": {"tau_span": tau_span, "rel_margin": rel_margin},
    }


### 3) opposite-primitive veto (kills tautologies)

In [11]:
# ===== Cell 5: opposite/veto helpers =====
OPPOSITE = {
    "deposit_asset": "withdraw_asset",
    "withdraw_asset": "deposit_asset",
    "stake_asset": "unstake_asset",
    "unstake_asset": "stake_asset",
    "borrow_asset": "repay_asset",
    "repay_asset": "borrow_asset",
    # claim_rewards has no strict opposite in this simple map
}

def should_veto(mapper_top1: str | None, audit_best: str | None) -> bool:
    if not mapper_top1 or not audit_best:
        return False
    return OPPOSITE.get(mapper_top1) == audit_best


### 4) fuse with mapper (single call to get final decision)

In [12]:
# ===== Cell 6: fuse_decision (mapper + spans audit) =====
import joblib

def load_mapper(path=".artifacts/defi_mapper.joblib"):
    return joblib.load(path)

def mapper_top1_label(mapper, prompt: str):
    # generic scikit-like pipeline
    if hasattr(mapper, "predict_proba"):
        probs = mapper.predict_proba([prompt])[0]
        classes = list(getattr(mapper, "classes_", PRIMS))
        top_idx = int(np.argmax(probs))
        return classes[top_idx], float(probs[top_idx])
    else:
        lab = mapper.predict([prompt])[0]
        return str(lab), 1.0

def fuse_decision(prompt: str,
                  mapper,
                  prototypes: dict[str, np.ndarray],
                  emb: Emb,
                  conf_thr: float = 0.70,
                  tau_span: float = 0.55,
                  rel_margin: float = 0.06):
    m_top, m_conf = mapper_top1_label(mapper, prompt)
    fired = bool(m_conf >= conf_thr)

    audit = audit_prompt_with_spans(prompt, prototypes, emb, tau_span=tau_span, rel_margin=rel_margin)
    a_best = audit["best_primitive"]

    if not fired:
        return {
            "prompt": prompt,
            "decision": "abstain_non_exec",
            "reason": "low_conf_mapper",
            "mapper": {"top": m_top, "conf": m_conf},
            "audit": audit,
        }

    if should_veto(m_top, a_best):
        return {
            "prompt": prompt,
            "decision": "reject",
            "reason": f"tautology_veto:{m_top}_vs_{a_best}",
            "mapper": {"top": m_top, "conf": m_conf},
            "audit": audit,
        }

    # optional: require lexical alignment when audit is strong
    if a_best and a_best != m_top:
        return {
            "prompt": prompt,
            "decision": "reject",
            "reason": f"audit_mismatch:{m_top}_vs_{a_best}",
            "mapper": {"top": m_top, "conf": m_conf},
            "audit": audit,
        }

    return {
        "prompt": prompt,
        "decision": "approve",
        "mapper": {"top": m_top, "conf": m_conf},
        "audit": audit,
    }


### 5) quick smoke (copy/paste)

In [13]:
# ===== Cell 7: quick smoke =====
PRIMS = list(prototypes.keys())
mapper = joblib.load(model_path)
tests = [
    "deposit 10 ETH into aave",
    "withdraw 5 ETH",
    "swap 2 ETH for USDC on uniswap — minimize gas",
    "check balance",
    "repay 300 USDC to aave",
    "unstake 1000 USDC from rocket pool",
    "sing me a lullaby",
    "convert centimeters to inches"
]
for p in tests:
    out = fuse_decision(p, mapper, prototypes, emb, conf_thr=0.70, tau_span=0.55, rel_margin=0.06)
    print(out["decision"], "—", p)
    if out["decision"] != "approve":
        print("  reason:", out.get("reason"))


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


approve — deposit 10 ETH into aave
approve — withdraw 5 ETH
approve — swap 2 ETH for USDC on uniswap — minimize gas


  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)


abstain_non_exec — check balance
  reason: low_conf_mapper
approve — repay 300 USDC to aave
approve — unstake 1000 USDC from rocket pool
abstain_non_exec — sing me a lullaby
  reason: low_conf_mapper
reject — convert centimeters to inches
  reason: audit_mismatch:withdraw_asset_vs_swap_asset
