In [1]:
%pip -q install --upgrade --force-reinstall \
  openai>=1.30.0 datasets>=2.19.0 scikit-learn>=1.4.2 tqdm>=4.66.4 huggingface_hub>=0.24.0 \
  pandas==2.2.2 numpy==2.0.2 "pyarrow<20" requests==2.32.4 "pydantic<2.12"


In [9]:
import os, getpass
from openai import OpenAI

#  xAI key securely
if not os.getenv("XAI_API_KEY"):
    os.environ["XAI_API_KEY"] = getpass.getpass("xAI API key: ")

# xAI uses an OpenAI-compatible API
BASE_URL = os.getenv("XAI_BASE_URL", "https://api.x.ai/v1")

#  Grok model
MODEL_NAME = os.getenv("GROK_MODEL", "grok-4")

client = OpenAI(api_key=os.environ["XAI_API_KEY"], base_url=BASE_URL)
print("Grok client ready ✅", "Model:", MODEL_NAME)


Grok client ready ✅ Model: grok-4


In [10]:
# Verify token & access, then load the gated dataset
import os, getpass
from huggingface_hub import HfApi, HfFolder
from datasets import load_dataset

# Get token
token = HfFolder.get_token()
if not token:
    token = os.getenv("HF_TOKEN") or getpass.getpass("Paste your HF token (hf_...): ")

api = HfApi()
print("Who am I:", api.whoami(token=token).get("name"))

# Check access
info = api.dataset_info("TheFinAI/en-fpb", token=token)
print("en-fpb access OK. SHA:", info.sha[:8], "...")

# Now load the dataset with the token
ds_all = load_dataset("TheFinAI/en-fpb", token=token)
split = next((s for s in ["test","validation","valid","train"] if s in ds_all), list(ds_all.keys())[0])
ds = ds_all[split]
print(f"Loaded split: {split} (n={len(ds)})")


Paste your HF token (hf_...): ··········
Who am I: Hanlin0914
en-fpb access OK. SHA: 98f0a91c ...
Loaded split: test (n=970)


In [11]:
from typing import List, Tuple
import json, re

def row_text(row):
    q = (row.get("query") or "").strip()
    t = (row.get("text") or "").strip()
    return f"{q}\n\nText:\n{t}" if (q and t) else (q or t)

def normalize_choices(raw):
    out = []
    if isinstance(raw, list):
        for x in raw:
            if isinstance(x, str):
                s = x.strip()
            elif isinstance(x, dict):
                s = str(x.get("label") or x.get("text") or x.get("value") or x.get("choice") or "").strip()
            else:
                s = str(x).strip()
            if s:
                out.append(s)
    return out

def gold_to_label(row, choices: List[str]):
    g = row.get("gold", None)
    if g is None:
        g = row.get("answer", None)
    if isinstance(g, int):
        return choices[g] if 0 <= g < len(choices) else str(g)
    if isinstance(g, str):
        gs = g.strip()
        for c in choices:
            if gs.lower() == c.lower():
                return c
        return gs
    return str(g)


In [12]:
import time, random
from openai import APIError, APIConnectionError, RateLimitError

MAX_RETRIES = 6
QPS_DELAY   = 1.0  # polite throttle

def call_grok_json(text: str, labels: List[str]) -> Tuple[str, float, dict]:
    """Ask Grok to return: {"label":"<one>", "confidence": float}"""
    sys = "Return only valid JSON. Choose exactly one label from the provided set."
    user = f"""Choose ONE label from this set:
{labels}

Return ONLY this JSON (no prose):
{{"label": "<one of {labels}>", "confidence": <0..1>}}

Text:
{text}"""

    last = None
    for k in range(MAX_RETRIES):
        try:
            r = client.chat.completions.create(
                model=MODEL_NAME,
                temperature=0,
                messages=[{"role":"system","content":sys},
                          {"role":"user","content":user}]
            )
            raw = r.choices[0].message.content.strip()
            try:
                obj = json.loads(raw)
            except Exception:
                # salvage common cases by extracting a label token
                low = raw.lower()
                for L in labels:
                    if re.search(rf"\b{re.escape(L.lower())}\b", low):
                        return L, 0.0, {"label": L, "raw": raw}
                # fallback to first label to avoid crashes
                return labels[0], 0.0, {"label": labels[0], "raw": raw}

            label = obj.get("label")
            conf  = float(obj.get("confidence", 0.0))
            if label not in labels:
                # try case-insensitive match
                for L in labels:
                    if str(label).lower() == L.lower():
                        label = L; break
            if label not in labels:
                label = labels[0]
            return label, conf, obj

        except (RateLimitError, APIConnectionError, APIError) as e:
            last = e
            time.sleep((2**k) + random.random())
    # last-ditch
    raise last or RuntimeError("Grok call failed")


In [13]:
from tqdm import tqdm
import pandas as pd
from collections import OrderedDict
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support, classification_report, confusion_matrix
from pathlib import Path

MAX_SAMPLES = None

idx = list(range(len(ds)))[: (MAX_SAMPLES or len(ds))]
gold, pred, rows = [], [], []
labels_seen = OrderedDict()

# Ensure SPLIT is defined for logging/progress
if 'SPLIT' not in globals():
    if 'ds_all' in globals() and isinstance(ds_all, dict) and len(ds_all):
        SPLIT = next((s for s in ["test","validation","valid","train"] if s in ds_all),
                     list(ds_all.keys())[0])
    else:
        SPLIT = "custom"
print("Using SPLIT:", SPLIT)


for i in tqdm(idx, desc=f"Grok on {SPLIT}"):
    row = ds[i]
    choices = normalize_choices(row.get("choices")) or ["negative","neutral","positive"]
    for c in choices: labels_seen.setdefault(c, None)
    g = gold_to_label(row, choices)
    labels_seen.setdefault(g, None)

    guess, conf, raw = call_grok_json(row_text(row), choices)
    gold.append(g); pred.append(guess)
    rows.append({"text": row_text(row), "gold": g, "pred": guess})

label_list = list(labels_seen.keys())
P,R,F1,S = precision_recall_fscore_support(gold, pred, labels=label_list, zero_division=0)
acc = accuracy_score(gold, pred)
f1_micro = f1_score(gold, pred, average="micro")
f1_macro = f1_score(gold, pred, average="macro")
report = classification_report(gold, pred, labels=label_list, digits=4, zero_division=0)
cm = confusion_matrix(gold, pred, labels=label_list).tolist()

print({"n":len(gold), "acc":round(acc,4), "f1_micro":round(f1_micro,4), "f1_macro":round(f1_macro,4)})


Using SPLIT: test


Grok on test: 100%|██████████| 970/970 [2:25:29<00:00,  9.00s/it]

{'n': 970, 'acc': 0.7639, 'f1_micro': 0.7639, 'f1_macro': 0.783}



