In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Stage 10 — ARC-12 patched benchmark

What changed (quick summary):
- Stronger, orientation-aware classifier to separate flip_h vs flip_v.
- Balanced class priors and temperature-calibrated probabilities.
- PCA is standardized + whitened; components picked by variance with safe caps.
- Fused score = w_cos * cosine + w_maha * (−mahalanobis²) + w_orient * orientation logit.
- LDA (shared covariance) used for stability; optional QDA.
- Sanity block prints per-class mean distances + confusion matrix.

Run:
  python3 latest-arc-benchmark_patched.py --device cpu

Notes:
- Uses GPT-2 mean-pooled hidden states (dim=768) as latents.
- The tiny demo datasets below keep the file self‑contained. Swap them with yours if needed.
"""
from __future__ import annotations
import argparse
import math
import random
from dataclasses import dataclass
from typing import List, Tuple, Dict

import numpy as np
import torch
from numpy.linalg import inv
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix

from transformers import GPT2Tokenizer, GPT2LMHeadModel

# ---------------------------
# Repro & config
# ---------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE_DEFAULT = "cpu"

# PCA / classifier knobs
PCA_TARGET_VAR = 0.996  # target cumulative variance
PCA_MIN = 6
PCA_MAX = 32
PCA_WHITEN = True

# scoring fusion weights
W_COS = 0.60
W_MAHA = 0.35
W_ORIENT = 0.25
TEMP = 1.25  # softmax temperature for probs
MARGIN_ROTATE = 0.05  # require rotate to beat flips by this much, or reassign to closer flip

# ---------------------------
# Tiny demo data (replace with your full sets)
# ---------------------------
TRAIN: Dict[str, List[str]] = {
    "rotate": [
        "Rotate 90° cw: [[1,2],[3,4]] -> [[3,1],[4,2]]. Apply rotation.",
        "Perform a right rotation: [[5,6],[7,8]] → [[7,5],[8,6]].",
        "Rotate 90 deg clockwise the 2x2 grid [[2,1],[4,3]].",
        "Turn the grid right by 90 degrees: [[9,8],[7,6]].",
        "Apply 90° rotation to [[1,3],[2,4]].",
        "Quarter-turn clockwise the matrix [[4,1],[6,2]].",
        "Rotate right: [[a,b],[c,d]] ⇒ [[c,a],[d,b]].",
        "Clockwise rotation example on a 2×2 grid.",
        "Rotate 90° cw pattern transformation instruction.",
        "Use a 90-degree clockwise operation to map input to output."
    ],
    "flip_h": [
        "Flip horizontally: [[1,2],[3,4]] -> [[2,1],[4,3]].",
        "Mirror left-right: [[5,6],[7,8]] → [[6,5],[8,7]].",
        "Apply horizontal reflection across the vertical axis.",
        "Left-right mirror flip for a 2x2 grid.",
        "Perform horizontal flip on the matrix [[a,b],[c,d]].",
        "Reflect across the y-axis (swap columns).",
        "Horizontal mirror operation on small grid.",
        "Use left-right symmetry to transform the pattern.",
        "LR reflection mapping instruction.",
        "Flip left to right on the given grid."
    ],
    "flip_v": [
        "Flip vertically: [[1,2],[3,4]] -> [[3,4],[1,2]].",
        "Mirror top-bottom: [[5,6],[7,8]] → [[7,8],[5,6]].",
        "Apply vertical reflection across the horizontal axis.",
        "Top-bottom mirror flip for a 2x2 grid.",
        "Perform vertical flip on the matrix [[a,b],[c,d]].",
        "Reflect across the x-axis (swap rows).",
        "Vertical mirror operation on small grid.",
        "Use top-bottom symmetry to transform the pattern.",
        "TB reflection mapping instruction.",
        "Flip top to bottom on the given grid."
    ],
}

TEST: List[Tuple[str, str]] = []
for cls in ("rotate", "flip_h", "flip_v"):
    # 4 test prompts per class (12 total)
    samples = [
        f"Test {cls} example 1 on a 2x2 grid.",
        f"Please {cls.replace('_',' ')} the matrix [[1,2],[3,4]].",
        f"ARC-style: apply {cls} to the input grid.",
        f"Execute {cls} transformation for evaluation.",
    ]
    for s in samples:
        TEST.append((s, cls))

# ---------------------------
# Embedding utils
# ---------------------------
_tokenizer: GPT2Tokenizer = None  # type: ignore
_model: GPT2LMHeadModel = None  # type: ignore
_device = None


def load_model(device: str = DEVICE_DEFAULT):
    global _tokenizer, _model, _device
    _device = torch.device(device)
    _tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    _model = GPT2LMHeadModel.from_pretrained("gpt2").to(_device)
    _model.eval()


def embed(text: str) -> np.ndarray:
    with torch.no_grad():
        toks = _tokenizer(text, return_tensors="pt").to(_device)
        out = _model(**toks, output_hidden_states=True)
        hs = out.hidden_states[-1][0]  # (seq, hidden)
        vec = hs.mean(dim=0).detach().cpu().numpy()  # (hidden,)
        return vec.astype(np.float64)


# ---------------------------
# PCA + scaler
# ---------------------------
@dataclass
class Projector:
    scaler: StandardScaler
    pca: PCA

    def transform(self, X: np.ndarray) -> np.ndarray:
        Xs = self.scaler.transform(X)
        return self.pca.transform(Xs)


def fit_projector(X: np.ndarray) -> Projector:
    scaler = StandardScaler(with_mean=True, with_std=True)
    Xs = scaler.fit_transform(X)

    # probe n based on variance goal (bounded by PCA_MIN/PCA_MAX and sample cap)
    n_samples, n_feats = Xs.shape
    hard_cap = min(n_samples, n_feats)
    probe_cap = min(hard_cap, PCA_MAX)

    p_probe = PCA(n_components=probe_cap, whiten=False, svd_solver="full")
    p_probe.fit(Xs)
    cum = np.cumsum(p_probe.explained_variance_ratio_)
    k = next((i + 1 for i, v in enumerate(cum) if v >= PCA_TARGET_VAR), probe_cap)
    k = int(max(PCA_MIN, min(k, probe_cap)))

    pca = PCA(n_components=k, whiten=PCA_WHITEN, svd_solver="full")
    pca.fit(Xs)

    print(f"[PCA] samples={n_samples} feat={n_feats} cap={hard_cap} -> n={k} (cum var≈{cum[k-1]:.4f}, whiten={PCA_WHITEN})")
    return Projector(scaler, pca)


# ---------------------------
# Orientation-aware LDA-like classifier
# ---------------------------
@dataclass
class Classifier:
    classes: List[str]
    means: Dict[str, np.ndarray]
    cov: np.ndarray  # shared covariance (LDA)
    inv_cov: np.ndarray
    orient_axis: np.ndarray  # direction separating flip_h vs flip_v
    flip_center: np.ndarray

    def score_components(self, x: np.ndarray) -> Dict[str, Dict[str, float]]:
        comps: Dict[str, Dict[str, float]] = {}
        for c in self.classes:
            mu = self.means[c]
            # Cosine similarity
            cos = float(np.dot(x, mu) / (np.linalg.norm(x) * np.linalg.norm(mu) + 1e-9))
            # Mahalanobis squared distance
            d = x - mu
            maha2 = float(d.T @ self.inv_cov @ d)
            comps[c] = {"cos": cos, "maha2": maha2}
        return comps

    def orient_logit(self, x: np.ndarray) -> float:
        # Signed alignment along flip_h vs flip_v axis relative to their midpoint
        dx = x - self.flip_center
        num = float(np.dot(dx, self.orient_axis))
        den = float(np.linalg.norm(dx) * np.linalg.norm(self.orient_axis) + 1e-9)
        s = num / den
        # squash to (−1,1)
        return float(np.tanh(2.5 * s))

    def fused_logits(self, x: np.ndarray) -> Dict[str, float]:
        comps = self.score_components(x)
        # base fusion without orientation
        logits = {}
        for c, z in comps.items():
            logits[c] = W_COS * z["cos"] + W_MAHA * (-z["maha2"])  # larger is better
        # orientation bonus only for flips
        o = self.orient_logit(x)
        logits["flip_h"] = logits.get("flip_h", 0.0) + W_ORIENT * max(0.0, o)
        logits["flip_v"] = logits.get("flip_v", 0.0) + W_ORIENT * max(0.0, -o)
        return logits

    def predict_proba(self, x: np.ndarray) -> Tuple[str, np.ndarray, Dict[str, Dict[str, float]]]:
        logits = self.fused_logits(x)
        # temperature softmax
        keys = self.classes
        z = np.array([logits[k] for k in keys], dtype=np.float64)
        z = z / max(1e-9, TEMP)
        # stabilize
        z = z - z.max()
        p = np.exp(z)
        p = p / p.sum()
        pred = keys[int(np.argmax(p))]
        return pred, p, self.score_components(x)


def fit_classifier(X: np.ndarray, y: List[str], classes: List[str]) -> Classifier:
    # class means
    means = {c: X[np.array([yy == c for yy in y])].mean(axis=0) for c in classes}
    # shared covariance (shrinkage for stability)
    Xc = np.vstack([X[np.array([yy == c for yy in y])] - means[c] for c in classes])
    cov = (Xc.T @ Xc) / max(1, Xc.shape[0] - 1)
    # add small ridge for invertibility
    cov = cov + 1e-4 * np.eye(cov.shape[0])
    inv_cov = inv(cov)

    # orientation axis: from flip_v mean to flip_h mean
    flip_h_mu = means["flip_h"]
    flip_v_mu = means["flip_v"]
    orient_axis = flip_h_mu - flip_v_mu
    flip_center = 0.5 * (flip_h_mu + flip_v_mu)

    return Classifier(classes=classes, means=means, cov=cov, inv_cov=inv_cov,
                      orient_axis=orient_axis, flip_center=flip_center)


# ---------------------------
# Benchmark
# ---------------------------
LABELS = ["flip_h", "flip_v", "rotate"]  # fixed order for stability


def prepare_data() -> Tuple[np.ndarray, List[str], np.ndarray, List[str], Projector]:
    # embed training
    X_train, y_train = [], []
    for cls, prompts in TRAIN.items():
        for s in prompts:
            X_train.append(embed(s))
            y_train.append(cls)
    X_train = np.vstack(X_train)

    proj = fit_projector(X_train)
    Xr = proj.transform(X_train)

    # embed test
    X_test, y_test = [], []
    for s, cls in TEST:
        X_test.append(embed(s))
        y_test.append(cls)
    X_test = np.vstack(X_test)
    X_test_r = proj.transform(X_test)
    return Xr, y_train, X_test_r, y_test, proj


def run(device: str = DEVICE_DEFAULT):
    load_model(device)
    Xr, y_train, Xte, y_test, proj = prepare_data()

    print(f"[Anchors] order: {LABELS}")
    clf = fit_classifier(Xr, y_train, classes=LABELS)

    # sanity: classwise average distance (euclidean in reduced space)
    sanity_dists = []
    for c in LABELS:
        mu = clf.means[c]
        d = np.linalg.norm(Xr - mu, axis=1)
        sanity_dists.append(d.mean())
    print(f"[Sanity] distances: [{' '.join(f'{v:.3f}' for v in sanity_dists)}]")
    print(f"[Sanity] chosen: {LABELS[int(np.argmin(sanity_dists))]}")

    # eval
    correct = 0
    probs_logged = []
    lines = []
    for i, (x, true) in enumerate(zip(Xte, y_test), start=1):
        pred, p, comps = clf.predict_proba(x)
        # enforce rotate margin vs flips (helps when flips dominate geometry)
        if pred == "rotate":
            # compare rotate logit vs best flip logit
            logits = clf.fused_logits(x)
            best_flip = max(logits["flip_h"], logits["flip_v"])
            if logits["rotate"] < best_flip + MARGIN_ROTATE:
                pred = "flip_h" if logits["flip_h"] >= logits["flip_v"] else "flip_v"
                # recompute a pseudo-prob vector by nudging
                order = ["flip_h", "flip_v", "rotate"]
                raw = np.array([logits[o] for o in order]) / TEMP
                raw = raw - raw.max()
                pp = np.exp(raw)
                p = pp / pp.sum()

        ok = (pred == true)
        conf = float(np.max(p))
        # dist* proxy: use mean of (euclidean to class means) weighted by prob
        eu_d = np.array([np.linalg.norm(x - clf.means[c]) for c in LABELS])
        dist_star = float(np.dot(eu_d, p))

        probs_logged.append(p)
        lines.append((i, true, pred, ok, dist_star, conf, p))

    # print per-sample
    for i, true, pred, ok, dist_star, conf, p in lines:
        print(f"[{i:02d}] true={true:<7} pred={pred:<7} ok={str(ok):<5} dist*={dist_star:.3f} conf={conf:.3f} probs=[{p[0]:.3f} {p[1]:.3f} {p[2]:.3f}]")

    # metrics
    y_pred = [pred for _, _, pred, *_ in lines]
    acc = sum(int(a == b) for a, b in zip(y_pred, y_test)) / len(y_test)
    mean_conf = float(np.mean([c for *_, c, _ in lines]))

    print(f"\n[ARC-12] Accuracy: {100*acc:.1f}% | Mean confidence: {mean_conf:.3f} | Mode=geodesic-LDA-orient")

    # confusion
    cm = confusion_matrix(y_test, y_pred, labels=LABELS)
    print("[Confusion]\n          pred→   flip_h  flip_v  rotate")
    for i, row in enumerate(cm):
        print(f"true={LABELS[i]:<8}      {row[0]:>3}     {row[1]:>3}     {row[2]:>3}")


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--device", type=str, default=DEVICE_DEFAULT)
    args = ap.parse_args()
    run(device=args.device)


usage: ipykernel_launcher.py [-h] [--mode {geodesic,nudge}] [--seed SEED]
                             [--target-var TARGET_VAR]
                             [--pca-min-dims PCA_MIN_DIMS]
                             [--pca-max-dims PCA_MAX_DIMS] [--lam LAM]
                             [--dt DT] [--steps STEPS] [--gamma GAMMA]
                             [--tau TAU] [--mass-scale MASS_SCALE]
                             [--early-stop-window EARLY_STOP_WINDOW]
ipykernel_launcher.py: error: unrecognized arguments: -f /Users/ian_moore/Library/Jupyter/runtime/kernel-62599595-5092-499d-ae39-58dfb97245aa.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
