In [4]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [5]:
import os
import torch
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import sklearn
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
from torch.optim import AdamW
from torchinfo import summary
import numpy as np

In [6]:
def load_news_data(data_file):

    df = pd.read_json(data_file, lines=True)
    df.head()

    df['category'] = df['category'].map(lambda x: "WORLDPOST" if x == "THE WORLDPOST" else x)

    df['headline'] = df['headline'].apply(lambda x: str(x).lower())
    df['short_description'] = df['short_description'].apply(lambda x: str(x).lower())

    df['text'] = df['headline'] + " " + df['short_description']
    encoder = LabelEncoder()
    df['label'] = encoder.fit_transform(df['category'])
    print(f"The dataset contains {df['category'].nunique()} unique categories.")

    return df['text'].tolist(), df['label'].tolist(), encoder.classes_.tolist()

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
data_file = "/content/drive/MyDrive/News_Category_Dataset_v2.json"
texts, labels, label_names = load_news_data(data_file)

The dataset contains 40 unique categories.


In [9]:

for idx, name in enumerate(label_names):
    print(f"{idx} → {name}")

0 → ARTS
1 → ARTS & CULTURE
2 → BLACK VOICES
3 → BUSINESS
4 → COLLEGE
5 → COMEDY
6 → CRIME
7 → CULTURE & ARTS
8 → DIVORCE
9 → EDUCATION
10 → ENTERTAINMENT
11 → ENVIRONMENT
12 → FIFTY
13 → FOOD & DRINK
14 → GOOD NEWS
15 → GREEN
16 → HEALTHY LIVING
17 → HOME & LIVING
18 → IMPACT
19 → LATINO VOICES
20 → MEDIA
21 → MONEY
22 → PARENTING
23 → PARENTS
24 → POLITICS
25 → QUEER VOICES
26 → RELIGION
27 → SCIENCE
28 → SPORTS
29 → STYLE
30 → STYLE & BEAUTY
31 → TASTE
32 → TECH
33 → TRAVEL
34 → WEDDINGS
35 → WEIRD NEWS
36 → WELLNESS
37 → WOMEN
38 → WORLD NEWS
39 → WORLDPOST


In [27]:
# Set up parameters
bert_model_name = 'prajjwal1/bert-tiny'
num_classes = len(label_names)
max_length = 256
batch_size = 32
num_epochs = 1
learning_rate = 2e-5

In [28]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [29]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        logits = self.fc(x)
        return logits

In [30]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(data_loader, desc="Training", leave=True)

    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(data_loader)
    print(f"\nEpoch completed. Average training loss: {avg_loss:.4f}")

In [31]:
# Without mixed precision
from torch.profiler import profile, record_function, ProfilerActivity

def train_original(model, dataloader, optimizer, scheduler, device, epoch=None):
    model.train()
    total_loss = 0

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] if device.type == "cuda" else [ProfilerActivity.CPU],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./log_train_profiler"),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as profiler:

        for step, batch in enumerate(tqdm(dataloader, desc=f"Training Epoch {epoch if epoch is not None else ''}")):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()

            with record_function("forward_pass"):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                # logits = outputs.logits
                loss = nn.CrossEntropyLoss()(outputs, labels)

            with record_function("backward_pass"):
                loss.backward()
                optimizer.step()

            scheduler.step()
            profiler.step()  # ✅ 每个 step 结束标记

            total_loss += loss.item()
            # wandb.log({"train/loss_batch": loss.item()})

    print(profiler.key_averages().table(sort_by="cuda_time_total", row_limit=20))
    print(profiler.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=20))
    print(profiler.key_averages().table(sort_by="cpu_time_total", row_limit=20))

    avg_loss = total_loss / len(dataloader)
    print(f"Average training loss: {avg_loss:.4f}")
    # wandb.log({"train/loss_epoch": avg_loss})

In [32]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)

            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    acc = accuracy_score(actual_labels, predictions)
    macro_f1 = f1_score(actual_labels, predictions, average='macro')
    weighted_f1 = f1_score(actual_labels, predictions, average='weighted')

    return acc, macro_f1, weighted_f1

In [33]:
def predict_news_category(text, model, tokenizer, device, encoder, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length,
                         padding='max_length', truncation=True)

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, predicted_label = torch.max(outputs, dim=1)

    predicted_category = encoder.inverse_transform(predicted_label.cpu().numpy())[0]
    return predicted_category

In [42]:
from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler
import time
def predict_news_category(text, model, tokenizer, device, encoder, max_length=128):
    model.eval()

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length
    )
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] if device.type == "cuda" else [ProfilerActivity.CPU],
        schedule=schedule(wait=1, warmup=1, active=5, repeat=1),  # 跳过第 1 次，记录后 5 次
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        on_trace_ready=tensorboard_trace_handler("./log_predict_base_warmup")
    ) as profiler:
        for i in range(7):  # 总共运行 7 次
            with torch.no_grad():
                with record_function(f"model_inference_{i}"):
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    _, predicted_label = torch.max(outputs, dim=1)
            profiler.step()

    print(profiler.key_averages().table(sort_by="cuda_time_total", row_limit=20))
    print(profiler.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=20))
    print(profiler.key_averages().table(sort_by="cpu_time_total", row_limit=20))

    predicted_category = encoder.inverse_transform(predicted_label.cpu().numpy())[0]
    print(f"✅ Warmup profiling done. Use TensorBoard:\n\n  %tensorboard --logdir=./log_predict_base_warmup")
    return predicted_category

In [34]:
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [35]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)
print(f"Using device: {device}")

pytorch_model.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

Using device: cuda


In [37]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [38]:
seq_len = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_data = {
    "input_ids": torch.zeros((batch_size, seq_len), dtype=torch.long).to(device),
    "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.long).to(device)
}

summary(model, input_data=input_data)

Layer (type:depth-idx)                                       Output Shape              Param #
BERTClassifier                                               [32, 40]                  --
├─BertModel: 1-1                                             [32, 128]                 --
│    └─BertEmbeddings: 2-1                                   [32, 128, 128]            --
│    │    └─Embedding: 3-1                                   [32, 128, 128]            3,906,816
│    │    └─Embedding: 3-2                                   [32, 128, 128]            256
│    │    └─Embedding: 3-3                                   [1, 128, 128]             65,536
│    │    └─LayerNorm: 3-4                                   [32, 128, 128]            256
│    │    └─Dropout: 3-5                                     [32, 128, 128]            --
│    └─BertEncoder: 2-2                                      [32, 128, 128]            --
│    │    └─ModuleList: 3-6                                  --                   

In [39]:
for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train_original(model, train_dataloader, optimizer, scheduler, device)
        accuracy, macro_f1, weighted_f1 = evaluate(model, val_dataloader, device)
        print(f"Validation Accuracy: {accuracy:.4f}")
        print(f"Macro F1: {macro_f1:.4f}")
        print(f"Weighted F1: {weighted_f1:.4f}")

Epoch 1/1


Training Epoch : 100%|██████████| 5022/5022 [03:11<00:00, 26.19it/s]


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         0.00%       0.000us         0.00%       0.000us       0.000us     143.332ms      1331.37%     143.332ms      71.666ms           0 b           0 b           0 b           0 

In [34]:
torch.save(model.state_dict(), "bert_classifier.pth")

In [43]:
texts, labels, label_classes = load_news_data("/content/drive/MyDrive/News_Category_Dataset_v2.json")

encoder = LabelEncoder()
encoder.classes_ = np.array(label_classes)

test_text = "NASA launches new space telescope to explore exoplanets."
predicted_category = predict_news_category(test_text, model, tokenizer, device, encoder)

print(f"Headline: {test_text}")
print(f"Predicted Category: {predicted_category}")

The dataset contains 40 unique categories.
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      model_inference_2         0.00%       0.000us         0.00%       0.000us       0.000us       4.729ms       312.65%       4.729ms       4.729ms           0