# Task
Use Labelled data to create an LLM judge.

* Understand the data
* Split data into train/dev/test
* Write a judge prompt with some few shot examples from train
* Iterate on dev
* Measure TPR TNR on test
* Use Judgy to estimate unbiased performance and confidence interval



In [None]:
import json
from typing import Literal

from pydantic import BaseModel


class LabeledTrace(BaseModel):
    query: str
    dietary_restriction: str
    response: str
    success: bool
    error: str | None
    trace_id: str
    query_id: str
    label: Literal["PASS", "FAIL"]
    reasoning: str
    confidence: Literal["HIGH", "MEDIUM", "LOW"]
    labeled: bool


with open("reference_files/labeled_traces.jsonl") as f:
    traces: list[LabeledTrace] = [LabeledTrace(**json.loads(line)) for line in f]

print(f"Loaded {len(traces)} traces")
traces[0]

In [None]:
import polars as pl
from polars import DataFrame

df = pl.DataFrame([t.model_dump() for t in traces])
df.head()

In [None]:
import altair as alt

diet_counts: DataFrame = df.group_by("dietary_restriction").len().rename({"len": "count"})

def stacked_bar(df, col):
    data = df.group_by("dietary_restriction", col).len().rename({"len": "count"})
    return alt.Chart(data).mark_bar().encode(
        x="dietary_restriction:N",
        y="count:Q",
        color=f"{col}:N",
        tooltip=["dietary_restriction", col, "count"],
    ).properties(title=f"{col} by Dietary Restriction", width=400, height=250)

Ok - so it looks like they are all successes

In [None]:
stacked_bar(df, "success")

There are a few categories like whole30 where we have only FAIL examples

In [None]:
stacked_bar(df, "label")

The confidence level is almost always high - not sure where it is from, looks like we struggled with diabetic-friendly & gluten free though

In [None]:
stacked_bar(df, "confidence")

# Split Data

In [None]:
from sklearn.model_selection import train_test_split

# Stratified split: 20% train, 40% dev, 40% test
train_df, temp_df = train_test_split(
    df, test_size=0.8, random_state=42, stratify=df["label"]
)
dev_df, test_df = train_test_split(
    temp_df, test_size=0.5, random_state=42, stratify=temp_df["label"]
)

# Balance each split by downsampling PASS to match FAIL count
def balance_on_label(split: pl.DataFrame) -> pl.DataFrame:
    passes = split.filter(pl.col("label") == "PASS")
    fails = split.filter(pl.col("label") == "FAIL")
    n = min(len(passes), len(fails))
    return pl.concat([passes.sample(n, seed=42), fails.sample(n, seed=42)]).sample(
        fraction=1.0, seed=42  # shuffle
    )

train_bal = balance_on_label(train_df)
dev_bal = balance_on_label(dev_df)
test_bal = balance_on_label(test_df)

for name, split in [("train", train_bal), ("dev", dev_bal), ("test", test_bal)]:
    counts = split.group_by("label").len().sort("label")
    print(f"{name}: {len(split)} rows â€” {dict(zip(counts['label'], counts['len']))}")

In [None]:
# Judge Prompt

class JudgeResponse(BaseModel):
    verdict: str
    justification: str


def _format_example(trace: LabeledTrace) -> str:
    return (
        f"<example>\n"
        f"<query>{trace.query}</query>\n"
        f"<dietary_restriction>{trace.dietary_restriction}</dietary_restriction>\n"
        f"<response>{trace.response}</response>\n"
        f"Judge: {{'verdict': '{trace.label}', 'justification': '{trace.reasoning}'}}\n"
        f"</example>"
    )


def judge_prompt(examples: list[LabeledTrace], n_examples: int = 1) -> str:
    passes = [e for e in examples if e.label == "PASS"][:n_examples]
    fails = [e for e in examples if e.label == "FAIL"][:n_examples]

    pass_section = "\n\n".join(_format_example(e) for e in passes)
    fail_section = "\n\n".join(_format_example(e) for e in fails)

    prompt = f"""# Scenario
You are a dietician assessing the output of a new automatic recipe bot.
Users have provided their dietary requirements and sometimes the bot doesnt fully meet them.

# Task
Assess the following user query, dietary requirement and response from the bot
Mark it either a PASS or a FAIL:
* PASS - The response provided meets the dietary requirement
* FAIL - Any aspect of the response clashes with the users dietary requirement

# Examples

## Pass
{pass_section}

## Fail
{fail_section}

# Trace to Judge
<query>{{query}}</query>
<dietary_restriction>{{dietary_restriction}}</dietary_restriction>
<response>{{response}}</response>


# Formatting
Return your response as json
Use the format `{{'verdict':'PASS', 'justification':'sentence or two on reason for the verdict'}}`
"""
    return prompt


train_examples: list[LabeledTrace] = [LabeledTrace(**d) for d in train_bal.to_dicts()]
print(judge_prompt(train_examples, n_examples=1))