In [31]:
%pip -q install --upgrade --force-reinstall \
  openai>=1.30.0 datasets>=2.19.0 scikit-learn>=1.4.2 tqdm>=4.66.4 \
  pandas==2.2.2 numpy==2.0.2 "pyarrow<20" requests==2.32.4 "pydantic<2.12"


In [32]:

import os, getpass
from datasets import load_dataset

# OpenAI key
if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API key: ")

DATASET = "TheFinAI/flare-fiqasa"

# If gated, pass HF_TOKEN
HF_TOKEN = os.getenv("HF_TOKEN")

try:
    ds_all = load_dataset(DATASET, token=HF_TOKEN)
except Exception:
    try:
        from huggingface_hub import notebook_login
        notebook_login()
        ds_all = load_dataset(DATASET)
    except Exception as e:
        raise RuntimeError(f"Hugging Face access required: {e}")

SPLIT = "train" if "train" in ds_all else list(ds_all.keys())[0]
ds = ds_all[SPLIT]
print(f"Loaded {DATASET} [{SPLIT}]  n={len(ds)}")
print("Columns:", ds.column_names)


Loaded TheFinAI/flare-fiqasa [train]  n=750
Columns: ['id', 'query', 'answer', 'text', 'choices', 'gold']


In [33]:
from typing import List
from collections import OrderedDict
import re

# Eval config
MODEL        = "o3"     # change to "o3-mini" if you want the mini model
MAX_SAMPLES  = None      # set 0/None for full split once things look good
QPS_DELAY    = 1.0      # gentle throttle to reduce 429s
MAX_RETRIES  = 6

def row_text(row):
    q = (row.get("query") or "").strip()
    t = (row.get("text")  or "").strip()
    return f"{q}\n\nText:\n{t}" if (q and t) else (q or t)

def normalize_choices(raw):
    out = []
    if isinstance(raw, list):
        for x in raw:
            if isinstance(x, str):
                s = x.strip()
            elif isinstance(x, dict):
                s = str(x.get("label") or x.get("text") or x.get("value") or x.get("choice") or "").strip()
            else:
                s = str(x).strip()
            if s:
                out.append(s)
    return out

def gold_to_label(row, choices: List[str]):
    g = row.get("gold", None)
    if g is None:
        g = row.get("answer", None)
    if isinstance(g, int):
        return choices[g] if 0 <= g < len(choices) else str(g)
    if isinstance(g, str):
        gs = g.strip()
        for c in choices:
            if gs.lower() == c.lower():
                return c
        return gs
    return str(g)


In [34]:
import json, re, time, random
from typing import List, Tuple
from openai import OpenAI, APIStatusError, RateLimitError, APIConnectionError, APIError

client = OpenAI()  # uses OPENAI_API_KEY
assert globals().get("MODEL", "o3") == "o3", "This cell is o3-only. Set MODEL='o3' in Cell 3."

QPS_DELAY   = globals().get("QPS_DELAY", 1.0)
MAX_RETRIES = globals().get("MAX_RETRIES", 6)

def call_o3_json(text: str, labels: List[str]) -> Tuple[str, float, dict]:
    """
    Ask o3 to choose ONE label from `labels`.
    Returns (label:str, confidence:float, raw_obj_or_text:dict).
    Uses Chat Completions with default temperature (required by o3).
    """
    sys = "Return only valid JSON. Choose exactly one label from the provided set."
    user = f"""Choose ONE label from this set:
{labels}

Return ONLY this JSON (no prose):
{{"label": "<one of {labels}>", "confidence": <0..1>}}

Text:
{text}"""

    def _once():

        return client.chat.completions.create(
            model="o3",
            messages=[
                {"role": "system", "content": sys},
                {"role": "user",   "content": user},
            ],
        )

    last_err = None
    for k in range(MAX_RETRIES):
        try:
            r = _once()
            raw = r.choices[0].message.content.strip()

            # Try strict JSON first
            try:
                obj  = json.loads(raw)
                lab  = obj.get("label")
                conf = float(obj.get("confidence", 0) or 0.0)
            except Exception:
                # Salvage by matching any label token
                low = raw.lower()
                lab = next((L for L in labels if re.search(rf"\b{re.escape(L.lower())}\b", low)), None)
                if lab is None:
                    lab = labels[0]
                conf = 0.0
                obj  = {"label": lab, "raw": raw}

            # Case-insensitive normalization to provided labels
            for L in labels:
                if str(lab).lower() == L.lower():
                    lab = L; break

            return lab, conf, obj

        except APIStatusError as e:

            code = getattr(e, "status_code", None)
            if code in (400, 401, 403, 404):

                raise RuntimeError(
                    f"o3 request failed (HTTP {code}). Details: {e}"
                ) from e
            last_err = e
            time.sleep((2**k) + random.random())

        except (RateLimitError, APIConnectionError, APIError) as e:
            last_err = e
            time.sleep((2**k) + random.random())

    raise last_err or RuntimeError("call_o3_json failed")


In [35]:
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support, classification_report, confusion_matrix

# Index to evaluate
N = (MAX_SAMPLES or len(ds))
idx = list(range(N))

gold, pred, rows = [], [], []
labels_seen = OrderedDict()

for i in tqdm(idx, desc=f"o3 on {SPLIT}"):
    row = ds[i]
    choices = normalize_choices(row.get("choices")) or ["negative","neutral","positive"]
    for c in choices: labels_seen.setdefault(c, None)
    g = gold_to_label(row, choices)
    labels_seen.setdefault(g, None)

    time.sleep(QPS_DELAY)
    lab, conf, raw = call_o3_json(row_text(row), choices)

    gold.append(g); pred.append(lab)
    rows.append({"text": row_text(row), "gold": g, "pred": lab})

# Metrics
label_list = list(labels_seen.keys())
P,R,F1,S = precision_recall_fscore_support(gold, pred, labels=label_list, zero_division=0)
acc = accuracy_score(gold, pred)
f1_micro = f1_score(gold, pred, average="micro")
f1_macro = f1_score(gold, pred, average="macro")
report = classification_report(gold, pred, labels=label_list, digits=4, zero_division=0)
cm = confusion_matrix(gold, pred, labels=label_list).tolist()

print({"n":len(gold), "accuracy":round(acc,4), "f1_micro":round(f1_micro,4), "f1_macro":round(f1_macro,4)})


o3 on train: 100%|██████████| 750/750 [46:55<00:00,  3.75s/it]

{'n': 750, 'accuracy': 0.8227, 'f1_micro': 0.8227, 'f1_macro': 0.6931}



