<a href="https://colab.research.google.com/github/gupta799/LLMFinetuning/blob/main/Knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers datasets



In [2]:
import torch
from torch import nn, optim
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Hyperparameters
alpha = 0.5
beta = 0.5
temperature = 2.0
batch_size = 16
learning_rate = 5e-5
num_epochs = 3

In [5]:
teacher_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
teacher_model.eval()
student_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
student_model.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [6]:
# Move models to device
teacher_model.to(device)
student_model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")



In [8]:
dataset = load_dataset("glue", "sst2")
def tokenize_function(example):
    return tokenizer(example['sentence'], padding='max_length', truncation=True, max_length=128)

In [9]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

# Set the format to return PyTorch tensors
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [10]:
train_dataset = tokenized_datasets['train']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [11]:
ce_loss = nn.CrossEntropyLoss()
kl_div_loss = nn.KLDivLoss(reduction='batchmean')

In [12]:
optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate)


In [14]:
for epoch in range(num_epochs):
    total_loss = 0
    student_model.train()  # Ensure student is in training mode
    for batch in train_loader:
        # Prepare inputs
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        # Add 'token_type_ids' if they exist
        if 'token_type_ids' in batch and batch['token_type_ids'] is not None:
            inputs['token_type_ids'] = batch['token_type_ids'].to(device)

        labels = batch['labels'].to(device)

        # Forward pass through teacher
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        # Forward pass through student
        student_outputs = student_model(**inputs)
        student_logits = student_outputs.logits

        # Compute losses
        loss_ce = ce_loss(student_logits, labels)
        loss_kl = kl_div_loss(
            nn.functional.log_softmax(student_logits / temperature, dim=-1),
            nn.functional.softmax(teacher_logits / temperature, dim=-1)
        ) * (temperature ** 2)

        loss = alpha * loss_ce + beta * loss_kl

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()


    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

KeyboardInterrupt: 