# RAUQ-CoT Experiments & Ablations

This notebook prepares calibration artifacts, runs the baseline RAUQ-triggered controller, and executes targeted ablations. Execute the cells sequentially on a GPU-enabled runtime. Update the configuration cell before running if you need to change models, datasets, or output locations. Ensure you have access to the chosen Hugging Face model (login via `huggingface-cli login` if required).

In [None]:
# Install runtime dependencies (run once per environment)
!pip install -q transformers accelerate bitsandbytes datasets evaluate scikit-learn numpy

In [None]:
from pathlib import Path
import os
import random

import numpy as np
import torch

REPO_ROOT = Path.cwd()
ARTIFACTS_DIR = REPO_ROOT / "artifacts"
CALIB_DIR = REPO_ROOT / "data" / "calibration"
EVAL_DIR = ARTIFACTS_DIR / "evals"
ABLATION_DIR = ARTIFACTS_DIR / "ablations"

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
TOKENIZER_NAME = None  # Set to override tokenizer (defaults to MODEL_NAME)
BENCHMARK_NAME = "gsm8k"

CALIBRATION_SIZE = None  # Number of calibration prompt/completion pairs; set to None to use all available samples
CALIBRATION_FILE = CALIB_DIR / "gsm8k_calibration.jsonl"

HEADS_PATH = ARTIFACTS_DIR / "qwen25_heads.json"
THRESHOLD_PATH = ARTIFACTS_DIR / "qwen25_theta.json"

MAX_NEW_TOKENS = None
ALPHA = 0.3
DEVICE = "cuda"
EVAL_LIMIT = None  # Set to an int for smoke tests (e.g., 32)
SEED = 0

for path in (ARTIFACTS_DIR, CALIB_DIR, EVAL_DIR, ABLATION_DIR):
    path.mkdir(parents=True, exist_ok=True)

os.environ["PYTHONPATH"] = str(REPO_ROOT)
os.environ.setdefault("HF_HOME", str(REPO_ROOT / ".hf_cache"))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"Repo root: {REPO_ROOT}")
print(f"Model: {MODEL_NAME}")
print(f"Calibration file: {CALIBRATION_FILE}")
print(f"Artifacts dir: {ARTIFACTS_DIR}")
print(f"Evaluation limit: {EVAL_LIMIT}")

In [None]:
import json

from ucot.data.benchmarks import load_benchmark

def prepare_calibration_file() -> None:
    if CALIBRATION_FILE.exists():
        print(f"Calibration file already present: {CALIBRATION_FILE}")
        return

    try:
        calibration_samples = load_benchmark(BENCHMARK_NAME, split="train", limit=CALIBRATION_SIZE)
    except Exception as exc:
        print(f"Train split unavailable ({exc}); falling back to default evaluation split.")
        calibration_samples = load_benchmark(BENCHMARK_NAME, limit=CALIBRATION_SIZE)

    if not calibration_samples:
        raise ValueError("No calibration samples fetched; update BENCHMARK_NAME or CALIBRATION_SIZE.")

    with CALIBRATION_FILE.open("w") as fp:
        for sample in calibration_samples:
            payload = {"prompt": sample.prompt, "completion": sample.reference}
            fp.write(json.dumps(payload) + "\n")

    print(f"Wrote {len(calibration_samples)} calibration examples to {CALIBRATION_FILE}")

prepare_calibration_file()

In [None]:
from ucot.config import HeadSelectionConfig
from ucot.head_selection import select_uncertainty_heads

head_config = HeadSelectionConfig(
    calibration_paths=[CALIBRATION_FILE],
    model_name=MODEL_NAME,
    tokenizer_name=TOKENIZER_NAME,
    output_path=HEADS_PATH,
    num_examples=CALIBRATION_SIZE,
    device=DEVICE,
)

torch.cuda.empty_cache()
head_result = select_uncertainty_heads(head_config)
print(f"Saved head selection to {HEADS_PATH}")
print(f"Layers used ({len(head_result.layers_used)}): {head_result.layers_used}")

In [None]:
from ucot.config import ThresholdTrainingConfig
from ucot.threshold import train_threshold

max_pairs = None if CALIBRATION_SIZE is None else min(2048, CALIBRATION_SIZE)
threshold_config = ThresholdTrainingConfig(
    calibration_paths=[CALIBRATION_FILE],
    model_name=MODEL_NAME,
    tokenizer_name=TOKENIZER_NAME,
    head_indices_path=HEADS_PATH,
    output_path=THRESHOLD_PATH,
    alpha=ALPHA,
    max_samples=max_pairs,
    device=DEVICE,
)
threshold_result = train_threshold(threshold_config)
print(f"Learned theta: {threshold_result.theta:.4f}")

In [None]:
import json
import time
from statistics import mean
from typing import Optional

import torch

from ucot.config import ControllerConfig, RAUQConfig
from ucot.controller import RAUQController
from ucot.data.benchmarks import load_benchmark
from ucot.experiments.metrics import METRICS, exact_match
from ucot.rauq import RAUQScorer
from ucot.threshold import ThresholdResult
from ucot.uncertainty import EntropyScorer, LogitMarginScorer, RAUQScorerWrapper
from ucot.utils.model import load_model

_LOADED_MODEL = None

def get_loaded_model():
    global _LOADED_MODEL
    if _LOADED_MODEL is None:
        _LOADED_MODEL = load_model(
            model_name=MODEL_NAME,
            tokenizer_name=TOKENIZER_NAME,
            device=DEVICE,
        )
    return _LOADED_MODEL

def ensure_cuda_sync():
    if torch.cuda.is_available():
        torch.cuda.synchronize()

def compute_summary(records, latencies):
    accuracy = mean(r["correct"] for r in records) if records else 0.0
    avg_generated = mean(r["generated_tokens"] for r in records) if records else 0.0
    avg_total = mean(r["total_tokens"] for r in records) if records else 0.0
    avg_triggers = mean(r["triggers"] for r in records) if records else 0.0
    tokens_total = sum(r["generated_tokens"] for r in records)
    correct_total = sum(r["correct"] for r in records)
    tokens_per_correct = tokens_total / max(correct_total, 1)
    avg_latency = mean(latencies) if latencies else 0.0
    return {
        "samples": len(records),
        "accuracy": accuracy,
        "avg_generated_tokens": avg_generated,
        "avg_total_tokens": avg_total,
        "avg_triggers": avg_triggers,
        "tokens_per_correct": tokens_per_correct,
        "total_generated_tokens": tokens_total,
        "avg_latency_sec": avg_latency,
    }

def print_summary(name, summary, out_path):
    print(f"=== {name} ===")
    for key, value in summary.items():
        if isinstance(value, float):
            print(f"{key}: {value:.4f}")
        else:
            print(f"{key}: {value}")
    print(f"Artifacts: {out_path}")

def run_vanilla_greedy(experiment_name: str, temperature: float = 0.0, top_p: float = 1.0, show_progress: bool = True):
    loaded = get_loaded_model()
    model = loaded.model
    tokenizer = loaded.tokenizer

    dataset = load_benchmark(BENCHMARK_NAME, limit=EVAL_LIMIT)
    metric_fn = METRICS.get(BENCHMARK_NAME, exact_match)
    results = []
    latencies = []

    progress_bar = None
    if show_progress:
        try:
            from tqdm.auto import tqdm
            progress_bar = tqdm(total=len(dataset), desc=f"{experiment_name}", unit="sample")
        except ImportError:
            progress_bar = None

    model.eval()
    for sample in dataset:
        inputs = tokenizer(sample.prompt, return_tensors="pt").to(model.device)
        ensure_cuda_sync()
        start = time.perf_counter()
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=temperature > 0.0,
                temperature=temperature if temperature > 0.0 else 1.0,
                top_p=top_p,
            )
        ensure_cuda_sync()
        latency = time.perf_counter() - start
        completion_ids = outputs[:, inputs["input_ids"].shape[1]:]
        completion = tokenizer.decode(completion_ids[0], skip_special_tokens=True)
        is_correct = bool(metric_fn({"reference": sample.reference}, completion))
        results.append({
            "id": sample.metadata.get("id") or sample.metadata.get("task_id") or sample.metadata.get("level"),
            "correct": int(is_correct),
            "triggers": 0,
            "generated_tokens": int(completion_ids.shape[1]),
            "total_tokens": int(outputs.shape[1]),
            "extra_tokens": int(completion_ids.shape[1]),
            "latency_sec": latency,
        })
        latencies.append(latency)
        if progress_bar is not None:
            progress_bar.update(1)

    if progress_bar is not None:
        progress_bar.close()

    summary = compute_summary(results, latencies)
    payload = {"name": experiment_name, "summary": summary, "records": results}
    out_path = EVAL_DIR / f"{experiment_name}.json"
    out_path.write_text(json.dumps(payload, indent=2))
    print_summary(experiment_name, summary, out_path)
    return summary

def build_scorer(trigger: str, alpha: float, num_layers: int):
    if trigger == "rauq":
        rauq_config = RAUQConfig(alpha=alpha, head_indices_path=HEADS_PATH, device=DEVICE)
        base = RAUQScorer.from_config(rauq_config, num_layers=num_layers)
        return RAUQScorerWrapper(
            alpha=base.alpha,
            head_indices=base.head_indices,
            layers=base.layers,
            eps=base.eps,
            device=base.device,
        )
    if trigger == "entropy":
        return EntropyScorer()
    if trigger == "margin":
        return LogitMarginScorer()
    raise ValueError(f"Unsupported trigger: {trigger}")

def configure_cot(controller_config, policy: str, override_length: Optional[int]):
    if policy == "rauq":
        controller_config.cot.stop_mode = "rauq"
        if override_length is not None:
            controller_config.cot.max_cot_tokens = override_length
    elif policy == "max20":
        controller_config.cot.stop_mode = "fixed"
        controller_config.cot.max_cot_tokens = override_length or 20
    elif policy == "unlimited":
        controller_config.cot.stop_mode = "none"
        controller_config.cot.max_cot_tokens = override_length or 200
    elif policy == "none":
        controller_config.cot.stop_mode = "none"
        controller_config.cot.max_cot_tokens = override_length or 0
    else:
        raise ValueError(f"Unknown CoT policy: {policy}")
    controller_config.cot.cot_prefix = (
        f"Wait, let's quickly think step by step about this (<{controller_config.cot.max_cot_tokens} tokens)."
    )

def run_controller_experiment(
    experiment_name: str,
    trigger: str = "rauq",
    repair: str = "cot",
    cot_policy: str = "rauq",
    cot_length: Optional[int] = None,
    rollback_depth: int = 2,
    rollback_mode: str = "fixed",
    cooldown: int = 5,
    stability_window: int = 2,
    max_triggers: Optional[int] = 5,
    alpha: Optional[float] = None,
    temperature: float = 0.0,
    top_p: float = 1.0,
    theta: Optional[float] = None,
    use_learned_threshold: bool = True,
    show_progress: bool = True,
):
    alpha = alpha if alpha is not None else ALPHA
    loaded = get_loaded_model()
    num_layers = loaded.model.config.num_hidden_layers
    scorer = build_scorer(trigger, alpha, num_layers)
    threshold = None
    theta_value = theta

    if use_learned_threshold:
        threshold = ThresholdResult.load(THRESHOLD_PATH)
        theta_value = threshold.theta if theta_value is None else theta_value
    else:
        if theta_value is None:
            raise ValueError("Provide theta when use_learned_threshold=False")

    controller_config = ControllerConfig(
        model_name=MODEL_NAME,
        tokenizer_name=TOKENIZER_NAME,
        theta=theta_value,
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=temperature,
        top_p=top_p,
        alpha=alpha,
    )
    configure_cot(controller_config, cot_policy, cot_length)
    controller_config.repair_strategy = repair
    controller_config.rollback.rollback_depth = rollback_depth
    controller_config.rollback.mode = rollback_mode
    controller_config.rollback.cooldown = cooldown
    controller_config.rollback.stability_window = stability_window
    controller_config.rollback.max_triggers = max_triggers

    controller = RAUQController(
        loaded=loaded,
        config=controller_config,
        scorer=scorer,
        threshold=threshold if use_learned_threshold else None,
    )

    dataset = load_benchmark(BENCHMARK_NAME, limit=EVAL_LIMIT)
    metric_fn = METRICS.get(BENCHMARK_NAME, exact_match)
    results = []
    latencies = []

    progress_bar = None
    if show_progress:
        try:
            from tqdm.auto import tqdm
            progress_bar = tqdm(total=len(dataset), desc=f"{experiment_name}", unit="sample")
        except ImportError:
            progress_bar = None

    for sample in dataset:
        ensure_cuda_sync()
        start = time.perf_counter()
        output = controller.generate(sample.prompt)
        ensure_cuda_sync()
        latency = time.perf_counter() - start
        is_correct = bool(metric_fn({"reference": sample.reference}, output.completion))
        results.append({
            "id": sample.metadata.get("id") or sample.metadata.get("task_id") or sample.metadata.get("level"),
            "correct": int(is_correct),
            "triggers": len(output.trigger_events),
            "generated_tokens": len(output.completion_tokens),
            "total_tokens": output.total_tokens,
            "extra_tokens": output.extra_tokens,
            "latency_sec": latency,
        })
        latencies.append(latency)
        if progress_bar is not None:
            progress_bar.update(1)

    if progress_bar is not None:
        progress_bar.close()

    summary = compute_summary(results, latencies)
    payload = {
        "name": experiment_name,
        "summary": summary,
        "records": results,
        "config": {
            "trigger": trigger,
            "repair": repair,
            "cot_policy": cot_policy,
            "cot_length": cot_length,
            "rollback_depth": rollback_depth,
            "rollback_mode": rollback_mode,
            "cooldown": cooldown,
            "stability_window": stability_window,
            "max_triggers": max_triggers,
            "alpha": alpha,
            "temperature": temperature,
            "top_p": top_p,
            "theta": theta_value,
            "use_learned_threshold": use_learned_threshold,
        },
    }
    out_path = ABLATION_DIR / f"{experiment_name}.json"
    out_path.write_text(json.dumps(payload, indent=2))
    print_summary(experiment_name, summary, out_path)
    return summary

In [None]:
# Experiment 0 — Plain greedy decoding (no RAUQ controller)
greedy_summary = run_vanilla_greedy("baseline_greedy")
greedy_summary

In [None]:
# Experiment 1 — RAUQ-triggered controller with CoT repair (baseline)
baseline_summary = run_controller_experiment(
    experiment_name="rauq_cot_baseline",
    trigger="rauq",
    repair="cot",
    cot_policy="rauq",
    rollback_depth=2,
    rollback_mode="fixed",
    max_triggers=5,
)
baseline_summary

In [None]:
# Experiment 2 — RAUQ controller without micro-CoT repair
no_cot_summary = run_controller_experiment(
    experiment_name="rauq_no_repair",
    trigger="rauq",
    repair="none",
    cot_policy="none",
    cot_length=0,
    rollback_depth=2,
    rollback_mode="fixed",
)
no_cot_summary

In [None]:
# Experiment 3 — RAUQ controller with rerank repair
rerank_summary = run_controller_experiment(
    experiment_name="rauq_rerank",
    trigger="rauq",
    repair="rerank",
    cot_policy="rauq",
    rollback_depth=2,
    rollback_mode="fixed",
)
rerank_summary

In [None]:
# Experiment 4 — RAUQ controller with anchor-style rollback
anchor_summary = run_controller_experiment(
    experiment_name="rauq_cot_anchor",
    trigger="rauq",
    repair="cot",
    cot_policy="rauq",
    rollback_depth=3,
    rollback_mode="anchor",
)
anchor_summary

In [None]:
# Experiment 5 — RAUQ controller with fixed-length CoT (20 tokens)
fixedcot_summary = run_controller_experiment(
    experiment_name="rauq_cot_fixed20",
    trigger="rauq",
    repair="cot",
    cot_policy="max20",
    cot_length=20,
    rollback_depth=2,
    rollback_mode="fixed",
)
fixedcot_summary

In [None]:
# Experiment 6 — RAUQ controller with lighter attention weighting (alpha = 0.1)
# Note: Reuses the baseline theta; rerun threshold calibration if you expect to productionize this variant.
alpha_low_summary = run_controller_experiment(
    experiment_name="rauq_cot_alpha0_1",
    trigger="rauq",
    repair="cot",
    cot_policy="rauq",
    rollback_depth=2,
    rollback_mode="fixed",
    alpha=0.1,
)
alpha_low_summary

In [None]:
# Experiment 7 — RAUQ controller with heavier attention weighting (alpha = 0.5)
# Note: Reuses the baseline theta; calibrate a dedicated threshold for a fairer comparison if needed.
alpha_high_summary = run_controller_experiment(
    experiment_name="rauq_cot_alpha0_5",
    trigger="rauq",
    repair="cot",
    cot_policy="rauq",
    rollback_depth=2,
    rollback_mode="fixed",
    alpha=0.5,
)
alpha_high_summary

In [None]:
# Aggregate experiment summaries for quick comparison
experiment_summaries = {
    "baseline_greedy": greedy_summary,
    "rauq_cot_baseline": baseline_summary,
    "rauq_no_repair": no_cot_summary,
    "rauq_rerank": rerank_summary,
    "rauq_cot_anchor": anchor_summary,
    "rauq_cot_fixed20": fixedcot_summary,
    "rauq_cot_alpha0_1": alpha_low_summary,
    "rauq_cot_alpha0_5": alpha_high_summary,
}

print("name".ljust(26), "accuracy", "avg_triggers", "avg_tokens", "avg_latency", sep=" | ")
for name, summary in experiment_summaries.items():
    if summary is None:
        continue
    print(
        name.ljust(26),
        f"{summary['accuracy']:.4f}",
        f"{summary['avg_triggers']:.2f}",
        f"{summary['avg_generated_tokens']:.1f}",
        f"{summary['avg_latency_sec']:.2f}s",
        sep=" | ",
    )