# Barista Bench

### The Reasoning Challenge

### 1. Imports

In [27]:
# Core libraries, LLM clients, and progress utilities
import os
import json
import time
import pandas as pd
from google import genai
from openai import OpenAI
from tqdm import tqdm

### 2. Load Data

In [28]:
# Load competition training and test data from CSV files
train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')

# Quick sanity check on dataset sizes
print("Train shape:", train_df.shape)
print("Test shape:", test_df.shape)

Train shape: (500, 3)
Test shape: (3500, 2)


### 3. API Keys

In [29]:
# Load API keys from environment variables or a secrets manager (recommended)
GEMINI_KEY = os.getenv("GEMINI_KEY")
OPENAI_KEY = os.getenv("OPENAI_KEY")

### 4. Menu Data For Price Validator

In [30]:
# Static menu configuration used to recompute prices deterministically
BASE_PRICES = {
    "Espresso": 3.00, "Americano": 3.50, "Drip Coffee": 2.50,
    "Latte": 4.50, "Cappuccino": 4.50, "Flat White": 4.75,
    "Mocha": 5.00, "Caramel Macchiato": 5.25,
    "Cold Brew": 4.25, "Iced Coffee": 3.00,
    "Frappe (Coffee)": 5.50, "Frappe (Mocha)": 5.75,
    "Strawberry Smoothie": 6.00,
    "Chai Latte": 4.75, "Matcha Latte": 5.25,
    "Earl Grey Tea": 3.00, "Green Tea": 3.00, "Hot Chocolate": 4.00,
    "Butter Croissant": 3.50, "Blueberry Muffin": 3.75,
    "Bagel": 2.50, "Avocado Toast": 7.00, "Bacon Gouda Sandwich": 5.50,
}

# Per-size price adjustments (applied only to drinks)
SIZE_ADJ = {
    "Short": -0.50, "Tall": 0.00, "Grande": 0.50,
    "Venti": 1.00, "Trenta": 1.50,
}

# Individual modifier price deltas
MOD_COSTS = {
    "Oat Milk": 0.80, "Almond Milk": 0.60, "Soy Milk": 0.60,
    "Coconut Milk": 0.70, "Breve": 0.80, "Skim Milk": 0.00,
    "Vanilla Syrup": 0.50, "Caramel Syrup": 0.50,
    "Hazelnut Syrup": 0.50, "Peppermint Syrup": 0.50,
    "Sugar Free Vanilla": 0.50, "Classic Syrup": 0.00,
    "Extra Shot": 1.00, "Whip Cream": 0.50,
    "No Whip": 0.00, "Cold Foam": 1.25, "Caramel Drizzle": 0.50,
    "Extra Hot": 0.00, "Light Ice": 0.00, "No Ice": 0.00,
}

# Food is priced without sizes or drink-style modifiers
FOOD_ITEMS = {"Butter Croissant", "Blueberry Muffin", "Bagel",
              "Avocado Toast", "Bacon Gouda Sandwich"}

def fix_price(parsed):
    """Recalculate total_price from menu rules. Overrides LLM math."""
    try:
        total = 0.0
        for item in parsed.get("items", []):
            name = item.get("name", "")
            size = item.get("size")
            qty = item.get("quantity", 1)
            mods = item.get("modifiers", [])

            # Look up base drink/food price
            base = BASE_PRICES.get(name, 0.0)

            # Add size adjustment only for drinks (food has no size)
            s = SIZE_ADJ.get(size, 0.0) if (name not in FOOD_ITEMS and size) else 0.0

            # Sum the price of all modifiers
            m = sum(MOD_COSTS.get(mod, 0.0) for mod in mods)

            # Price formula: (base + size + modifiers) × quantity
            total += (base + s + m) * qty

        parsed["total_price"] = round(total, 2)
    except Exception as e:
        print(f"  Price fix error: {e}")
    return parsed

def merge_items(parsed):
    """Merge duplicate items with same name, size, and modifiers."""
    if not parsed.get('items'):
        return parsed
    
    merged = {}
    for item in parsed['items']:
        # Key includes name, size, and a canonicalized list of modifiers
        key = (item['name'], item.get('size'), tuple(sorted(item.get('modifiers', []))))
        if key in merged:
            # If item already exists, just bump the quantity
            merged[key]['quantity'] += item.get('quantity', 1)
        else:
            merged[key] = {
                'name': item['name'],
                'size': item.get('size'),
                'quantity': item.get('quantity', 1),
                'modifiers': item.get('modifiers', [])
            }
    
    parsed['items'] = list(merged.values())
    return parsed

### 5. Verify Price Validator Against Training Data

In [31]:
# Sanity-check that our price calculator exactly matches the labeled training prices
correct, wrong = 0, 0
for _, row in train_df.iterrows():
    try:
        exp = json.loads(row['expected_json'])
        # Clone JSON to avoid mutating the original as we recompute prices
        recalc = fix_price(json.loads(json.dumps(exp)))
        if abs(recalc['total_price'] - exp['total_price']) < 0.01:
            correct += 1
        else:
            wrong += 1
            if wrong <= 3:
                print(f"  Mismatch: expected ${exp['total_price']}, got ${recalc['total_price']}")
                print(f"    Items: {exp['items']}")
    except:
        # Skip any malformed rows without failing the notebook
        continue
print(f"\nPrice validator: {correct}/{correct+wrong} correct ({100*correct/(correct+wrong):.1f}%)")


Price validator: 500/500 correct (100.0%)


### 6. System Prompt

In [32]:
# Construct few-shot examples from labeled data and define the system instruction for the LLM
# (examples pasted directly into the system message at the end)
# Few-shot examples from training
examples = ""
for i, (_, row) in enumerate(train_df.sample(n=8, random_state=42).iterrows()):
    examples += f"\nExample {i+1}:\nOrder: {row['order']}\nOutput: {row['expected_json']}\n"

cancellation_examples = """
Example 9 (CANCELLATION):
Order: Can I get a Bagel and a Tall Latte with vanilla syrup. Actually remove the bagel.
Output: {"items": [{"name": "Latte", "size": "Tall", "quantity": 1, "modifiers": ["Vanilla Syrup"]}], "total_price": 5.00}

Example 10 (MODIFIER CANCELLATION):
Order: I'll have a Grande Iced Coffee with almond milk, wait scratch that almond milk.
Output: {"items": [{"name": "Iced Coffee", "size": "Grande", "quantity": 1, "modifiers": []}], "total_price": 3.50}

Example 11 (QUANTITY CORRECTION):
Order: Gimme two blueberry muffins, no actually make that three blueberry muffins.
Output: {"items": [{"name": "Blueberry Muffin", "size": null, "quantity": 3, "modifiers": []}], "total_price": 11.25}

Example 12 (MULTI-ITEM WITH REMOVAL):
Order: I want a Tall Cold Brew, a Bagel, and a Grande Green Tea with Skim Milk. Cancel the bagel.
Output: {"items": [{"name": "Cold Brew", "size": "Tall", "quantity": 1, "modifiers": []}, {"name": "Green Tea", "size": "Grande", "quantity": 1, "modifiers": ["Skim Milk"]}], "total_price": 7.75}
"""

SYSTEM_PROMPT = """You are a coffee shop POS system. Parse natural language orders into structured JSON.

# MENU (Base = Tall price)
## HOT COFFEES
Espresso $3.00 | Americano $3.50 | Drip Coffee $2.50 | Latte $4.50
Cappuccino $4.50 | Flat White $4.75 | Mocha $5.00 | Caramel Macchiato $5.25

## COLD/BLENDED
Cold Brew $4.25 | Iced Coffee $3.00 | Frappe (Coffee) $5.50
Frappe (Mocha) $5.75 | Strawberry Smoothie $6.00

## TEAS & OTHERS
Chai Latte $4.75 | Matcha Latte $5.25 | Earl Grey Tea $3.00
Green Tea $3.00 | Hot Chocolate $4.00

## FOOD (size: null, no modifiers)
Butter Croissant $3.50 | Blueberry Muffin $3.75 | Bagel $2.50
Avocado Toast $7.00 | Bacon Gouda Sandwich $5.50

# SIZES: Short -$0.50 | Tall +$0.00 | Grande +$0.50 | Venti +$1.00 | Trenta +$1.50 (cold only)

# MODIFIERS
Milks (replace default): Oat Milk +$0.80 | Almond Milk +$0.60 | Soy Milk +$0.60 | Coconut Milk +$0.70 | Breve +$0.80 | Skim Milk +$0.00
Syrups (stack): Vanilla Syrup +$0.50 | Caramel Syrup +$0.50 | Hazelnut Syrup +$0.50 | Peppermint Syrup +$0.50 | Sugar Free Vanilla +$0.50 | Classic Syrup +$0.00
Toppings: Extra Shot +$1.00 | Whip Cream +$0.50 | No Whip +$0.00 | Cold Foam +$1.25 | Caramel Drizzle +$0.50 | Extra Hot +$0.00 | Light Ice +$0.00 | No Ice +$0.00

# RULES
1. Milk mods REPLACE default (charge replacement only)
2. "No Whip" on Mocha/Frappe/Hot Chocolate: price unchanged, record as modifier
3. Multiple syrups stack
4. Food: size=null, no modifiers
5. Handle corrections: "actually"/"change to"/"switch to" = modify; "cancel"/"nevermind"/"scratch that" = remove
6. Ignore fillers: "um","uh","like","literally","basically" — EVEN when they appear between a modifier and its keyword (e.g. "peppermint uh syrup" = "Peppermint Syrup")
7. Quantities: a/one/single=1, couple/two/pair=2, triple/three/a few=3, four=4, five=5
8. Title Case for all names/modifiers
9. Price formula: (base + size_adj + modifier_costs) × quantity

# CRITICAL STATE TRACKING RULES (read carefully)
- "remove that", "remove the", "drop the" = DELETE the referenced item or modifier entirely
- "actually" followed by a correction = REPLACE the previous version, do NOT keep both
- If user orders "two X actually three X", final quantity is 3, NOT 5
- If user says "add vanilla, actually make it caramel", final modifier is Caramel Syrup ONLY
- NEVER output more items than the customer actually wants in their final order
- Count items AFTER applying all corrections and cancellations

# MODIFIER PARSING RULES
- Filler words inside modifier phrases should be ignored: "peppermint uh syrup" = "Peppermint Syrup"
- "add hazelnut" = "Hazelnut Syrup", "add vanilla" = "Vanilla Syrup" etc
- Parse ALL modifiers mentioned for each item, even if separated by filler words
- Do NOT add "No Ice" unless the customer explicitly says "no ice"

# OUTPUT: ONLY valid JSON, no markdown, no explanation
{"items": [{"name": "...", "size": "..." or null, "quantity": N, "modifiers": [...]}], "total_price": X.XX}

""" + examples

### 7. Model Setup

STEP 1 OPTIONS (cheap, for prompt iteration):
PROVIDER = "gemini",  MODEL = "gemini-2.0-flash"
PROVIDER = "gemini",  MODEL = "gemini-2.5-flash-lite"

STEP 2 OPTIONS (better quality, for final submission):
PROVIDER = "openai",  MODEL = "gpt-4o-mini"
PROVIDER = "gemini",  MODEL = "gemini-2.5-flash"

In [39]:
# ---- PROVIDER CONFIG: Change this to switch models ----
PROVIDER = "openai"
MODEL = "gpt-4.1"

# Global settings for robustness and rate limiting
MAX_RETRIES = 3
DELAY = 2.0  # seconds between requests (increase if rate limited)

# Instantiate the chosen LLM client once at notebook start
if PROVIDER == "gemini":
    gemini_client = genai.Client(api_key=GEMINI_KEY)
elif PROVIDER == "openai":
    openai_client = OpenAI(api_key=OPENAI_KEY)


def call_llm(order_text):
    """Call the configured LLM provider with the parsing prompt."""
    if PROVIDER == "gemini":
        resp = gemini_client.models.generate_content(
            model=MODEL,
            contents=f"Parse this order:\n{order_text}",
            config={
                "system_instruction": SYSTEM_PROMPT,
                "temperature": 0.0,
                "max_output_tokens": 1024,
            }
        )
        return resp.text.strip()
    elif PROVIDER == "openai":
        resp = openai_client.chat.completions.create(
            model=MODEL,
            temperature=0.0,
            max_tokens=1024,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": f"Parse this order:\n{order_text}"}
            ]
        )
        return resp.choices[0].message.content.strip()


# Quick smoke test to confirm the model and prompt behave as expected
print(f"Using: {PROVIDER} / {MODEL}")
test_resp = call_llm("One tall latte with oat milk")
print(f"Test response:\n{test_resp}")

Using: openai / gpt-4.1
Test response:
{"items": [{"name": "Latte", "size": "Tall", "quantity": 1, "modifiers": ["Oat Milk"]}], "total_price": 5.3}


### 8. Parse Orders

In [40]:
# Main inference routine: robust per-order parsing plus batch processing over test set
def parse_order(order_text):
    for attempt in range(MAX_RETRIES):
        try:
            raw = call_llm(order_text)
            raw = raw.replace("```json", "").replace("```", "").strip()
            parsed = json.loads(raw)
            if 'items' not in parsed or 'total_price' not in parsed:
                raise ValueError("Missing keys")
            parsed = merge_items(parsed)  # ← NEW: merge duplicates
            parsed = fix_price(parsed)    # ← fix arithmetic
            return parsed
        except (json.JSONDecodeError, ValueError) as e:
            if attempt < MAX_RETRIES - 1:
                time.sleep(2)
            continue
        except Exception as e:
            err = str(e).lower()
            if '429' in err or 'rate' in err or 'quota' in err:
                wait = 30 * (attempt + 1)
                print(f"\n  Rate limited! Waiting {wait}s...")
                time.sleep(wait)
            else:
                print(f"\n  Error (attempt {attempt+1}): {e}")
                if attempt < MAX_RETRIES - 1:
                    time.sleep(5)
            continue
    return {"items": [], "total_price": 0.0}

# Storage for model predictions and basic failure stats over the test set
results = []
failed = 0
est_min = len(test_df) * DELAY / 60
print(f"\nParsing {len(test_df)} orders | {PROVIDER}/{MODEL} | ~{est_min:.0f} min estimated\n")

for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Parsing",
                   bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'):
    parsed = parse_order(row['order'])
    if not parsed.get('items'):
        failed += 1
    results.append(json.dumps(parsed))
    time.sleep(DELAY)

print(f"\nDone! {failed} failed out of {len(results)}")


Parsing 3500 orders | openai/gpt-4.1 | ~117 min estimated



Parsing:  31%|███       | 1068/3500 [1:19:08<2:35:25,  3.83s/it]


  Rate limited! Waiting 30s...


Parsing:  53%|█████▎    | 1870/3500 [2:18:47<1:39:23,  3.66s/it]


  Rate limited! Waiting 30s...


Parsing: 100%|██████████| 3500/3500 [4:20:02<00:00,  4.46s/it]  


Done! 68 failed out of 3500





### 9. Generate Submission

In [41]:
# Build the competition submission DataFrame with IDs and model predictions
submission = pd.DataFrame({
    'id': test_df['id'],
    'predicted_json': results
})

# Save with model name for tracking
fname = f"submission_{PROVIDER}_{MODEL.replace('/', '_')}.csv"
submission.to_csv(fname, index=False)
submission.to_csv('submission.csv', index=False)  # also save as default
print(f"Saved: {fname}")

# Show samples
for i in range(min(5, len(submission))):
    print(f"\n  ID {submission.iloc[i]['id']}:")
    print(f"  {submission.iloc[i]['predicted_json']}")

Saved: submission_openai_gpt-4.1.csv

  ID 500:
  {"items": [{"name": "Bacon Gouda Sandwich", "size": null, "quantity": 1, "modifiers": []}, {"name": "Americano", "size": "Grande", "quantity": 3, "modifiers": ["Caramel Drizzle"]}], "total_price": 19.0}

  ID 501:
  {"items": [{"name": "Espresso", "size": "Venti", "quantity": 1, "modifiers": []}], "total_price": 4.0}

  ID 502:
  {"items": [{"name": "Caramel Macchiato", "size": "Tall", "quantity": 3, "modifiers": []}, {"name": "Flat White", "size": "Trenta", "quantity": 1, "modifiers": ["Extra Shot", "Almond Milk"]}], "total_price": 23.6}

  ID 503:
  {"items": [{"name": "Avocado Toast", "size": null, "quantity": 3, "modifiers": []}], "total_price": 21.0}

  ID 504:
  {"items": [{"name": "Flat White", "size": "Short", "quantity": 1, "modifiers": ["Oat Milk"]}, {"name": "Cold Brew", "size": "Venti", "quantity": 1, "modifiers": ["Whip Cream"]}, {"name": "Flat White", "size": "Short", "quantity": 1, "modifiers": ["Extra Hot"]}], "total_pri

### 10. Validation Stats

In [42]:
# High-level sanity checks on prediction distribution and quick accuracy on training data
empty = sum(1 for r in results if json.loads(r).get('items') == [])
prices = [json.loads(r)['total_price'] for r in results if json.loads(r)['total_price'] > 0]

print(f"\n--- Quality Report ---")
print(f"Empty predictions: {empty}/{len(results)} ({100*empty/len(results):.1f}%)")
if prices:
    print(f"Price range: ${min(prices):.2f} - ${max(prices):.2f}")
    print(f"Mean price: ${sum(prices)/len(prices):.2f}")
    print(f"Median price: ${sorted(prices)[len(prices)//2]:.2f}")

# Check against training data (quick accuracy estimate)
print(f"\n--- Training Accuracy Check ---")
match, total_checked = 0, 0
for _, row in train_df.sample(n=min(20, len(train_df)), random_state=99).iterrows():
    expected = json.loads(row['expected_json'])
    predicted = parse_order(row['order'])
    total_checked += 1
    time.sleep(DELAY)

    exp_names = sorted([i['name'] for i in expected['items']])
    pred_names = sorted([i['name'] for i in predicted['items']])

    if (exp_names == pred_names and
        abs(predicted['total_price'] - expected['total_price']) < 0.01):
        match += 1
    else:
        print(f"  Miss: expected {exp_names} ${expected['total_price']}, "
              f"got {pred_names} ${predicted['total_price']}")

print(f"Quick accuracy: {match}/{total_checked} ({100*match/total_checked:.1f}%)")


--- Quality Report ---
Empty predictions: 68/3500 (1.9%)
Price range: $2.00 - $92.50
Mean price: $26.72
Median price: $24.25

--- Training Accuracy Check ---
Quick accuracy: 20/20 (100.0%)


In [26]:
# Quick offline check: compare price validator against training data structure
issues = {"name": [], "size": [], "modifier": [], "quantity": []}

# Parse 50 training orders and compare
sample = train_df.sample(n=50, random_state=42)
for _, row in tqdm(sample.iterrows(), total=50, desc="Checking"):
    expected = json.loads(row['expected_json'])
    predicted = parse_order(row['order'])
    time.sleep(DELAY)

    exp_items = sorted(expected['items'], key=lambda x: x['name'])
    pred_items = sorted(predicted['items'], key=lambda x: x['name'])

    if len(exp_items) != len(pred_items):
        issues["name"].append(f"Count: exp {len(exp_items)} vs got {len(pred_items)} | {row['order'][:80]}")
        continue

    for e, p in zip(exp_items, pred_items):
        if e['name'] != p['name']:
            issues["name"].append(f"'{e['name']}' vs '{p['name']}' | {row['order'][:80]}")
        if e.get('size') != p.get('size'):
            issues["size"].append(f"'{e.get('size')}' vs '{p.get('size')}' for {e['name']} | {row['order'][:80]}")
        if e.get('quantity') != p.get('quantity'):
            issues["quantity"].append(f"{e.get('quantity')} vs {p.get('quantity')} for {e['name']} | {row['order'][:80]}")
        if sorted(e.get('modifiers', [])) != sorted(p.get('modifiers', [])):
            issues["modifier"].append(
                f"exp {sorted(e.get('modifiers',[]))} vs got {sorted(p.get('modifiers',[]))} | {row['order'][:80]}"
            )

print(f"\n{'='*50}")
print(f"ERROR BREAKDOWN (50 training orders)")
print(f"{'='*50}")
for cat, errs in issues.items():
    print(f"\n{cat.upper()} errors: {len(errs)}")
    for e in errs[:3]:  # show first 3 of each
        print(f"  → {e}")

Checking: 100%|██████████| 50/50 [02:44<00:00,  3.29s/it]


ERROR BREAKDOWN (50 training orders)

NAME errors: 2
  → Count: exp 1 vs got 0 | Start me off with couple of tall Flat Whites and you know, hold the ice plus one
  → Count: exp 4 vs got 5 | I'd like to order three Grande ICED COFFEEs include almond milk wait scratch tha

SIZE errors: 0

MODIFIER errors: 2
  → exp ['Breve', 'Oat Milk'] vs got ['Oat Milk'] | Could you get me one grande americano, plus one TALL cold brew add actually brev
  → exp ['Oat Milk', 'Skim Milk'] vs got ['Oat Milk'] | I'll have five literally blueberry muffin and four trenta earl grey teas add hol

QUANTITY errors: 0



