# Knowledge Distillation
In this jupyter notebook, I implemented one approach to Knowledge Distillation (KD). Be aware that there are multiple ways to approach this topic and there is not ONE approach to KD.

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from tqdm import tqdm
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import time
from datasets import load_dataset
import os

## Knowledge Distillation Loss
This function calculates the knowledge distillation loss by computing the Kullback-Leibler divergence between the softened probability distributions of new and old logits, adjusted by a specified temperature.

In [10]:
def knowledge_distillation_loss(
    new_logits: torch.Tensor,
    old_logits: torch.Tensor,
    temperature: float = 1.0
) -> torch.Tensor:
    """
    Compute KL-div between old and new logits (softened by 'temperature').
    
    new_logits, old_logits: (batch_size, seq_len, vocab_size)
    """
    # Reshape to (batch_size * seq_len, vocab_size)
    new_logits = new_logits.view(-1, new_logits.size(-1)) / temperature
    old_logits = old_logits.view(-1, old_logits.size(-1)) / temperature

    # KL divergence: D_KL( old || new ) — or new vs old, depending on which distribution you want to match
    kd_loss = F.kl_div(
        F.log_softmax(new_logits, dim=-1),
        F.softmax(old_logits, dim=-1),
        reduction="batchmean"
    ) * (temperature ** 2)

    return kd_loss

## Load the models

Note that in my example, I use the same model for teacher and student. This will not lead to any outcome, as they are exactly the same. Adjust to the bigger Qwen2.5-7b for example, if you want to see a difference.

In [13]:
device = "mps"  # or "cuda", "cpu"

MODEL_NAME = "Qwen/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# NEW: "Old" (teacher) model for distillation – a frozen copy of the pretrained model
old_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
old_model.eval()
for param in old_model.parameters():
    param.requires_grad = False

# "New" model to be fine-tuned
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

optimizer = AdamW(
    model.parameters(),
    lr=5e-6,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.0
)

model.train()
optimizer.zero_grad()

## Generate Dataset

Load the dataset from the excel spreadsheet using pandas.

In [None]:
class QADataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def tokenize_data(dataframe, tokenizer, max_seq_length: int, device: str):
    tokenized_data = []
    for _, row in dataframe.iterrows():
        question = row['Question']
        answer = row['Answer']
        try:

            # Tokenize the question
            inputs = tokenizer(
                question,
                padding='max_length',
                truncation=True,
                max_length=max_seq_length,
                return_tensors="pt"
            )

            # Tokenize the answer as labels
            labels = tokenizer(
                f"{question}: {answer}",
                padding='max_length',
                truncation=True,
                max_length=max_seq_length,
                return_tensors="pt"
            )['input_ids'].squeeze()

            tokenized_data.append({
                'input_ids': inputs['input_ids'].squeeze().to(device),
                'attention_mask': inputs['attention_mask'].squeeze().to(device),
                'labels': labels.to(device)
            })
        except Exception as e:
            print(f"Error tokenizing row: {e}")
            continue

    return tokenized_data

print("Loading dataset...")
dataframe = pd.read_csv('wiki_qa_by_headline.csv')

max_seq_length = 512

tokenized_data = tokenize_data(dataframe, tokenizer, max_seq_length, device)
qa_dataset = QADataset(tokenized_data)
train_loader = DataLoader(qa_dataset, batch_size=2, shuffle=True)

In [15]:
epochs = 4
alpha = 0.2        # Weight on distillation loss
temperature = 2.0  # Soften predictions to help KD
logging_steps = 100
gradient_steps = 4
total_steps = len(train_loader) * epochs

total_steps = len(train_loader) * epochs
warmup_steps = 200

scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

## The actual training loop

In [16]:
exp_name = "Supervised Knowledge Distilliation"
timestamp = time.time()
writer = SummaryWriter(log_dir=f"./tensorboard_logs/{exp_name}_{timestamp}")
global_step = 0

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc="Training")

    for step, batch in enumerate(progress_bar):
        batch = {key: val.to(device) for key, val in batch.items()}
        
        # Forward pass of the new (student) model
        outputs = model(**batch)
        task_loss = outputs.loss  # Standard cross-entropy loss on new data

        # Forward pass of the old (teacher) model (frozen)
        with torch.no_grad():
            old_outputs = old_model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask']
            )

        # Knowledge Distillation loss
        dist_loss = knowledge_distillation_loss(
            new_logits=outputs.logits,
            old_logits=old_outputs.logits,
            temperature=temperature
        )

        # Combine the losses
        loss = task_loss + alpha * dist_loss

        loss.backward()
        writer.add_scalar('Loss/Total', loss.item(), global_step)
        writer.add_scalar('Loss/Task', task_loss.item(), global_step)
        writer.add_scalar('Loss/KD', dist_loss.item(), global_step)
        global_step += 1
        scheduler.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

        if (step % gradient_steps == 0 and step > 0) or step == len(train_loader) - 1:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            optimizer.step()
            optimizer.zero_grad()

    print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(train_loader)}")

Epoch 1/4


Training: 100%|██████████| 174/174 [02:45<00:00,  1.05it/s, loss=1.48]


Epoch 1 Loss: 5.484623649339566
Epoch 2/4


Training: 100%|██████████| 174/174 [02:55<00:00,  1.01s/it, loss=1.25] 


Epoch 2 Loss: 1.596250693345892
Epoch 3/4


Training: 100%|██████████| 174/174 [02:52<00:00,  1.01it/s, loss=1.1]  


Epoch 3 Loss: 1.4266898330600781
Epoch 4/4


Training: 100%|██████████| 174/174 [02:38<00:00,  1.10it/s, loss=1.37] 

Epoch 4 Loss: 1.3529919416740024





## Saving the model

In [17]:
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")

print("Model fine-tuned and saved to ./fine_tuned_model")

Model fine-tuned and saved to ./fine_tuned_model
