In [4]:
!pip install -q google-generativeai datasets pandas scikit-learn tqdm

import os, getpass, re, time
import google.generativeai as genai
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report

os.environ["GOOGLE_API_KEY"] = getpass.getpass("Paste your Google API key: ")
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])

MODEL = "gemini-2.5-flash"
model = genai.GenerativeModel(model_name=MODEL, system_instruction="You are a financial sentiment classifier. Choose exactly one label from the options. Output ONLY the label.")
GEN_CFG = {"temperature": 0, "max_output_tokens": 128}

ds = load_dataset("ChanceFocus/en-fpb", split="test")
LIMIT = 200
ds = ds.select(range(min(LIMIT, len(ds))))

def normalize_to_choice(raw, choices):
    if not raw:
        return None
    s = re.split(r"[\r\n]", str(raw).strip())[0].strip().strip(".:;").lower()
    alias = {"pos":"positive","neg":"negative","neu":"neutral","bullish":"positive","bearish":"negative"}
    s = alias.get(s, s)
    for c in choices:
        if s == c.lower():
            return c
    for c in choices:
        if c.lower().startswith(s):
            return c
    return None

def extract_text(resp):
    for cand in getattr(resp, "candidates", []) or []:
        parts = getattr(getattr(cand, "content", None), "parts", None)
        if parts:
            t = "".join(getattr(p, "text", "") for p in parts).strip()
            if t:
                return t
    return ""

def ask_model(sentence, choices, retries=3, sleep=1):
    user1 = f"Sentence: {sentence}\nOptions: {', '.join(choices)}\nAnswer with ONE of the options only."
    user2 = f"Return exactly one of: {', '.join(choices)}.\nText: {sentence}"
    cfg = GEN_CFG.copy()
    last = None
    for i in range(retries):
        try:
            resp = model.generate_content(user1 if i == 0 else user2, generation_config=cfg)
            if getattr(resp, "prompt_feedback", None) and getattr(resp.prompt_feedback, "block_reason", None):
                return "UNKNOWN"
            txt = extract_text(resp)
            if txt:
                return txt
            cfg["max_output_tokens"] = max(cfg.get("max_output_tokens", 128), 192)
        except Exception as e:
            last = e
        time.sleep(sleep)
    return "UNKNOWN"

ex = ds[0]
pred_raw = ask_model(ex["text"], list(ex["choices"]))
pred = normalize_to_choice(pred_raw, list(ex["choices"]))
print(ex["text"], ex["choices"], pred_raw, pred, ex["answer"])

rows, y_true, y_pred = [], [], []
for i in tqdm(range(len(ds))):
    x = ds[i]
    choices = list(x["choices"])
    gold = x["answer"]
    raw = ask_model(x["text"], choices)
    pred = normalize_to_choice(raw, choices) or "UNKNOWN"
    rows.append({"id": x.get("id", i), "text": x["text"], "choices": "|".join(choices), "pred_raw": raw, "pred": pred, "label": gold})
    y_true.append(gold); y_pred.append(pred)

df = pd.DataFrame(rows)
df.to_csv("/content/fpb_predictions_gemini_200.csv", index=False)
print("Saved to /content/fpb_predictions_gemini_200.csv")

ok = df[df["pred"]!="UNKNOWN"]
print("Used for scoring:", len(ok), "/", len(df))
print("Accuracy:", round(accuracy_score(ok["label"], ok["pred"]), 4))
print("Macro-F1:", round(f1_score(ok["label"], ok["pred"], average="macro"), 4))
print("\nReport:\n", classification_report(ok["label"], ok["pred"]))


Paste your Google API key: ··········
The new agreement , which expands a long-established cooperation between the companies , involves the transfer of certain engineering and documentation functions from Larox to Etteplan . ['positive', 'neutral', 'negative'] UNKNOWN None positive


100%|██████████| 200/200 [11:08<00:00,  3.34s/it]

Saved to /content/fpb_predictions_gemini_200.csv
Used for scoring: 76 / 200
Accuracy: 0.8684
Macro-F1: 0.3099

Report:
               precision    recall  f1-score   support

    negative       0.00      0.00      0.00         0
     neutral       0.00      0.00      0.00         0
    positive       1.00      0.87      0.93        76

    accuracy                           0.87        76
   macro avg       0.33      0.29      0.31        76
weighted avg       1.00      0.87      0.93        76




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
