# IMDb Knowledge Distillation Notebook
This notebook fine-tunes a BERT teacher and trains a small student model via knowledge distillation.

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [4]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [5]:
# Cell 1: Imports & warnings (updated)
import warnings
warnings.filterwarnings(
    "ignore",
    message="The secret `HF_TOKEN` does not exist in your Colab secrets"
)

import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW           # ← moved here
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_scheduler                   # ← removed AdamW from here
)
from datasets import load_dataset
from evaluate import load as load_metric
from tqdm.auto import tqdm

In [6]:
def load_data(num_train=20000, num_test=5000, max_length=256, batch_size=16):
    raw = load_dataset("imdb")
    raw_train = raw["train"].shuffle(seed=42).select(range(num_train))
    raw_test  = raw["test"].shuffle(seed=42).select(range(num_test))

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    def preprocess(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    train_ds = raw_train.map(preprocess, batched=True)
    test_ds  = raw_test.map(preprocess, batched=True)

    train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    test_ds.set_format(type="torch",  columns=["input_ids", "attention_mask", "label"])

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size)
    return train_loader, test_loader, tokenizer

In [7]:
def train_teacher(model, train_loader, device, lr=2e-5, epochs=1):
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = epochs * len(train_loader)
    scheduler = get_scheduler(
        "linear", optimizer=optimizer,
        num_warmup_steps=0, num_training_steps=total_steps
    )

    model.to(device).train()
    for epoch in range(epochs):
        loop = tqdm(train_loader, desc=f"[Teacher] Epoch {epoch+1}", leave=False)
        for batch in loop:
            batch = {k: v.to(device) for k, v in batch.items()}
            loss = model(**batch).loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            loop.set_postfix(loss=loss.item())
    return model

In [16]:
class StudentModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim=128, num_labels=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=4,
            dim_feedforward=hidden_dim*2, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, input_ids, attention_mask=None):
        x = self.embed(input_ids)
        x = x.permute(1, 0, 2)
        mask = ~attention_mask.bool()
        x = self.transformer(x, src_key_padding_mask=mask)
        x = x.mean(dim=0)
        logits = self.classifier(x)  # Assign output to logits
        return logits # Return logits instead of just the tensor

In [17]:
def train_student(student, teacher, train_loader, device,
                  lr=5e-4, epochs=1, alpha=0.5, temperature=2.0):
    ce_loss = nn.CrossEntropyLoss()
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    optimizer = AdamW(student.parameters(), lr=lr)

    student.to(device).train()
    teacher.to(device).eval()

    for epoch in range(epochs):
        loop = tqdm(train_loader, desc="[Student Distill]", leave=False)
        for batch in loop:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            with torch.no_grad():
                t_logits = teacher(
                    input_ids, attention_mask=attention_mask
                ).logits / temperature
                t_soft = torch.softmax(t_logits, dim=-1)

            s_logits = student(input_ids, attention_mask)
            s_logits_temp = s_logits / temperature

            loss_hard = ce_loss(s_logits, labels)
            loss_soft = kl_loss(
                torch.log_softmax(s_logits_temp, dim=-1),
                t_soft
            ) * (temperature ** 2)

            loss = alpha * loss_hard + (1 - alpha) * loss_soft
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loop.set_postfix(distill_loss=loss.item())
    return student

In [18]:
def train_teacher(model, train_loader, device, lr=2e-5, epochs=1):
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = epochs * len(train_loader)
    scheduler = get_scheduler(
        "linear", optimizer=optimizer,
        num_warmup_steps=0, num_training_steps=total_steps
    )

    model.to(device).train()
    for epoch in range(epochs):
        loop = tqdm(train_loader, desc=f"[Teacher] Epoch {epoch+1}", leave=False)
        for batch in loop:
            batch = {k: v.to(device) for k, v in batch.items()}
            # Separate labels from inputs
            labels = batch.pop("label")
            # Pass only the expected inputs to the model
            loss = model(**batch, labels=labels).loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            loop.set_postfix(loss=loss.item())
    return model

# Main orchestration
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_loader, test_loader, tokenizer = load_data()

    teacher = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-uncased", num_labels=2
    )
    teacher = train_teacher(teacher, train_loader, device)

    student = StudentModel(vocab_size=tokenizer.vocab_size)
    student = train_student(student, teacher, train_loader, device)

    teacher_acc = evaluate_model(teacher, test_loader, device)
    student_acc = evaluate_model(student, test_loader, device)
    retention = student_acc / teacher_acc * 100 if teacher_acc > 0 else 0

    print(f"Teacher Accuracy : {teacher_acc:.4f}")
    print(f"Student Accuracy : {student_acc:.4f}")
    print(f"Retention        : {retention:.1f}% of teacher")

    os.makedirs("models", exist_ok=True)
    teacher.save_pretrained("models/teacher_bert")
    torch.save(student.state_dict(), "models/student.pt")

Using device: cuda


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Teacher] Epoch 1:   0%|          | 0/1250 [00:00<?, ?it/s]



[Student Distill]:   0%|          | 0/1250 [00:00<?, ?it/s]

AttributeError: 'Tensor' object has no attribute 'logits'

In [19]:
def evaluate_model(model, data_loader, device):
    metric = load_metric("accuracy")
    model.to(device).eval()

    with torch.no_grad():
        for batch in data_loader:
            # move data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # get outputs: may be a ModelOutput or a Tensor
            try:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            except TypeError:
                # fallback for student: positional args
                outputs = model(input_ids, attention_mask)

            # extract logits
            if hasattr(outputs, "logits"):
                logits = outputs.logits
            else:
                logits = outputs

            preds = torch.argmax(logits, dim=-1)
            metric.add_batch(
                predictions=preds.cpu(),
                references=labels.cpu()
            )

    return metric.compute()["accuracy"]


In [20]:
# after training both...
teacher_acc = evaluate_model(teacher, test_loader, device)
student_acc = evaluate_model(student, test_loader, device)

print(f"Teacher Accuracy : {teacher_acc:.4f}")
print(f"Student Accuracy : {student_acc:.4f}")


Teacher Accuracy : 0.9104
Student Accuracy : 0.8284
