In [1]:
# Step 1: Install dependencies & set Hugging Face token
%pip install -q "datasets>=2.19.0" "huggingface_hub>=0.24"
import os
import getpass

# 直接设置Hugging Face token，跳过登录界面
hf_token = getpass.getpass("Paste your Hugging Face token: ")
os.environ['HF_TOKEN'] = hf_token
os.environ['HUGGINGFACE_HUB_TOKEN'] = hf_token

print("Hugging Face token set successfully!")

Note: you may need to restart the kernel to use updated packages.


Paste your Hugging Face token:  ········


Hugging Face token set successfully!


In [2]:
# Step 2: Load FLARE-FPB test set and normalize labels
from datasets import load_dataset, Dataset

LABELS = ["negative", "neutral", "positive"]

ds_raw = load_dataset("TheFinAI/flare-fpb", split="test")
print("Loaded flare-fpb test:", len(ds_raw), "columns:", ds_raw.column_names)

_alias = {"pos": "positive", "neg": "negative", "neu": "neutral",
          "bullish": "positive", "bearish": "negative"}

def _norm_label(v):
    if v is None: 
        return None
    if isinstance(v, (int, float)) or (isinstance(v, str) and v.isdigit()):
        i = int(v)
        return LABELS[i] if 0 <= i < len(LABELS) else None
    s = str(v).strip().lower()
    s = _alias.get(s, s)
    return s if s in LABELS else None

def _map_row(x):
    text = x.get("text") or x.get("sentence") or x.get("content") or x.get("input") or ""
    lab = _norm_label(x.get("label", x.get("labels", x.get("answer"))))
    return {"text": text, "choices": LABELS, "answer": lab}

ds = Dataset.from_list([{**r, **_map_row(r)} for r in ds_raw])
bad = [i for i, r in enumerate(ds) if r["answer"] not in LABELS]
print("Samples with unusable label:", len(bad))
assert len(bad) == 0, "Found unparseable labels; please check the field mapping."

Loaded flare-fpb test: 970 columns: ['id', 'query', 'answer', 'text', 'choices', 'gold']
Samples with unusable label: 0


In [3]:
# Step 3: Install dependencies, configure DeepSeek, and record experiment metadata
%pip install -q "openai==1.40.2" "httpx==0.27.2" "httpcore==1.0.5" \
               "pandas>=2.2.2" "tqdm>=4.66.4" "requests>=2.31.0"

import os, getpass, json, time, platform
from importlib.metadata import version, PackageNotFoundError

# DeepSeek适配：使用deepseek-chat模型
MODEL = "deepseek-chat"
BASE_URL = "https://api.deepseek.com/v1"

api_key = os.getenv("DEEPSEEK_API_KEY") or os.getenv("API_KEY")
if not api_key:
    api_key = getpass.getpass("Paste your DeepSeek API key: ")
os.environ["DEEPSEEK_API_KEY"] = api_key

# DeepSeek适配：调整文件命名以区分模型
run_tag = f"flare_fpb_{MODEL.replace('-', '_')}"
save_dir = "/content"
pred_path = f"{save_dir}/{run_tag}_predictions.csv"
meta_path = f"{save_dir}/{run_tag}_metadata.json"

def ver(pkg: str) -> str:
    try:
        return version(pkg)
    except PackageNotFoundError:
        return "not-installed"

# DeepSeek适配：在元数据中标注模型版本信息
meta = {
    "dataset": "TheFinAI/flare-fpb",
    "split": "test",
    "labels": list(LABELS),
    "model": MODEL,
    "model_variant": "standard",
    "openai_sdk": ver("openai"),
    "httpx": ver("httpx"),
    "httpcore": ver("httpcore"),
    "datasets_version": ver("datasets"),
    "pandas": ver("pandas"),
    "tqdm": ver("tqdm"),
    "time_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "python": platform.python_version(),
    "base_url": BASE_URL,
    "note": "DeepSeek-chat model evaluation"
}

os.makedirs(save_dir, exist_ok=True)
with open(meta_path, "w") as f:
    json.dump(meta, f, indent=2)

print("Meta saved ->", meta_path)
print("MODEL:", MODEL, "| BASE_URL:", BASE_URL)
print("DEEPSEEK_API_KEY is set:", bool(os.environ.get("DEEPSEEK_API_KEY")))

Note: you may need to restart the kernel to use updated packages.


Paste your DeepSeek API key:  ········


Meta saved -> /content/flare_fpb_deepseek_chat_metadata.json
MODEL: deepseek-chat | BASE_URL: https://api.deepseek.com/v1
DEEPSEEK_API_KEY is set: True


In [4]:
# Step 4: Inference & evaluation loop (DeepSeek adaptation)
import requests, json, os, re, time
import pandas as pd
from tqdm import tqdm
from openai import OpenAI

def _strip_code_fences(s: str) -> str:
    s = s.strip()
    if s.startswith("```"):
        s = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", s)
        s = re.sub(r"\s*```$", "", s)
    return s.strip()

def _make_user_text(sentence: str, choices=("",)):
    # DeepSeek适配：保持原有提示词结构
    return (
        "Task: classify the sentence into exactly one of these labels: "
        f"{', '.join(choices)}.\n\n"
        f"Sentence: {sentence}\n\n"
        "Return ONLY a JSON object on a single line, exactly in this form:\n"
        "{\"label\":\"negative|neutral|positive\"}\n"
        "No code fences, no extra text, no explanation."
    )

def ask_deepseek_once(sentence, choices=("negative", "neutral", "positive")):
    client = OpenAI(
        api_key=os.environ['DEEPSEEK_API_KEY'],
        base_url=BASE_URL
    )
    
    user_text = _make_user_text(sentence, choices)
    
    try:
        response = client.chat.completions.create(
            model=MODEL,
            messages=[
                {"role": "user", "content": user_text}
            ],
            max_tokens=100,
            temperature=0.1
        )
        
        content = response.choices[0].message.content.strip()
        content = _strip_code_fences(content)
        
        # 尝试解析JSON
        obj = json.loads(content)
        lab = obj.get("label")
        if lab not in choices:
            raise RuntimeError(f"Invalid label {lab!r}; raw json: {obj}")
        return lab
        
    except json.JSONDecodeError:
        # 如果JSON解析失败，尝试从文本中提取标签
        content_lower = content.lower()
        for choice in choices:
            if choice in content_lower:
                return choice
        raise RuntimeError(f"Could not parse label from response: {content}")
    except Exception as e:
        raise RuntimeError(f"API call failed: {str(e)}")

def ask_deepseek(sentence, choices=("negative", "neutral", "positive")):
    # DeepSeek适配：调整重试策略
    delay = 2.0
    for attempt in range(5):  # 减少重试次数
        try:
            return ask_deepseek_once(sentence, choices)
        except RuntimeError as e:
            msg = str(e)
            
            if "rate limit" in msg.lower() or "429" in msg:
                time.sleep(delay)
                delay = min(delay * 2, 30)
                continue
            if "server" in msg.lower() or "timeout" in msg.lower():
                time.sleep(delay)
                delay = min(delay * 2, 30)
                continue
            raise
    
    raise RuntimeError("Exhausted retries for this sample.")

run_tag = f"flare_fpb_{MODEL.replace('-', '_')}"
save_dir = "/content"
pred_path = f"{save_dir}/{run_tag}_predictions.csv"
err_path = f"{save_dir}/{run_tag}_errors.csv"

rows_done = []
done_idx = set()
if os.path.exists(pred_path):
    old = pd.read_csv(pred_path)
    if "row_idx" in old.columns:
        rows_done = old.to_dict("records")
        done_idx = set(old["row_idx"].tolist())
        print(f"[resume] loaded {len(done_idx)} completed rows.")

err_rows = []
buf = []
save_every = 30

total = len(ds)
print(f"Starting DeepSeek model evaluation on {total} samples...")

for i in tqdm(range(total)):
    if i in done_idx:
        continue
    x = ds[i]
    text = x["text"]
    gold = x["answer"]

    try:
        pred = ask_deepseek(text, LABELS)
        raw = json.dumps({"label": pred})
    except Exception as e:
        pred = "UNKNOWN"
        raw = f"ERROR: {type(e).__name__}: {e}"
        err_rows.append({"row_idx": i, "id": x.get("id", i), "error": raw, "text": text})

    buf.append({
        "row_idx": i,
        "id": x.get("id", i),
        "text": text,
        "pred_raw": raw,
        "pred": pred,
        "label": gold
    })

    if len(buf) % save_every == 0:
        out = pd.DataFrame(rows_done + buf).sort_values("row_idx")
        out.to_csv(pred_path, index=False)
        if err_rows:
            pd.DataFrame(err_rows).to_csv(err_path, index=False)
        print(f"[checkpoint] saved {len(out)}/{total} -> {pred_path}")

out = pd.DataFrame(rows_done + buf).sort_values("row_idx")
out.to_csv(pred_path, index=False)
if err_rows:
    pd.DataFrame(err_rows).to_csv(err_path, index=False)
print(f"[done] DeepSeek evaluation completed -> {pred_path}")
if os.path.exists(err_path):
    err_count = len(pd.read_csv(err_path)) if os.path.getsize(err_path) > 0 else 0
    print(f"[errors] {err_count} errors logged -> {err_path}")

Starting DeepSeek model evaluation on 970 samples...


  3%|██▌                                                                              | 30/970 [01:11<36:04,  2.30s/it]

[checkpoint] saved 30/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


  6%|█████                                                                            | 60/970 [02:21<36:05,  2.38s/it]

[checkpoint] saved 60/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


  9%|███████▌                                                                         | 90/970 [03:34<35:22,  2.41s/it]

[checkpoint] saved 90/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 12%|█████████▉                                                                      | 120/970 [04:44<31:14,  2.21s/it]

[checkpoint] saved 120/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 15%|████████████▎                                                                   | 150/970 [05:58<33:44,  2.47s/it]

[checkpoint] saved 150/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 19%|██████████████▊                                                                 | 180/970 [07:11<31:13,  2.37s/it]

[checkpoint] saved 180/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 22%|█████████████████▎                                                              | 210/970 [08:23<29:57,  2.37s/it]

[checkpoint] saved 210/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 25%|███████████████████▊                                                            | 240/970 [09:33<29:26,  2.42s/it]

[checkpoint] saved 240/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 28%|██████████████████████▎                                                         | 270/970 [10:43<25:53,  2.22s/it]

[checkpoint] saved 270/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 31%|████████████████████████▋                                                       | 300/970 [11:53<26:41,  2.39s/it]

[checkpoint] saved 300/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 34%|███████████████████████████▏                                                    | 330/970 [13:07<26:46,  2.51s/it]

[checkpoint] saved 330/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 37%|█████████████████████████████▋                                                  | 360/970 [14:19<25:03,  2.47s/it]

[checkpoint] saved 360/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 40%|████████████████████████████████▏                                               | 390/970 [15:34<24:08,  2.50s/it]

[checkpoint] saved 390/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 43%|██████████████████████████████████▋                                             | 420/970 [16:43<21:19,  2.33s/it]

[checkpoint] saved 420/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 46%|█████████████████████████████████████                                           | 450/970 [17:56<21:31,  2.48s/it]

[checkpoint] saved 450/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 49%|███████████████████████████████████████▌                                        | 480/970 [19:10<20:06,  2.46s/it]

[checkpoint] saved 480/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 53%|██████████████████████████████████████████                                      | 510/970 [20:23<20:21,  2.66s/it]

[checkpoint] saved 510/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 56%|████████████████████████████████████████████▌                                   | 540/970 [21:35<16:33,  2.31s/it]

[checkpoint] saved 540/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 59%|███████████████████████████████████████████████                                 | 570/970 [22:47<16:37,  2.49s/it]

[checkpoint] saved 570/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 62%|█████████████████████████████████████████████████▍                              | 600/970 [23:59<15:42,  2.55s/it]

[checkpoint] saved 600/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 65%|███████████████████████████████████████████████████▉                            | 630/970 [25:13<14:05,  2.49s/it]

[checkpoint] saved 630/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 68%|██████████████████████████████████████████████████████▍                         | 660/970 [26:26<12:24,  2.40s/it]

[checkpoint] saved 660/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 71%|████████████████████████████████████████████████████████▉                       | 690/970 [27:42<11:51,  2.54s/it]

[checkpoint] saved 690/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 74%|███████████████████████████████████████████████████████████▍                    | 720/970 [28:54<10:02,  2.41s/it]

[checkpoint] saved 720/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 77%|█████████████████████████████████████████████████████████████▊                  | 750/970 [30:10<09:56,  2.71s/it]

[checkpoint] saved 750/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 80%|████████████████████████████████████████████████████████████████▎               | 780/970 [31:24<07:28,  2.36s/it]

[checkpoint] saved 780/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 84%|██████████████████████████████████████████████████████████████████▊             | 810/970 [32:39<06:52,  2.58s/it]

[checkpoint] saved 810/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 87%|█████████████████████████████████████████████████████████████████████▎          | 840/970 [33:50<04:51,  2.24s/it]

[checkpoint] saved 840/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 90%|███████████████████████████████████████████████████████████████████████▊        | 870/970 [35:03<04:04,  2.45s/it]

[checkpoint] saved 870/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 93%|██████████████████████████████████████████████████████████████████████████▏     | 900/970 [36:14<02:37,  2.25s/it]

[checkpoint] saved 900/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 96%|████████████████████████████████████████████████████████████████████████████▋   | 930/970 [37:28<01:32,  2.32s/it]

[checkpoint] saved 930/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


 99%|███████████████████████████████████████████████████████████████████████████████▏| 960/970 [38:41<00:24,  2.49s/it]

[checkpoint] saved 960/970 -> /content/flare_fpb_deepseek_chat_predictions.csv


100%|████████████████████████████████████████████████████████████████████████████████| 970/970 [39:07<00:00,  2.42s/it]

[done] DeepSeek evaluation completed -> /content/flare_fpb_deepseek_chat_predictions.csv





In [5]:
# Step 5: Install scikit-learn first
%pip install -q scikit-learn

# Then compute Macro-F1 and Accuracy
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

# 加载预测结果
df = pd.read_csv(pred_path).sort_values("row_idx").drop_duplicates("row_idx", keep="last")
ok = df[df["pred"] != "UNKNOWN"].copy()

print(f"DeepSeek Model Evaluation Results:")
print(f"Total samples: {len(df)}")
print(f"Successful predictions: {len(ok)}")
print(f"Failed predictions: {len(df) - len(ok)}")

if len(ok) > 0:
    # 计算评估指标
    f1_macro = f1_score(ok["label"], ok["pred"], labels=LABELS, average="macro", zero_division=0)
    f1_micro = f1_score(ok["label"], ok["pred"], labels=LABELS, average="micro", zero_division=0)
    f1_weighted = f1_score(ok["label"], ok["pred"], labels=LABELS, average="weighted", zero_division=0)
    accuracy = accuracy_score(ok["label"], ok["pred"])
    
    print("\n" + "="*50)
    print("EVALUATION RESULTS - DeepSeek Chat")
    print("="*50)
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"F1-Macro:  {f1_macro:.4f}")
    print(f"F1-Micro:  {f1_micro:.4f}")
    print(f"F1-Weighted: {f1_weighted:.4f}")
    
    # 详细分类报告
    print("\nDetailed Classification Report:")
    print(classification_report(ok["label"], ok["pred"], labels=LABELS, zero_division=0))
    
    # 混淆矩阵
    print("Confusion Matrix:")
    cm = confusion_matrix(ok["label"], ok["pred"], labels=LABELS)
    cm_df = pd.DataFrame(cm, index=LABELS, columns=LABELS)
    print(cm_df)
    
    # 保存评估结果
    eval_results = {
        "model": MODEL,
        "dataset": "TheFinAI/flare-fpb",
        "split": "test",
        "total_samples": len(df),
        "successful_predictions": len(ok),
        "failed_predictions": len(df) - len(ok),
        "accuracy": float(accuracy),
        "f1_macro": float(f1_macro),
        "f1_micro": float(f1_micro),
        "f1_weighted": float(f1_weighted),
        "evaluation_time": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
        "confusion_matrix": cm.tolist(),
        "labels": LABELS
    }
    
    eval_path = f"{save_dir}/{run_tag}_evaluation_results.json"
    with open(eval_path, "w") as f:
        json.dump(eval_results, f, indent=2)
    print(f"\nEvaluation results saved -> {eval_path}")
    
else:
    print("No successful predictions to evaluate!")

Note: you may need to restart the kernel to use updated packages.
DeepSeek Model Evaluation Results:
Total samples: 970
Successful predictions: 970
Failed predictions: 0

EVALUATION RESULTS - DeepSeek Chat
Accuracy:  0.7701
F1-Macro:  0.7772
F1-Micro:  0.7701
F1-Weighted: 0.7725

Detailed Classification Report:
              precision    recall  f1-score   support

    negative       0.75      0.95      0.84       116
     neutral       0.87      0.73      0.79       577
    positive       0.63      0.78      0.70       277

    accuracy                           0.77       970
   macro avg       0.75      0.82      0.78       970
weighted avg       0.79      0.77      0.77       970

Confusion Matrix:
          negative  neutral  positive
negative       110        5         1
neutral         33      420       124
positive         4       56       217

Evaluation results saved -> /content/flare_fpb_deepseek_chat_evaluation_results.json
