In [1]:
import os, re, json, time, random
from datetime import datetime
from pathlib import Path
from decimal import Decimal, InvalidOperation
from typing import Optional, Tuple

import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

import outlines
from pydantic import BaseModel, Field, ValidationError

In [2]:
SEED = 20250809
random.seed(SEED)
torch.manual_seed(SEED)

# Device selection: prefer Apple Metal (MPS) on Mac
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
DEVICE

'mps'

In [3]:
# Paths
RUN_NAME = "gsm8k500_deepseek_r1d_qwen1p5b"
BASE_DIR = Path("runs") / datetime.now().strftime("%Y-%m-%d")
BASE_DIR.mkdir(parents=True, exist_ok=True)

SUBSET_PATH = BASE_DIR / f"{RUN_NAME}_indices.json"  # fixed subset indices
A_JSONL = BASE_DIR / f"{RUN_NAME}_baselineA_thinking.jsonl"
B_JSONL = BASE_DIR / f"{RUN_NAME}_baselineB_nothinking.jsonl"

def ensure_fixed_subset(dataset, n=500, key_file=SUBSET_PATH):
    if key_file.exists():
        with open(key_file, "r") as f:
            idxs = json.load(f)
    else:
        all_idxs = list(range(len(dataset)))
        random.Random(SEED).shuffle(all_idxs)
        idxs = sorted(all_idxs[:n])
        with open(key_file, "w") as f:
            json.dump(idxs, f)
    return dataset.select(idxs), idxs

def parse_gold_gsm8k(answer_text: str) -> Optional[str]:
    # GSM8K gold answers end with '#### <num>'
    m = re.search(r"####\s*([-+]?\d+(?:\.\d+)?)", answer_text.strip())
    return m.group(1).replace(",", "") if m else None

def normalize_number_str(s: Optional[str]) -> Optional[str]:
    if s is None: return None
    s = s.strip().replace(",", "")
    try:
        d = Decimal(s)
        s2 = format(d.normalize(), 'f')
        if '.' in s2:
            s2 = s2.rstrip('0').rstrip('.')
        return s2
    except (InvalidOperation, ValueError):
        return None

def equal_numbers(a: Optional[str], b: Optional[str]) -> bool:
    na, nb = normalize_number_str(a), normalize_number_str(b)
    return (na is not None) and (nb is not None) and (na == nb)

def extract_pred_fields(text: str) -> Tuple[Optional[str], Optional[float]]:
    m1 = re.search(r"Final Answer:\s*([-+]?\d+(?:\.\d+)?)", text, flags=re.IGNORECASE)
    m2 = re.search(r"Confidence:\s*(0(?:\.\d+)?|1(?:\.0+)?)", text, flags=re.IGNORECASE)
    ans = m1.group(1) if m1 else None
    conf = float(m2.group(1)) if m2 else None
    if conf is not None:
        conf = max(0.0, min(1.0, conf))
    return ans, conf

def think_answer_split(text: str) -> Tuple[str, str]:
    m = re.search(r"<think>(.*?)</think>", text, flags=re.DOTALL | re.IGNORECASE)
    if m:
        think = m.group(1)
        rest = text[:m.start()] + text[m.end():]
        return think, rest
    return "", text

def count_tokens(tokenizer, text: str) -> int:
    return len(tokenizer.encode(text, add_special_tokens=False))

In [4]:
MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

if DEVICE == "mps":
    DTYPE = torch.float16
elif DEVICE == "cuda":
    DTYPE = torch.bfloat16
else:
    DTYPE = torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id

hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE in ("cuda","mps") else None,
    trust_remote_code=True
).eval()

# Wrap the SAME HF model & tokenizer with outlines (no re-download)
omodel = outlines.from_transformers(hf_model, tokenizer)

print(f"Loaded {MODEL_ID} on {DEVICE} dtype={DTYPE}")

Loaded deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B on mps dtype=torch.float16


In [5]:
SYSTEM_BASE = (
    "You are a careful math solver. Solve grade-school math problems reliably.\n"
    "You may use <think>...</think> for private reasoning, but evaluation only reads:\n"
    "  Final Answer: <number>\n"
    "  Confidence: <0-1>\n"
    "The confidence should reflect your belief the final answer is correct."
)

USER_INSTR_A = (
    "Solve the problem step by step. You may include <think>...</think> to reason privately.\n"
    "When you are done, output exactly two lines:\n"
    "Final Answer: <number>\n"
    "Confidence: <0-1>"
)

USER_INSTR_B = (
    "Solve the problem concisely. DO NOT output any <think> tags.\n"
    "Only provide the two final lines in this exact format:\n"
    "Final Answer: <number>\n"
    "Confidence: <0-1>"
)

def build_prompt(question: str, mode: str) -> str:
    user_instr = USER_INSTR_A if mode == "A" else USER_INSTR_B
    messages = [
        {"role": "system", "content": SYSTEM_BASE},
        {"role": "user", "content": f"{user_instr}\n\nQuestion:\n{question}\n"}
    ]
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception:
        return (
            f"<|system|>\n{SYSTEM_BASE}\n<|end|>\n"
            f"<|user|>\n{user_instr}\n\nQuestion:\n{question}\n<|end|>\n<|assistant|>\n"
        )

In [6]:
def generate_freeform(prompt: str,
                      max_new_tokens: int = 512,
                      temperature: float = 0.7,
                      top_p: float = 0.95,
                      seed: int = SEED) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(hf_model.device)
    torch.manual_seed(seed)
    with torch.no_grad():
        out = hf_model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id
        )
    decoded = tokenizer.decode(out[0], skip_special_tokens=True)
    # Heuristic to strip the prompt if included in the decode
    if decoded.startswith(prompt):
        decoded = decoded[len(prompt):]
    return decoded

from typing import List, Dict, Any

def batch_iter(lst, bs):
    for i in range(0, len(lst), bs):
        yield i, lst[i:i+bs]

def generate_freeform_batched(prompts: List[str],
                              max_new_tokens: int = 512,
                              temperature: float = 0.7,
                              top_p: float = 0.95,
                              seed: int = SEED,
                              batch_size: int = 8) -> List[str]:
    """
    Efficient batched generation. Returns only the newly generated suffix for each prompt.
    """
    all_outputs = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    gen = torch.Generator(device=hf_model.device)
    gen.manual_seed(seed)

    for offset, batch_prompts in tqdm(list(batch_iter(prompts, batch_size)), desc="Generating (batched)"):
        enc = tokenizer(batch_prompts, return_tensors="pt", padding=True, add_special_tokens=False)
        # True input lengths before padding (per example)
        input_lens = enc["attention_mask"].sum(dim=1).tolist()

        enc = {k: v.to(hf_model.device) for k, v in enc.items()}

        with torch.no_grad():
            out = hf_model.generate(
                **enc,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                pad_token_id=pad_id,
                use_cache=True,
                return_dict_in_generate=True,
                output_scores=False,
            )

        seqs = out.sequences  # [B, input_len + new_len]
        # Slice off the prompt portion per item using true input lengths
        for i in range(seqs.size(0)):
            new_tokens = seqs[i, input_lens[i]:]
            text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            all_outputs.append(text)
    return all_outputs

In [7]:
# Typed schema for hard-constrained decoding
class TailSchema(BaseModel):
    final_answer: float
    confidence: float = Field(..., ge=0.0, le=1.0)

def constrained_tail_with_outlines(question: str,
                                   mode: str,
                                   max_new_tokens: int = 64,
                                   temperature: float = 0.2,
                                   top_p: float = 0.95,
                                   seed: int = SEED) -> Tuple[Optional[str], Optional[float], Optional[str]]:
    """
    Returns (final_answer_str, confidence_float, raw_json_string) using outlines typed decoding.
    Produces ONLY the structured object; we later format 'Final Answer:' and 'Confidence:' lines ourselves.
    """
    # A compact instruction that avoids generating thoughts; we just want the numbers.
    inst = (
        "Return only a JSON object with two fields:\n"
        '{ "final_answer": <number>, "confidence": <number between 0 and 1> }.\n'
        "No text or explanation."
    )
    if mode == "A":
        style = "You may think privately first, but do not output the thoughts."
    else:
        style = "Be concise and do not include any hidden or private reasoning."

    prompt = f"{inst}\n\nQuestion:\n{question}\n\n{style}"
    torch.manual_seed(seed)
    raw = omodel(prompt, TailSchema, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
    try:
        parsed = TailSchema.model_validate_json(raw)
        ans_str = normalize_number_str(str(parsed.final_answer))
        conf = float(parsed.confidence)
        conf = max(0.0, min(1.0, conf))
        return ans_str, conf, raw
    except ValidationError as e:
        return None, None, raw

In [8]:
gsm = load_dataset("openai/gsm8k", "main", split="test")
gsm500, indices = ensure_fixed_subset(gsm, n=10, key_file=SUBSET_PATH)

print("GSM8K test size:", len(gsm))
print("Using fixed subset of size:", len(gsm500))
print("Saved indices to:", SUBSET_PATH)

GSM8K test size: 1319
Using fixed subset of size: 500
Saved indices to: runs/2025-08-09/gsm8k500_deepseek_r1d_qwen1p5b_indices.json


In [9]:
def run_benchmark(split,
                  mode: str,
                  out_path: Path,
                  max_new_tokens_free: int,
                  temperature_free: float,
                  top_p_free: float,
                  desc: str):
    """
    mode: 'A' (thinking) or 'B' (no thinking)
    - Freeform pass to capture raw + <think>
    - Parse 'Final Answer'/'Confidence'
    - If missing/invalid, do a constrained repair using outlines typed decoding
    - Log JSONL rows
    """
    out_f = open(out_path, "w", encoding="utf-8")

    correct_count = 0
    conf_sum, conf_n = 0.0, 0
    repaired_ct = 0

    pbar = tqdm(total=len(split), desc=desc)
    for i, ex in enumerate(split):
        q = ex["question"]
        gold_raw = ex["answer"]
        gold = parse_gold_gsm8k(gold_raw)

        prompt = build_prompt(q, mode=mode)

        t0 = time.time()
        gen = generate_freeform(
            prompt=prompt,
            max_new_tokens=max_new_tokens_free,
            temperature=temperature_free,
            top_p=top_p_free,
            seed=SEED + i
        )
        latency_ms = int((time.time() - t0) * 1000)

        # Parse from freeform
        pred_ans, pred_conf = extract_pred_fields(gen)

        # If parsing failed, get a hard-constrained tail with outlines (latest API)
        tail_raw = None
        tail_forced = False
        if (pred_ans is None) or (pred_conf is None):
            tail_forced = True
            repaired_ct += 1
            ans_fix, conf_fix, raw_json = constrained_tail_with_outlines(
                question=q, mode=mode, seed=SEED + i
            )
            tail_raw = raw_json
            if ans_fix is not None:
                pred_ans = ans_fix
            if conf_fix is not None:
                pred_conf = conf_fix

        # Final guards
        if pred_conf is not None:
            pred_conf = max(0.0, min(1.0, float(pred_conf)))

        is_correct = equal_numbers(pred_ans, gold)

        # Token accounting (<think> vs rest)
        think_text, _ = think_answer_split(gen)
        tok_total = count_tokens(tokenizer, gen)
        tok_think = count_tokens(tokenizer, think_text) if think_text else 0
        tok_answer = tok_total - tok_think

        row = {
            "id": f"gsm8k_test_{indices[i]}",
            "dataset": "gsm8k",
            "split": "test",
            "index": indices[i],
            "prompt_version": f"mode_{mode}",
            "model": MODEL_ID,
            "quant": None,
            "seed": SEED + i,
            "temps": {"temperature": temperature_free, "top_p": top_p_free},
            "setting": {
                "device": DEVICE,
                "dtype": str(DTYPE),
                "max_new_tokens_free": max_new_tokens_free,
                "hard_tail_via_outlines": tail_forced
            },
            "final_answer": pred_ans,
            "confidence": pred_conf,
            "gold": gold,
            "correct": bool(is_correct),
            "tok_counts": {
                "generated_total": tok_total,
                "think": tok_think,
                "answer": tok_answer
            },
            "latency_ms": latency_ms,
            "raw_freeform": gen,
            "raw_tail_outlines_json": tail_raw
        }
        out_f.write(json.dumps(row, ensure_ascii=False) + "\n")

        # live stats
        correct_count += int(is_correct)
        if pred_conf is not None:
            conf_sum += pred_conf
            conf_n += 1

        acc_so_far = correct_count / (i + 1)
        avg_conf_so_far = (conf_sum / conf_n) if conf_n else 0.0
        pbar.set_postfix(acc=f"{acc_so_far:.3f}", mean_conf=f"{avg_conf_so_far:.3f}",
                         repaired=repaired_ct, t_ms=latency_ms)
        pbar.update(1)

    pbar.close()
    out_f.close()

    final_acc = correct_count / len(split)
    final_mean_conf = (conf_sum / conf_n) if conf_n else float("nan")
    print(f"\nSaved -> {out_path}")
    print(f"Final Accuracy: {final_acc:.4f}  |  Mean Confidence: {final_mean_conf:.4f}  |  Repairs: {repaired_ct}/{len(split)}")


In [10]:
def run_benchmark_batched(split,
                          mode: str,
                          out_path: Path,
                          max_new_tokens_free: int,
                          temperature_free: float,
                          top_p_free: float,
                          batch_size: int = 8,
                          desc: str = "Benchmark (batched)"):
    """
    2-step per item:
      1) Batched freeform gen to capture raw (and potential <think>).
      2) If parsing fails, repair that item with typed, hard-constrained decoding via outlines.
    """
    # Build prompts up-front so lengths are known and batching is simple
    prompts = [build_prompt(ex["question"], mode=mode) for ex in split]
    questions = [ex["question"] for ex in split]
    golds_raw = [ex["answer"] for ex in split]
    golds = [parse_gold_gsm8k(a) for a in golds_raw]

    t0 = time.time()
    gen_texts = generate_freeform_batched(
        prompts,
        max_new_tokens=max_new_tokens_free,
        temperature=temperature_free,
        top_p=top_p_free,
        seed=SEED,              # reproducible if order & batch size stay the same
        batch_size=batch_size
    )
    t1 = time.time()

    assert len(gen_texts) == len(split)

    out_f = open(out_path, "w", encoding="utf-8")
    correct_count, conf_sum, conf_n, repaired_ct = 0, 0.0, 0, 0

    pbar = tqdm(total=len(split), desc=desc)
    for i in range(len(split)):
        q = questions[i]
        gold = golds[i]
        gen = gen_texts[i]

        # Parse freeform
        pred_ans, pred_conf = extract_pred_fields(gen)

        # Hard-constrained repair (typed) if needed
        tail_forced = False
        tail_raw = None
        if (pred_ans is None) or (pred_conf is None):
            tail_forced = True
            repaired_ct += 1
            ans_fix, conf_fix, raw_json = constrained_tail_with_outlines(
                question=q, mode=mode, seed=SEED + i
            )
            tail_raw = raw_json
            if ans_fix is not None:
                pred_ans = ans_fix
            if conf_fix is not None:
                pred_conf = conf_fix

        if pred_conf is not None:
            pred_conf = max(0.0, min(1.0, float(pred_conf)))

        is_correct = equal_numbers(pred_ans, gold)

        # Token accounting for process metrics
        think_text, _ = think_answer_split(gen)
        tok_total = count_tokens(tokenizer, gen)         # new tokens only (suffix)
        tok_think = count_tokens(tokenizer, think_text) if think_text else 0
        tok_answer = max(0, tok_total - tok_think)

        row = {
            "id": f"gsm8k_test_{indices[i]}",
            "dataset": "gsm8k",
            "split": "test",
            "index": indices[i],
            "prompt_version": f"mode_{mode}",
            "model": MODEL_ID,
            "quant": None,
            "seed": SEED + i,
            "temps": {"temperature": temperature_free, "top_p": top_p_free},
            "setting": {
                "device": DEVICE,
                "dtype": str(DTYPE),
                "max_new_tokens_free": max_new_tokens_free,
                "batch_size": batch_size,
                "hard_tail_via_outlines": tail_forced
            },
            "final_answer": pred_ans,
            "confidence": pred_conf,
            "gold": gold,
            "correct": bool(is_correct),
            "tok_counts": {
                "generated_total": tok_total,
                "think": tok_think,
                "answer": tok_answer
            },
            # per-item latency not meaningful in batched mode; include batch wall time instead
            "latency_ms": None,
            "raw_freeform": gen,
            "raw_tail_outlines_json": tail_raw
        }
        out_f.write(json.dumps(row, ensure_ascii=False) + "\n")

        # live stats
        correct_count += int(is_correct)
        if pred_conf is not None:
            conf_sum += pred_conf
            conf_n += 1

        acc_so_far = correct_count / (i + 1)
        avg_conf_so_far = (conf_sum / conf_n) if conf_n else 0.0
        pbar.set_postfix(acc=f"{acc_so_far:.3f}", mean_conf=f"{avg_conf_so_far:.3f}",
                         repaired=repaired_ct, batches_time_s=f"{(t1-t0):.1f}")
        pbar.update(1)

    pbar.close()
    out_f.close()

    final_acc = correct_count / len(split)
    final_mean_conf = (conf_sum / conf_n) if conf_n else float("nan")
    print(f"\nSaved -> {out_path}")
    print(f"Final Accuracy: {final_acc:.4f} | Mean Confidence: {final_mean_conf:.4f} | Repairs: {repaired_ct}/{len(split)}")
    print(f"Batched generation wall time: {(t1 - t0):.2f}s for {len(split)} items (bs={batch_size})")


In [11]:
TEMP = 0.7
TOP_P = 0.95
BATCH_SIZE = 8

print("Baseline A (thinking, batched)...")
run_benchmark_batched(
    split=gsm500,
    mode="A",
    out_path=A_JSONL,
    max_new_tokens_free=512,
    temperature_free=TEMP,
    top_p_free=TOP_P,
    batch_size=BATCH_SIZE,
    desc="Baseline A (batched)"
)

print("\nBaseline B (no thinking, batched)...")
run_benchmark_batched(
    split=gsm500,
    mode="B",
    out_path=B_JSONL,
    max_new_tokens_free=256,
    temperature_free=TEMP,
    top_p_free=TOP_P,
    batch_size=BATCH_SIZE,
    desc="Baseline B (batched)"
)

Baseline A (thinking, batched)...


Generating (batched):   0%|          | 0/63 [00:00<?, ?it/s]

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [None]:
def summarize_jsonl(path: Path):
    n = 0
    correct = 0
    conf_sum = 0.0
    conf_n = 0
    repaired = 0
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            n += 1
            obj = json.loads(line)
            correct += int(bool(obj.get("correct", False)))
            c = obj.get("confidence", None)
            if isinstance(c, (int, float)):
                conf_sum += float(c)
                conf_n += 1
            if obj.get("setting", {}).get("hard_tail_via_outlines", False):
                repaired += 1
    acc = correct / n if n else float("nan")
    mean_conf = conf_sum / conf_n if conf_n else float("nan")
    print(f"{path.name}: n={n} | accuracy={acc:.4f} | mean_conf={mean_conf:.4f} | constrained_repairs={repaired}")

summarize_jsonl(A_JSONL)
summarize_jsonl(B_JSONL)

print("\nArtifacts:")
print("Fixed subset indices:", SUBSET_PATH)
print("Baseline A JSONL:", A_JSONL)
print("Baseline B JSONL:", B_JSONL)