In [2]:
import os, sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

print("Project root:", project_root)

Project root: E:\cas2105_hw6\agnews-pipeline


In [3]:
import random

import pandas as pd
from datasets import load_dataset

from src.baseline import baseline_predict_batch, clean_text
from src.embeddings_pipeline import AGNewsEmbeddingClassifier
from src.evaluation_utils import evaluate_multiclass, print_classification_summary

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
raw = load_dataset("ag_news")

train_ds = raw["train"]
test_ds = raw["test"]

df_train_full = train_ds.to_pandas()
df_test_full = test_ds.to_pandas()

print(df_train_full.head())
print(df_train_full["label"].value_counts())

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Generating train split: 100%|███████████████████████████████████████| 120000/120000 [00:00<00:00, 708880.48 examples/s]
Generating test split: 100%|████████████████████████████████████████████| 7600/7600 [00:00<00:00, 574956.00 examples/s]


                                                text  label
0  Wall St. Bears Claw Back Into the Black (Reute...      2
1  Carlyle Looks Toward Commercial Aerospace (Reu...      2
2  Oil and Economy Cloud Stocks' Outlook (Reuters...      2
3  Iraq Halts Oil Exports from Main Southern Pipe...      2
4  Oil prices soar to all-time record, posing new...      2
label
2    30000
3    30000
1    30000
0    30000
Name: count, dtype: int64


In [5]:
RANDOM_SEED = 42

# Shuffle and downsample
df_train_full = df_train_full.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)
df_test_full = df_test_full.sample(frac=1.0, random_state=RANDOM_SEED).reset_index(drop=True)

N_TRAIN = 4000
N_TEST = 1000

df_train = df_train_full.iloc[:N_TRAIN].reset_index(drop=True)
df_test = df_test_full.iloc[:N_TEST].reset_index(drop=True)

print("Train size:", len(df_train))
print("Test size:", len(df_test))

# Save to CSV (optional)
os.makedirs("../data", exist_ok=True)
df_train.to_csv("../data/agnews_train_small.csv", index=False)
df_test.to_csv("../data/agnews_test_small.csv", index=False)

Train size: 4000
Test size: 1000


In [6]:
df_train["text_clean"] = df_train["text"].apply(clean_text)
df_test["text_clean"] = df_test["text"].apply(clean_text)

X_train = df_train["text_clean"].tolist()
y_train = df_train["label"].tolist()

X_test = df_test["text_clean"].tolist()
y_test = df_test["label"].tolist()

label_names = ["World", "Sports", "Business", "Sci/Tech"]

In [7]:
y_pred_baseline = baseline_predict_batch(X_test)

baseline_metrics = evaluate_multiclass(y_test, y_pred_baseline)
print("Baseline metrics:", baseline_metrics)

print_classification_summary(y_test, y_pred_baseline, target_names=label_names)

Baseline metrics: {'accuracy': 0.537, 'f1_macro': 0.5243673998965013}
              precision    recall  f1-score   support

       World     0.4011    0.9364    0.5616       236
      Sports     0.6898    0.6157    0.6507       242
    Business     0.6744    0.3551    0.4652       245
    Sci/Tech     0.7692    0.2888    0.4199       277

    accuracy                         0.5370      1000
   macro avg     0.6336    0.5490    0.5244      1000
weighted avg     0.6399    0.5370    0.5203      1000



In [8]:
pipeline = AGNewsEmbeddingClassifier("all-MiniLM-L6-v2")
pipeline.fit(X_train, y_train)

Batches: 100%|███████████████████████████████████████████████████████████████████████| 125/125 [00:49<00:00,  2.51it/s]


In [9]:
y_pred_embed = pipeline.predict(X_test)

embed_metrics = evaluate_multiclass(y_test, y_pred_embed)
print("Embedding pipeline metrics:", embed_metrics)

print_classification_summary(y_test, y_pred_embed, target_names=label_names)

Batches: 100%|█████████████████████████████████████████████████████████████████████████| 32/32 [00:09<00:00,  3.23it/s]


Embedding pipeline metrics: {'accuracy': 0.873, 'f1_macro': 0.8741399886441847}
              precision    recall  f1-score   support

       World     0.8694    0.9025    0.8857       236
      Sports     0.9512    0.9669    0.9590       242
    Business     0.7930    0.8286    0.8104       245
    Sci/Tech     0.8814    0.8051    0.8415       277

    accuracy                         0.8730      1000
   macro avg     0.8737    0.8758    0.8741      1000
weighted avg     0.8738    0.8730    0.8727      1000



In [10]:
df_test["baseline_pred"] = y_pred_baseline
df_test["embed_pred"] = y_pred_embed

mask_embed_better = (df_test["baseline_pred"] != y_test) & (df_test["embed_pred"] == y_test)

examples = df_test[mask_embed_better].head(10)

for i, row in examples.iterrows():
    print("----")
    print("Text:", row["text"])
    print("True label:", label_names[row["label"]])
    print("Baseline:", label_names[row["baseline_pred"]])
    print("Embed:", label_names[row["embed_pred"]])

----
Text: Dependent species risk extinction The global extinction crisis is worse than thought, because thousands of  quot;affiliated quot; species also at risk do not figure in calculations.
True label: Sci/Tech
Baseline: World
Embed: Sci/Tech
----
Text: Profit Plunges at International Game Tech International Game Technology, the world #39;s biggest maker of slot machines, Tuesday said said profit for its latest quarter fell 50 percent from a year ago due to a charge for early redemption of debt and a tax adjustment.
True label: Business
Baseline: Sci/Tech
Embed: Business
----
Text: General Mills goes whole grains NEW YORK (CNN/Money) - General Mills announced plans Thursday to start using healthier whole grains in all of its ready-to-eat cereals, including children #39;s cereals such as Trix, Cocoa Puffs and Lucky Charms.
True label: Business
Baseline: World
Embed: Business
----
Text: United Apology over Website Abuse Manchester United have been forced to issue an embarrassing apolo