<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/Student_Teacher_Knowledge_Distillation_Sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#code for knowledge distillation applicable
#disclaimer - code is reproduced by me . original code can be found by doing a google search on knoledge distillation
#https://wandb.ai/byyoung3/ML_NEWS3/reports/Knowledge-distillation-Teaching-LLM-s-with-synthetic-data--Vmlldzo5MTMyMzA2
#in contrastive learning /metric learning uses contrastive loss . loss works on training the NN by bringing positive labels closer and negative labels farther
#Constrastive Losses can be used in Supervised learning where the Anchor is the GT and positive and negative labels are logits
#Self Supervised Learning <-> augment the data to get positive pairs

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertModel
from sklearn.metrics import accuracy_score

#define teacher
class TeacherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", output_hidden_states=True)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask,output_hidden_states = True)
        return outputs.logits, outputs.hidden_states

#a smaller transformer model
class StudentModel(nn.Module):
    def __init__(self, num_labels=2):
        super().__init__()
        # Use a smaller pre-trained model or a custom architecture
        #self.encoder = BertModel.from_pretrained("bert-base-uncased") # Using a smaller BERT model for demonstration
        self.encoder = nn.TransformerEncoderLayer(d_model=768, nhead=8)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)
        self.hidden_state_projector = nn.Linear(self.encoder.config.hidden_size, 768) # Project student hidden state to teacher hidden size

    def forward(self, input_ids, attention_mask):
        hidden_states = torch.randn(input_ids.shape[0],768)
        logits = self.classifier(hidden_states)
        return logits,hidden_states # return logits and hidden_states for constrastive loss

  #define distillation loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha = 0.5,temperature=2.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.kl_divergence = nn.KLDivLoss(reduction='batchmean')
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        #soft targets from teacher model
        soft_teacher_probs = nn.functional.softmax(teacher_logits/self.temperature,dim=1)
        soft_student_probs = nn.functional.softmax(student_logits/self.temperature,dim=1)
        #compute distillation loss
        distillation_loss = self.kl_divergence(soft_student_probs,soft_teacher_probs) * (self.temperature**2)
        #cross entropy loss for student on ground truth - THIS IS KEY .... GROUND TRUTH IS SAME FOR STUDENT AND TEACHER MODEL
        ce_loss = self.cross_entropy(student_logits,labels)
        #total loss
        total_loss = self.alpha * distillation_loss + (1-self.alpha) * ce_loss
        return total_loss

class ContrastiveLoss(nn.Module):
      def __init__(self, temperature=0.5):
          super().__init__()
          self.temperature = temperature
          self.cross_entropy = nn.CrossEntropyLoss()
      def forward(self,student_reps,teacher_reps):
          #normalize representations
          student_reps = F.normalize(student_reps,dim=1)
          teacher_reps = F.normalize(teacher_reps,dim=1)
          #compute similarity
          similarity_matrix = torch.matmul(student_reps,teacher_reps.transpose(0, 1)) / self.temperature

          #positive pairs are (student_rep_i,teaher_rep_i)
          labels = torch.arange(similarity_matrix.size(0)).to(similarity_matrix.device)
          loss = self.cross_entropy(similarity_matrix,labels)
          return loss
 #training loop
def train_distilled_model(teacher_model,student_model,tokenizer,dataloader,epochs=5,lr=1e-4):
    distillation_criterion = DistillationLoss(alpha=0.7, temperature = 2.0)
    contrastive_criterion = ContrastiveLoss( temperature = 0.1)
    optimizer = optim.Adam(student_model.parameters(),lr=lr)
    teacher_model.eval()
    for epoch in range(epochs):
      for batch in dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        # teacher forward pass (no gradients needed)
        with torch.no_grad():
          teacher_outputs = teacher_model(input_ids,attention_mask)
          teacher_logits = teacher_outputs[0] # Access logits from the tuple
          # Access hidden states from the tuple, assuming the last layer's hidden states are needed
          teacher_hidden_state = teacher_outputs[-1]

        #student forward pass
        student_logits,student_hidden_state = student_model(input_ids, attention_mask) # Pass input_ids and attention_mask to the student model

        #distillation loss
        distillation_loss = distillation_criterion(student_logits,teacher_logits,labels)
        #contrastive loss
        contrastive_loss = contrastive_criterion(student_hidden_state,teacher_hidden_state)

        #combine loss
        total_loss = distillation_loss + contrastive_loss
        #backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        #log loss etc.
        print(f"Epoch {epoch+1}/{epochs} - Total Loss: {total_loss.item()}")


#teacher_model = TeacherModel()
#student_model = StudentModel()
#tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
#train_distilled_model(teacher_model,student_model,tokenizer,dataloader)