# Task
Use Labelled data to create an LLM judge.

* Understand the data
* Split data into train/dev/test
* Write a judge prompt with some few shot examples from train
* Iterate on dev
* Measure TPR TNR on test
* Use Judgy to estimate unbiased performance and confidence interval



In [None]:
import json
from typing import Literal

from pydantic import BaseModel


class LabeledTrace(BaseModel):
    query: str
    dietary_restriction: str
    response: str
    success: bool
    error: str | None
    trace_id: str
    query_id: str
    label: Literal["PASS", "FAIL"]
    reasoning: str
    confidence: Literal["HIGH", "MEDIUM", "LOW"]
    labeled: bool


with open("reference_files/labeled_traces.jsonl") as f:
    traces: list[LabeledTrace] = [LabeledTrace(**json.loads(line)) for line in f]

print(f"Loaded {len(traces)} traces")
traces[0]

In [None]:
import polars as pl
from polars import DataFrame

df = pl.DataFrame([t.model_dump() for t in traces])
df.head()

In [None]:
import altair as alt

diet_counts: DataFrame = df.group_by("dietary_restriction").len().rename({"len": "count"})

def stacked_bar(df, col):
    data = df.group_by("dietary_restriction", col).len().rename({"len": "count"})
    return alt.Chart(data).mark_bar().encode(
        x="dietary_restriction:N",
        y="count:Q",
        color=f"{col}:N",
        tooltip=["dietary_restriction", col, "count"],
    ).properties(title=f"{col} by Dietary Restriction", width=400, height=250)

Ok - so it looks like they are all successes

In [None]:
stacked_bar(df, "success")

There are a few categories like whole30 where we have only FAIL examples

In [None]:
stacked_bar(df, "label")

The confidence level is almost always high - not sure where it is from, looks like we struggled with diabetic-friendly & gluten free though

In [None]:
stacked_bar(df, "confidence")

# Split Data

In [None]:
from sklearn.model_selection import train_test_split

# Stratified split: 20% train, 40% dev, 40% test
train_df, temp_df = train_test_split(
    df, test_size=0.8, random_state=42, stratify=df["label"]
)
dev_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=42, stratify=temp_df["label"]
)

# Balance each split by downsampling PASS to match FAIL count
def balance_on_label(split: pl.DataFrame) -> pl.DataFrame:
    passes = split.filter(pl.col("label") == "PASS")
    fails = split.filter(pl.col("label") == "FAIL")
    n = min(len(passes), len(fails))
    return pl.concat([passes.sample(n, seed=42), fails.sample(n, seed=42)]).sample(
        fraction=1.0, seed=42  # shuffle
    )

train_bal = balance_on_label(train_df)
dev_bal = balance_on_label(dev_df)
test_bal = balance_on_label(test_df)

for name, split in [("train", train_bal), ("dev", dev_bal), ("test", test_bal)]:
    counts = split.group_by("label").len().sort("label")
    print(f"{name}: {len(split)} rows — {dict(zip(counts['label'], counts['len']))}")

In [None]:
# Judge Prompt

class JudgeResponse(BaseModel):
    verdict: str
    justification: str


def _format_example(trace: LabeledTrace) -> str:
    return (
        f"<example>\n"
        f"<query>{trace.query}</query>\n"
        f"<dietary_restriction>{trace.dietary_restriction}</dietary_restriction>\n"
        f"<response>{trace.response}</response>\n"
        f"Judge: {{'verdict': '{trace.label}', 'justification': '{trace.reasoning}'}}\n"
        f"</example>"
    )


def judge_prompt(examples: list[LabeledTrace], n_examples: int = 1) -> str:
    passes = [e for e in examples if e.label == "PASS"][:n_examples]
    fails = [e for e in examples if e.label == "FAIL"][:n_examples]

    pass_section = "\n\n".join(_format_example(e) for e in passes)
    fail_section = "\n\n".join(_format_example(e) for e in fails)

    prompt = f"""# Scenario
You are a dietician assessing the output of a new automatic recipe bot.
Users have provided their dietary requirements and sometimes the bot doesnt fully meet them.

# Task
Assess the following user query, dietary requirement and response from the bot
Mark it either a PASS or a FAIL:
* PASS - The response provided meets the dietary requirement
* FAIL - Any aspect of the response clashes with the users dietary requirement

# Examples

## Pass
{pass_section}

## Fail
{fail_section}

# Trace to Judge
<query>{{query}}</query>
<dietary_restriction>{{dietary_restriction}}</dietary_restriction>
<response>{{response}}</response>


# Formatting
Return your response as json
Use the format `{{'verdict':'PASS', 'justification':'sentence or two on reason for the verdict'}}`
"""
    return prompt


train_examples: list[LabeledTrace] = [LabeledTrace(**d) for d in train_bal.to_dicts()]
print(judge_prompt(train_examples, n_examples=1))

# Run Judge on Dev

In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed, Future

import litellm
from dotenv import load_dotenv
from tqdm.notebook import tqdm

load_dotenv()

MODEL_NAME = "gpt-4o-mini"


def judge_trace(trace: LabeledTrace, prompt_template: str) -> JudgeResponse:
    """Run the judge on a single trace and return the parsed response."""
    prompt = prompt_template.format(
        query=trace.query,
        dietary_restriction=trace.dietary_restriction,
        response=trace.response,
    )
    resp = litellm.completion(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": prompt}],
        response_format=JudgeResponse,
        temperature=0,
    )
    return JudgeResponse(**json.loads(resp.choices[0].message.content))


def run_judge(
    traces: list[LabeledTrace], prompt_template: str, desc: str = "Judging"
) -> list[tuple[LabeledTrace, JudgeResponse]]:
    """Run judge on a list of traces in parallel and return (trace, response) pairs."""
    results: list[tuple[LabeledTrace, JudgeResponse]] = []
    with ThreadPoolExecutor(max_workers=16) as executor:
        futures: dict[Future[JudgeResponse], LabeledTrace] = {
            executor.submit(judge_trace, trace, prompt_template): trace
            for trace in traces
        }
        for fut in tqdm(as_completed(futures), total=len(futures), desc=desc):
            results.append((futures[fut], fut.result()))
    return results


def calc_metrics(
    results: list[tuple[LabeledTrace, JudgeResponse]],
) -> dict[str, float]:
    """Calculate TPR, TNR, and accuracy for a set of judge results."""
    tp = sum(1 for t, r in results if t.label == "PASS" and r.verdict == "PASS")
    fn = sum(1 for t, r in results if t.label == "PASS" and r.verdict == "FAIL")
    tn = sum(1 for t, r in results if t.label == "FAIL" and r.verdict == "FAIL")
    fp = sum(1 for t, r in results if t.label == "FAIL" and r.verdict == "PASS")

    return {
        "TPR": tp / (tp + fn) if (tp + fn) > 0 else 0.0,
        "TNR": tn / (tn + fp) if (tn + fp) > 0 else 0.0,
        "accuracy": (tp + tn) / len(results) if results else 0.0,
        "tp": tp, "fn": fn, "tn": tn, "fp": fp,
    }


def print_metrics(metrics: dict[str, float], name: str) -> None:
    tp, fn, tn, fp = metrics["tp"], metrics["fn"], metrics["tn"], metrics["fp"]
    print(f"{name} results (n={int(tp+fn+tn+fp)}):")
    print(f"  TPR (sensitivity): {metrics['TPR']:.1%}  ({tp:.0f}/{tp+fn:.0f})")
    print(f"  TNR (specificity): {metrics['TNR']:.1%}  ({tn:.0f}/{tn+fp:.0f})")
    print(f"  Accuracy: {metrics['accuracy']:.1%}")



In [None]:
prompt_template: str = judge_prompt(train_examples, n_examples=2)
dev_traces: list[LabeledTrace] = [LabeledTrace(**d) for d in dev_bal.to_dicts()]

dev_results: list[tuple[LabeledTrace, JudgeResponse]] = run_judge(
    dev_traces, prompt_template, desc="Judging dev set"
)
dev_metrics = calc_metrics(dev_results)
print_metrics(dev_metrics, "Dev set")

# Bias & Confidence Interval

In [None]:
from judgy import Judgy

# Create judgy instance with TPR/TNR from dev set
j = Judgy(tpr=dev_metrics["TPR"], tnr=dev_metrics["TNR"])

# Calculate corrected success rate
raw_pass_rate = sum(
    1 for _, judge_resp in dev_results if judge_resp.verdict == "PASS"
) / len(dev_results)
corrected = j.correct(raw_pass_rate)

print(f"Raw observed pass rate: {raw_pass_rate:.1%}")
print(f"Corrected pass rate: {corrected:.1%}")

# Mistake Viewer

In [None]:
from IPython.display import display, Markdown, clear_output
import ipywidgets as widgets


def _render_error(trace: LabeledTrace, judge: JudgeResponse, error_type: str) -> str:
    return f"""### {error_type}: expected **{trace.label}**, judge said **{judge.verdict}**

**Query:** {trace.query}

**Dietary Restriction:** {trace.dietary_restriction}

**Response (truncated):** {trace.response[:500]}{"..." if len(trace.response) > 500 else ""}

**Human reasoning:** {trace.reasoning}

**Judge justification:** {judge.justification}

---
"""


false_negatives = [(t, r) for t, r in results if t.label == "PASS" and r.verdict == "FAIL"]
false_positives = [(t, r) for t, r in results if t.label == "FAIL" and r.verdict == "PASS"]

errors = (
    [("False Negative", t, r) for t, r in false_negatives]
    + [("False Positive", t, r) for t, r in false_positives]
)

if not errors:
    print("No false positives or false negatives!")
else:
    idx_state = {"current": 0}
    label = widgets.HTML(value="")
    out = widgets.Output()

    def show(idx):
        error_type, trace, judge = errors[idx]
        label.value = f"<b>{idx + 1} / {len(errors)}</b> — {error_type}"
        with out:
            clear_output(wait=True)
            display(Markdown(_render_error(trace, judge, error_type)))

    def on_prev(_):
        idx_state["current"] = max(0, idx_state["current"] - 1)
        show(idx_state["current"])

    def on_next(_):
        idx_state["current"] = min(len(errors) - 1, idx_state["current"] + 1)
        show(idx_state["current"])

    prev_btn = widgets.Button(description="← Prev")
    next_btn = widgets.Button(description="Next →")
    prev_btn.on_click(on_prev)
    next_btn.on_click(on_next)

    display(widgets.HBox([prev_btn, next_btn, label]))
    display(out)
    show(0)

# Run Judge on Test

In [None]:
test_traces = [LabeledTrace(**d) for d in test_bal.to_dicts()]

test_results = run_judge(test_traces, prompt_template, desc="Judging test set")
test_metrics = calc_metrics(test_results)
print_metrics(test_metrics, "Test set")