In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import random
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Part 2 (original)
dataset = load_dataset('ag_news')
train_df = pd.DataFrame(dataset['train'])
test_df = pd.DataFrame(dataset['test'])

print("Train:", train_df.shape, "| Test:", test_df.shape)

def tokenize(text):
    return text.lower().split()

counter = Counter()
for text in train_df['text']:
    counter.update(tokenize(text))

vocab = {word: idx + 2 for idx, (word, _) in enumerate(counter.items())}
vocab["<pad>"] = 0
vocab["<unk>"] = 1

# NEW CODE TO ADD HERE
import pickle
with open('ag_news_vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)
print(f"Vocabulary saved with {len(vocab)} words")
# END OF NEW CODE

def text_pipeline(text):
    return [vocab.get(tok, vocab["<unk>"]) for tok in tokenize(text)]

def label_pipeline(label):
    return label

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Train: (120000, 2) | Test: (7600, 2)
Vocabulary saved with 158735 words


In [3]:
class AGNewsDataset(Dataset):
    def __init__(self, df):
        self.texts = df['text'].tolist()
        self.labels = df['label'].tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = torch.tensor(text_pipeline(self.texts[idx]), dtype=torch.long)
        label = torch.tensor(label_pipeline(self.labels[idx]), dtype=torch.long)
        return text, label

def collate_batch(batch):
    texts, labels = zip(*batch)
    texts = pad_sequence(texts, batch_first=True, padding_value=vocab["<pad>"])
    return texts.to(device), torch.tensor(labels, dtype=torch.long).to(device)

train_loader = DataLoader(AGNewsDataset(train_df), batch_size=64, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(AGNewsDataset(test_df), batch_size=64, collate_fn=collate_batch)

In [4]:
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab["<pad>"])
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        emb = self.embedding(x)
        pooled = emb.mean(dim=1)
        return self.fc(pooled)


In [5]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, all_preds, all_labels = 0, [], []

    for x, y in loader:
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y.size(0)
        all_preds.extend(logits.argmax(1).tolist())
        all_labels.extend(y.tolist())

    return total_loss / len(loader.dataset), accuracy_score(all_labels, all_preds)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, all_preds, all_labels = 0, [], []

    with torch.no_grad():
        for x, y in loader:
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * y.size(0)
            all_preds.extend(logits.argmax(1).tolist())
            all_labels.extend(y.tolist())

    return total_loss / len(loader.dataset), accuracy_score(all_labels, all_preds)

In [8]:
teacher = TextClassifier(len(vocab), embed_dim=128, num_classes=4).to(device)
opt_teacher = optim.Adam(teacher.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    train_loss, train_acc = train_one_epoch(teacher, train_loader, opt_teacher, loss_fn)
    val_loss, val_acc = evaluate(teacher, test_loader, loss_fn)
    print(f"[Epoch {epoch+1}] Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

[Epoch 1] Train Acc: 0.8093 | Val Acc: 0.8911
[Epoch 2] Train Acc: 0.9157 | Val Acc: 0.9057
[Epoch 3] Train Acc: 0.9363 | Val Acc: 0.9130
[Epoch 4] Train Acc: 0.9507 | Val Acc: 0.9141
[Epoch 5] Train Acc: 0.9610 | Val Acc: 0.9149
[Epoch 6] Train Acc: 0.9697 | Val Acc: 0.9166
[Epoch 7] Train Acc: 0.9758 | Val Acc: 0.9136
[Epoch 8] Train Acc: 0.9812 | Val Acc: 0.9134
[Epoch 9] Train Acc: 0.9853 | Val Acc: 0.9111
[Epoch 10] Train Acc: 0.9885 | Val Acc: 0.9105


In [9]:
def get_soft_labels(model, loader, temp=3.0):
    model.eval()
    soft_outputs = []

    with torch.no_grad():
        for x, _ in loader:
            logits = model(x)
            softened = F.log_softmax(logits / temp, dim=1)
            soft_outputs.append(softened)

    return torch.cat(soft_outputs, dim=0)

soft_labels = get_soft_labels(teacher, train_loader)

In [11]:
student = TextClassifier(len(vocab), embed_dim=64, num_classes=4).to(device)
opt_student = optim.Adam(student.parameters(), lr=1e-3)
kl_div = nn.KLDivLoss(reduction="batchmean")
T = 3.0
alpha = 0.7

def train_student_kd(model, loader, soft_logits):
    model.train()
    all_preds, all_labels = [], []
    total_loss = 0
    idx = 0

    for x, y in loader:
        opt_student.zero_grad()
        out = model(x)

        soft = soft_logits[idx:idx + y.size(0)]
        idx += y.size(0)

        loss = alpha * kl_div(F.log_softmax(out / T, dim=1), soft.exp()) + (1 - alpha) * loss_fn(out, y)
        loss.backward()
        opt_student.step()

        total_loss += loss.item() * y.size(0)
        all_preds.extend(out.argmax(1).tolist())
        all_labels.extend(y.tolist())

    return total_loss / len(loader.dataset), accuracy_score(all_labels, all_preds)

for epoch in range(10):
    train_loss, train_acc = train_student_kd(student, train_loader, soft_labels)
    val_loss, val_acc = evaluate(student, test_loader, loss_fn)
    print(f"[Epoch {epoch+1}] Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

[Epoch 1] Train Acc: 0.7637 | Val Acc: 0.8693
[Epoch 2] Train Acc: 0.8952 | Val Acc: 0.8891
[Epoch 3] Train Acc: 0.9171 | Val Acc: 0.8997
[Epoch 4] Train Acc: 0.9289 | Val Acc: 0.9058
[Epoch 5] Train Acc: 0.9372 | Val Acc: 0.9101
[Epoch 6] Train Acc: 0.9436 | Val Acc: 0.9104
[Epoch 7] Train Acc: 0.9492 | Val Acc: 0.9134
[Epoch 8] Train Acc: 0.9538 | Val Acc: 0.9142
[Epoch 9] Train Acc: 0.9581 | Val Acc: 0.9141
[Epoch 10] Train Acc: 0.9611 | Val Acc: 0.9149


In [12]:
torch.save(student.state_dict(), "AG_SafeStudent.pt")


In [13]:
from sklearn.metrics import classification_report

# Evaluate on test set
student.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for x, y in test_loader:
        logits = student(x)
        preds = logits.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

# Classification Report
target_names = ['World', 'Sports', 'Business', 'Sci/Tech']
print("\nClassification Report for Student Model:")
print(classification_report(all_labels, all_preds, target_names=target_names))



Classification Report for Student Model:
              precision    recall  f1-score   support

       World       0.92      0.91      0.92      1900
      Sports       0.96      0.98      0.97      1900
    Business       0.90      0.86      0.88      1900
    Sci/Tech       0.88      0.90      0.89      1900

    accuracy                           0.91      7600
   macro avg       0.91      0.91      0.91      7600
weighted avg       0.91      0.91      0.91      7600

