In [1]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Load dataset

In [3]:
train_data = pd.read_parquet(f"../../data/wikitext/train-00000-of-00001.parquet")[
    "text"
].to_numpy()
val_data = pd.read_parquet(f"../../data/wikitext/validation-00000-of-00001.parquet")[
    "text"
].to_numpy()
test_data = pd.read_parquet(f"../../data/wikitext/test-00000-of-00001.parquet")[
    "text"
].to_numpy()

#### Prepare Dataset and Dataloader

In [4]:
MODEL_PATH = "prajjwal1/bert-tiny"

In [5]:
class DistilDataset(Dataset):
    def __init__(self, sentences, tokenizer, max_length=128):
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        return sentence, self.tokenizer(
            sentence,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [7]:
train_ds = DistilDataset(train_data, tokenizer)
val_ds = DistilDataset(val_data, tokenizer)
test_ds = DistilDataset(test_data, tokenizer)

In [8]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_ds = DataLoader(val_ds, batch_size=32, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)

#### Define model

In [9]:
class TinyBERT(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256, num_heads=4, num_layers=4):
        super(TinyBERT, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.encoder_layers = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
                for _ in range(num_layers)
            ]
        )
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, attention_mask=None):
        x = self.embedding(x)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Expand dimensions of mask to match input for broadcasting
            attention_mask = attention_mask.unsqueeze(-1)
            x = x * attention_mask

        for layer in self.encoder_layers:
            x = layer(x)
        logits = self.linear(x)
        return logits

#### Instantiate Model and Define Loss & Optimizer

In [10]:
def manual_distillation_loss(student_logits, teacher_logits, temperature=2):
    """Distillation loss function"""
    student_probs = nn.functional.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = nn.functional.softmax(teacher_logits / temperature, dim=-1)
    loss = (
        torch.sum(teacher_probs * (teacher_probs.log() - student_probs))
        / student_probs.size()[0]
        * (temperature**2)
    )
    return loss


def distillation_loss(student_logits, teacher_logits, temperature=2):
    """Distillation loss function"""
    loss = nn.KLDivLoss(reduction="batchmean")
    student_probs = nn.functional.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = nn.functional.softmax(teacher_logits / temperature, dim=-1)
    return loss(student_probs, teacher_probs)

In [None]:
teacher_model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH).to(device)
student_model = TinyBERT(vocab_size=tokenizer.vocab_size).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)

#### Start training

In [12]:
num_epochs = 10

In [None]:
student_model.train()
teacher_model.eval()

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}, Started..")
    total_loss = 0
    for batch in train_dl:
        sentence, batch = batch
        embed_dim = batch["input_ids"].size()[2]
        inputs = batch["input_ids"].view(-1, embed_dim).to(device)
        attention_mask = batch["attention_mask"].view(-1, embed_dim).to(device)

        with torch.no_grad():
            teacher_logits = teacher_model(inputs, attention_mask=attention_mask).logits

        student_logits = student_model(inputs, attention_mask=None)

        loss_distillation = distillation_loss(
            student_logits, teacher_logits, temperature=0.1
        )
        loss_ce = loss_fn(
            student_logits.view(-1, student_logits.size(-1)), inputs.view(-1)
        )

        loss = loss_distillation + loss_ce

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_dl)}")

#### Run evaluation

In [None]:
teacher_model.eval()
total_loss = 0
total_correct = 0
total_count = 0

with torch.no_grad():
    for batch in test_dl:
        sentence, batch = batch
        inputs = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Handle logits extraction for both types of models
        outputs = teacher_model(inputs, attention_mask=attention_mask)
        if isinstance(outputs, tuple) or hasattr(outputs, "logits"):
            logits = outputs.logits  # Hugging Face model
        else:
            logits = outputs  # TinyBERT

        # For loss: Compare logits with true inputs (MLM task)
        loss = loss_fn(logits.view(-1, logits.size(-1)), inputs.view(-1))
        total_loss += loss.item()

        # For accuracy: Check if the highest logits match the true labels
        predictions = torch.argmax(logits, dim=-1)
        total_correct += (predictions == inputs).sum().item()
        total_count += inputs.numel()

avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_count
print(f"Teacher Model - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

In [None]:
student_model.eval()
total_loss = 0
total_correct = 0
total_count = 0

with torch.no_grad():
    for batch in test_dl:
        sentence, batch = batch
        inputs = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Handle logits extraction for both types of models
        outputs = student_model(inputs, attention_mask=attention_mask)
        if isinstance(outputs, tuple) or hasattr(outputs, "logits"):
            logits = outputs.logits  # Hugging Face model
        else:
            logits = outputs  # TinyBERT

        # For loss: Compare logits with true inputs (MLM task)
        loss = criterion(logits.view(-1, logits.size(-1)), inputs.view(-1))
        total_loss += loss.item()

        # For accuracy: Check if the highest logits match the true labels
        predictions = torch.argmax(logits, dim=-1)
        total_correct += (predictions == inputs).sum().item()
        total_count += inputs.numel()

avg_loss = total_loss / len(dataloader)
accuracy = total_correct / total_count
print(f"Student Model - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")