In [None]:
!pip install -q transformers datasets evaluate accelerate

import os, time, json
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)
import evaluate

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

os.makedirs("results", exist_ok=True)


##1. Load dataset

In [None]:
ds = load_dataset("ag_news")

# train
split   = ds["train"].train_test_split(
    test_size=0.1,
    seed=42,
    stratify_by_column="label"
)
train_ds = split["train"]
val_ds   = split["test"]
test_ds  = ds["test"]

print(train_ds)
print(val_ds)

##

In [None]:
_df = pd.DataFrame({"text": train_ds["text"], "label": train_ds["label"]})
_sub = _df.groupby("label", group_keys=False).apply(
    lambda g: g.sample(n=5000, random_state=42)
)
train_small = Dataset.from_pandas(_sub.reset_index(drop=True), preserve_index=False)

print(train_small)

##2. DistilBERT v1（max_length=128, 200 steps）

In [None]:
model_name = "distilbert-base-uncased"
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)

def preprocess(batch):
    return tok(batch["text"], truncation=True, padding=False, max_length=128)

enc_train = (
    train_small
    .map(preprocess, batched=True)
    .rename_column("label", "labels")
    .with_format("torch")
)

enc_val = (
    val_ds
    .map(preprocess, batched=True)
    .rename_column("label", "labels")
    .with_format("torch")
)

collator = DataCollatorWithPadding(tok)

enc_train, enc_val

In [None]:
# DistilBERT classfication model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=4
)

for p in model.base_model.parameters():
    p.requires_grad = False

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params       = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable_params:,} / {all_params:,}")


In [None]:
#Accuracy + Macro-F1
acc_metric = evaluate.load("accuracy")
f1_metric  = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": acc_metric.compute(predictions=preds, references=labels)["accuracy"],
        "macro_f1": f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"],
    }

In [None]:
args = TrainingArguments(
    output_dir="out_distilbert_head",
    max_steps=200,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    do_eval=False,
    logging_steps=50,
    save_steps=10_000_000,
    seed=42,
    )

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=enc_train,
    eval_dataset=enc_val,
    data_collator=collator,
    compute_metrics=compute_metrics,
    tokenizer=None,
)

t0 = time.time()
trainer.train()
train_time_min = (time.time() - t0) / 60.0

metrics = trainer.evaluate(enc_val)
metrics = {k: float(v) for k, v in metrics.items()}
metrics["train_time_min"] = train_time_min

print("\n=== DistilBERT head-only on validation set ===")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

In [None]:
with open("results/distilbert_head_v1.json","w") as f:
    json.dump(metrics,f,indent=2)

pred_out = trainer.predict(enc_val)
logits = pred_out.predictions
y_true = pred_out.label_ids
y_pred = logits.argmax(axis=-1)

labels4 = ["World","Sports","Business","Sci/Tech"]
cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3])

disp = ConfusionMatrixDisplay(cm, display_labels=labels4)
plt.figure(figsize=(5.2,4.4))
disp.plot(cmap="Blues", values_format="d", colorbar=False)
plt.title("Confusion Matrix — DistilBERT v1 (128, 200 steps)")
plt.tight_layout()
plt.savefig("results/confmat_distilbert_v1.png", dpi=300)
plt.show()

##Subword Length Analysis

In [None]:
sample_texts = train_ds["text"][:20000]
enc_no_trunc = tok(sample_texts, truncation=False, add_special_tokens=True)
sub_lens = np.array([len(ids) for ids in enc_no_trunc["input_ids"]])

p50  = float(np.percentile(sub_lens, 50))
p95  = float(np.percentile(sub_lens, 95))
t128 = float((sub_lens > 128).mean())
t256 = float((sub_lens > 256).mean())

print({"p50": p50, "p95": p95, "trunc@128": t128, "trunc@256": t256})

subword_stats = {
    "p50": p50,
    "p95": p95,
    "trunc@128": t128,
    "trunc@256": t256,
}
with open("results/subword_stats.json", "w") as f:
    json.dump(subword_stats, f, indent=2)

##DistilBERT v2（max_length=192, 1000 steps）

In [None]:
model_name = "distilbert-base-uncased"
tok_192 = AutoTokenizer.from_pretrained(model_name, use_fast=True)

def preprocess_192(batch):
    """
    v2 ：change max_length of DistilBERT from128to192。
    without padding
    """
    return tok_192(
        batch["text"],
        truncation=True,
        padding=False,
        max_length=192,
    )


enc_train_192 = train_small.map(preprocess_192, batched=True)
enc_val_192   = val_ds.map(preprocess_192, batched=True)

enc_train_192 = enc_train_192.rename_column("label", "labels")
enc_val_192   = enc_val_192.rename_column("label", "labels")
enc_train_192.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
enc_val_192.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

collator_192 = DataCollatorWithPadding(tokenizer=tok_192)

In [None]:
model_v2 = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=4
)

for p in model_v2.base_model.parameters():
    p.requires_grad = False

trainable_params = sum(p.numel() for p in model_v2.parameters() if p.requires_grad)
all_params       = sum(p.numel() for p in model_v2.parameters())
print(f"[v2] Trainable params: {trainable_params:,} / {all_params:,}")

In [None]:
args_v2 = TrainingArguments(
    output_dir="out_distilbert_head_v2_max192",
    max_steps=1000,                    # 200 -> 1000
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    do_eval=False,
    logging_steps=50,
    save_steps=10_000_000,
    seed=42,
)

In [None]:
#Trainer
trainer_v2 = Trainer(
    model=model_v2,
    args=args_v2,
    train_dataset=enc_train_192,
    eval_dataset=enc_val_192,
    data_collator=collator_192,
    compute_metrics=compute_metrics,
    tokenizer=None,
)

t0 = time.time()
trainer_v2.train()
train_time_min_v2 = (time.time() - t0) / 60.0

metrics_v2 = trainer_v2.evaluate(enc_val_192)
metrics_v2 = {k: float(v) for k, v in metrics_v2.items()}
metrics_v2["train_time_min"] = train_time_min_v2

print("\n=== DistilBERT head-only v2 (max_length=192, 1000 steps) on validation set ===")
for k, v in metrics_v2.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")

os.makedirs("results", exist_ok=True)
with open("results/distilbert_head_v2_max192.json", "w") as f:
    json.dump(metrics_v2, f, indent=2)

In [None]:
pred_out_v2 = trainer_v2.predict(enc_val_192)
logits_v2 = pred_out_v2.predictions
y_true_v2 = pred_out_v2.label_ids
y_pred_v2 = logits_v2.argmax(axis=-1)

labels4 = ["World","Sports","Business","Sci/Tech"]
cm_v2 = confusion_matrix(y_true_v2, y_pred_v2, labels=[0,1,2,3])

disp_v2 = ConfusionMatrixDisplay(cm_v2, display_labels=labels4)
plt.figure(figsize=(5.2, 4.4))
disp_v2.plot(cmap="Blues", values_format="d", colorbar=False)
plt.title("Confusion Matrix - DistilBERT head-only v2 (max_length=192, 1000 steps)")
plt.tight_layout()
plt.savefig("results/confmat_distilbert_head_v2_max192.png", dpi=300)
plt.show()
