In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset

# Student model (smaller transformer + MLP)
class StudentModel(nn.Module):
    def __init__(self, hidden_size, num_heads, num_layers):
        super().__init__()
        # Custom transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=hidden_size * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        
        # Final MLP layer
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )
        
    def forward(self, x):
        x = self.transformer(x)
        return self.mlp(x)

# Distillation trainer
class DistillationTrainer:
    def __init__(
        self,
        teacher_model_name,
        student_hidden_size,
        student_num_heads,
        student_num_layers,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        
        # Initialize teacher model
        self.teacher = AutoModel.from_pretrained(teacher_model_name).to(device)
        self.teacher.eval()  # Set to evaluation mode
        
        # Initialize student model
        self.student = StudentModel(
            hidden_size=student_hidden_size,
            num_heads=student_num_heads,
            num_layers=student_num_layers
        ).to(device)
        
        # Loss function (MSE for simplicity)
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.AdamW(self.student.parameters())
        
    def train_step(self, batch):
        self.student.train()
        
        # Get teacher predictions
        with torch.no_grad():
            teacher_outputs = self.teacher(**batch).last_hidden_state
            
        # Get student predictions
        student_outputs = self.student(batch['input_ids'])
        
        # Calculate loss
        loss = self.criterion(student_outputs, teacher_outputs)
        
        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

# Example usage
def main():
    # Initialize trainer
    trainer = DistillationTrainer(
        teacher_model_name='bert-base-uncased',  # Can be any HF model
        student_hidden_size=768,  # Match teacher's hidden size for simplicity
        student_num_heads=8,
        student_num_layers=2  # Smaller number of layers than teacher
    )
    