In [12]:
# 1. Install PyTorch EXACTLY matching Colab’s CUDA 12.6 build
!pip install -q --force-reinstall \
    torch==2.9.0+cu126 torchvision==0.24.0+cu126 torchaudio==2.9.0+cu126 \
    -f https://download.pytorch.org/whl/torch_stable.html

# 2. Install HF tools WITHOUT breaking pyarrow
!pip install -q "datasets<3.0" "pyarrow<17" transformers accelerate scikit-learn

import torch, transformers, datasets, sklearn, numpy as np
from accelerate import Accelerator

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Transformers:", transformers.__version__)
print("Datasets:", datasets.__version__)

accelerator = Accelerator(mixed_precision="fp16")
device = accelerator.device
print("Using:", device)


[31mERROR: Could not find a version that satisfies the requirement torch==2.9.0+cu126 (from versions: 2.2.0, 2.2.0+cpu, 2.2.0+cpu.cxx11.abi, 2.2.0+cu118, 2.2.0+cu121, 2.2.0+rocm5.6, 2.2.0+rocm5.7, 2.2.1, 2.2.1+cpu, 2.2.1+cpu.cxx11.abi, 2.2.1+cu118, 2.2.1+cu121, 2.2.1+rocm5.6, 2.2.1+rocm5.7, 2.2.2, 2.2.2+cpu, 2.2.2+cpu.cxx11.abi, 2.2.2+cu118, 2.2.2+cu121, 2.2.2+rocm5.6, 2.2.2+rocm5.7, 2.3.0, 2.3.0+cpu, 2.3.0+cpu.cxx11.abi, 2.3.0+cu118, 2.3.0+cu121, 2.3.0+rocm5.7, 2.3.0+rocm6.0, 2.3.1, 2.3.1+cpu, 2.3.1+cpu.cxx11.abi, 2.3.1+cu118, 2.3.1+cu121, 2.3.1+rocm5.7, 2.3.1+rocm6.0, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0, 2.7.1, 2.8.0, 2.9.0, 2.9.1)[0m[31m
[0m[31mERROR: No matching distribution found for torch==2.9.0+cu126[0m[31m
[0mTorch: 2.9.0+cu126
CUDA available: True
Transformers: 4.57.2
Datasets: 2.21.0
Using: cuda


In [13]:
from datasets import load_dataset, DatasetDict

raw_ds = load_dataset("reuters21578", "ModApte", trust_remote_code=True)
raw_ds


DatasetDict({
    test: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 3299
    })
    train: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 9603
    })
    unused: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title'],
        num_rows: 722
    })
})

In [14]:
def extract_first_topic(batch):
    topics = batch["topics"]
    batch["label_text"] = topics[0] if len(topics) > 0 else "__NO_LABEL__"
    return batch

ds = raw_ds.map(extract_first_topic)
ds = ds.filter(lambda x: x["label_text"] != "__NO_LABEL__", batched=False)
ds


DatasetDict({
    test: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 3019
    })
    train: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 7775
    })
    unused: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 565
    })
})

In [15]:
from collections import Counter

# Count all labels across train/test/unused
label_counts = Counter()

for split in ds.keys():
    label_counts.update(ds[split]["label_text"])

print("Total unique before filtering:", len(label_counts))

# Keep only labels with >=20 samples
MIN_SAMPLES = 10
keep_labels = {lbl for lbl, cnt in label_counts.items() if cnt >= MIN_SAMPLES}

print("Kept labels:", len(keep_labels))

def filter_rare(batch):
    return batch["label_text"] in keep_labels

ds = ds.filter(filter_rare)
ds


Total unique before filtering: 82
Kept labels: 49


DatasetDict({
    test: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 2989
    })
    train: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 7716
    })
    unused: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text'],
        num_rows: 554
    })
})

In [16]:
# Build final sorted label list
final_labels = sorted(list(keep_labels))
label_to_id = {name: i for i, name in enumerate(final_labels)}
id_to_label = {i: name for name, i in label_to_id.items()}

num_labels = len(final_labels)
print("Final num_labels:", num_labels)
final_labels[:20]


Final num_labels: 49


['acq',
 'alum',
 'bop',
 'carcass',
 'cocoa',
 'coffee',
 'copper',
 'cotton',
 'cpi',
 'crude',
 'dlr',
 'earn',
 'fuel',
 'gas',
 'gnp',
 'gold',
 'grain',
 'heat',
 'hog',
 'housing']

In [17]:
def encode_label(batch):
    batch["label"] = label_to_id[batch["label_text"]]
    return batch

ds = ds.map(encode_label)
ds


DatasetDict({
    test: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text', 'label'],
        num_rows: 2989
    })
    train: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text', 'label'],
        num_rows: 7716
    })
    unused: Dataset({
        features: ['text', 'text_type', 'topics', 'lewis_split', 'cgis_split', 'old_id', 'new_id', 'places', 'people', 'orgs', 'exchanges', 'date', 'title', 'label_text', 'label'],
        num_rows: 554
    })
})

In [18]:
ds = ds.remove_columns([
    "topics", "label_text", "text_type", "lewis_split",
    "cgis_split", "old_id", "new_id", "places", "people",
    "orgs", "exchanges", "date", "title"
])

train_valid = ds["train"].train_test_split(test_size=0.1, seed=42)

dataset = DatasetDict({
    "train": train_valid["train"],
    "validation": train_valid["test"],
    "test": ds["test"]
})

dataset


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 6944
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 772
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2989
    })
})

In [19]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

max_length = 256

def tok(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=max_length,
    )

tokenized = dataset.map(tok, batched=True)
tokenized = tokenized.rename_column("label", "labels")

tokenized.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"]
)

tokenized


Map:   0%|          | 0/772 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 6944
    })
    validation: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 772
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask'],
        num_rows: 2989
    })
})

In [20]:
from torch.utils.data import DataLoader

train_loader = DataLoader(tokenized["train"], batch_size=16, shuffle=True)
valid_loader = DataLoader(tokenized["validation"], batch_size=32)
test_loader  = DataLoader(tokenized["test"], batch_size=32)

batch = next(iter(train_loader))
batch.keys(), batch["input_ids"].shape


(dict_keys(['labels', 'input_ids', 'attention_mask']), torch.Size([16, 256]))

In [21]:
from transformers import DistilBertForSequenceClassification
import torch.nn as nn

phase1_model_path = "/content/bert_20ng"

base_model = DistilBertForSequenceClassification.from_pretrained(phase1_model_path)

print("20NG num_labels:", base_model.config.num_labels)

# Replace classifier
hidden = base_model.config.dim
base_model.classifier = nn.Linear(hidden, num_labels)

# Inject correct mappings
base_model.config.num_labels = num_labels
base_model.num_labels = num_labels

base_model.config.label2id = label_to_id
base_model.config.id2label = id_to_label

nn.init.xavier_uniform_(base_model.classifier.weight)
nn.init.zeros_(base_model.classifier.bias)

model = base_model.to(device)
model


20NG num_labels: 20


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [22]:
import json

labels_path = "/content/labels.json"

with open(labels_path, "w") as f:
    json.dump({
        "label_to_id": label_to_id,
        "id_to_label": id_to_label
    }, f, indent=2)

print("Saved:", labels_path)


Saved: /content/labels.json


In [23]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

steps_per_epoch = len(train_loader)
max_steps = steps_per_epoch * 3
warmup = int(0.1 * max_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer, warmup, max_steps
)

model, optimizer, train_loader, valid_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, valid_loader, scheduler
)

model


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [24]:
from tqdm.auto import tqdm

for epoch in range(3):
    model.train()
    total_loss = 0

    pbar = tqdm(train_loader)
    for batch in pbar:
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        accelerator.backward(loss)
        accelerator.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        pbar.set_description(f"Epoch {epoch+1} Loss {loss.item():.4f}")

    print(f"Epoch {epoch+1} avg loss:", total_loss/len(train_loader))


  0%|          | 0/434 [00:00<?, ?it/s]

Epoch 1 avg loss: 1.3970276484017


  0%|          | 0/434 [00:00<?, ?it/s]

Epoch 2 avg loss: 0.5324962469427267


  0%|          | 0/434 [00:00<?, ?it/s]

Epoch 3 avg loss: 0.40286335956803115


In [25]:
from sklearn.metrics import accuracy_score

def eval_model(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(**batch)
            p = out.logits.argmax(dim=-1)
            preds.append(accelerator.gather_for_metrics(p).cpu())
            labels.append(accelerator.gather_for_metrics(batch["labels"]).cpu())
    preds = torch.cat(preds).numpy()
    labels = torch.cat(labels).numpy()
    return accuracy_score(labels, preds)

acc = eval_model(model, valid_loader)
print("Validation accuracy:", acc)


Validation accuracy: 0.8575129533678757


In [26]:
save_dir = "/content/bert_reuters21578"

unwrapped = accelerator.unwrap_model(model)

unwrapped.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

# Copy labels.json into directory
import shutil
shutil.copy("/content/labels.json", f"{save_dir}/labels.json")

print("Saved to:", save_dir)


Saved to: /content/bert_reuters21578


In [27]:
import shutil
from google.colab import files

zip_path = "/content/bert_reuters21578.zip"
shutil.make_archive("/content/bert_reuters21578", "zip", save_dir)

files.download(zip_path)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [28]:
# ============================================================
# 8. Full evaluation: Loss, Accuracy, Macro F1, per-label scores
# ============================================================
import torch
from tqdm.auto import tqdm
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_recall_fscore_support,
)
import numpy as np

def evaluate_full(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0.0

    for batch in tqdm(dataloader, desc="Evaluating", disable=not accelerator.is_local_main_process):
        with torch.no_grad():
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
            )
        loss = outputs.loss
        total_loss += loss.item()

        logits = outputs.logits
        preds = logits.argmax(dim=-1)

        all_preds.append(accelerator.gather_for_metrics(preds).cpu())
        all_labels.append(accelerator.gather_for_metrics(batch["labels"]).cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    # Per-class stats
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average=None, labels=list(id_to_label.keys())
    )

    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "y_true": all_labels,
        "y_pred": all_preds,
    }

# ------------------------------------------------------------
# Run evaluation on VALIDATION or TEST set
# ------------------------------------------------------------
results = evaluate_full(model, valid_loader)

# ------------------------------------------------------------
# Print global stats
# ------------------------------------------------------------
print("\n================ Overall Performance ================\n")
print(f"Eval Loss:      {results['loss']:.4f}")
print(f"Accuracy:       {results['accuracy']:.4f}")
print(f"Macro F1:       {results['macro_f1']:.4f}")

# ------------------------------------------------------------
# Compute per-class F1 ranking
# ------------------------------------------------------------
label_f1_pairs = [
    (id_to_label[i], results["f1"][i])
    for i in range(num_labels)
]

# Sort descending
label_f1_pairs_sorted = sorted(label_f1_pairs, key=lambda x: x[1], reverse=True)

best_3 = label_f1_pairs_sorted[:3]
worst_3 = label_f1_pairs_sorted[-3:]

print("\n================ Best 3 Categories ================\n")
for i, (label, score) in enumerate(best_3, start=1):
    print(f"{i}. {label:<20} F1 = {score:.3f}")

print("\n================ Worst 3 Categories ================\n")
for i, (label, score) in enumerate(worst_3, start=1):
    print(f"{i}. {label:<20} F1 = {score:.3f}")


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]



Eval Loss:      0.4958
Accuracy:       0.8575
Macro F1:       0.6210


1. alum                 F1 = 1.000
2. coffee               F1 = 1.000
3. gas                  F1 = 1.000


1. strategic-metal      F1 = 0.000
2. wheat                F1 = 0.000
3. wpi                  F1 = 0.000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [29]:
for i, (label, score) in enumerate(label_f1_pairs_sorted, start=1):
    print(f"{i}. {label:<20} F1 = {score:.3f}")

1. alum                 F1 = 1.000
2. coffee               F1 = 1.000
3. gas                  F1 = 1.000
4. gold                 F1 = 1.000
5. hog                  F1 = 1.000
6. housing              F1 = 1.000
7. jobs                 F1 = 1.000
8. orange               F1 = 1.000
9. retail               F1 = 1.000
10. zinc                 F1 = 1.000
11. sugar                F1 = 0.923
12. grain                F1 = 0.894
13. trade                F1 = 0.893
14. earn                 F1 = 0.891
15. acq                  F1 = 0.890
16. gnp                  F1 = 0.889
17. tin                  F1 = 0.889
18. crude                F1 = 0.862
19. money-fx             F1 = 0.842
20. copper               F1 = 0.800
21. ship                 F1 = 0.800
22. veg-oil              F1 = 0.800
23. oilseed              F1 = 0.727
24. bop                  F1 = 0.667
25. cpi                  F1 = 0.667
26. ipi                  F1 = 0.667
27. pet-chem             F1 = 0.667
28. rubber               F1 = 0.667
2