# Distillation Notebook
This notebook demonstrates knowledge distillation in PyTorch.

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import DistilBertModel, DistilBertTokenizer

# Create dummy dataset
X = torch.randn(100, 20)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=16)

model_name = "distilbert-base-uncased"
model = DistilBertModel.from_pretrained(model_name)
model.eval()

# Define teacher and student models
teacher = nn.Sequential(nn.LSTM(20, 50), nn.ReLU(), nn.Linear(50, 2))
student = nn.Sequential(nn.LSTM(20, 10), nn.ReLU(), nn.Linear(10, 2))

# Assume teacher is pre-trained
with torch.no_grad():
    pass  # placeholder for teacher training

# Distillation training loop
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
temperature = 5.0
alpha = 0.5

for epoch in range(50):
    student.train()
    total_loss = 0
    for data, labels in loader:
        with torch.no_grad():
            teacher_logits = teacher(data)
        student_logits = student(data)

        # Soft targets
        T = temperature
        teacher_probs = F.softmax(teacher_logits / T, dim=1)
        student_log_probs = F.log_softmax(student_logits / T, dim=1)
        distill_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (T*T)
        task_loss = F.cross_entropy(student_logits, labels)

        loss = alpha * distill_loss + (1 - alpha) * task_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")

TypeError: relu(): argument 'input' (position 1) must be Tensor, not tuple

## Network Parameter Sizes
Compute and display the number of trainable parameters for teacher and student models.

In [None]:
# Compute parameter counts
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher parameters: {teacher_params}")
print(f"Student parameters: {student_params}")

Teacher parameters: 1152
Student parameters: 232


## Simple Tests
Basic runtime checks to verify model outputs and parameter relationships.

In [None]:
# Forward pass shape tests
dummy = torch.randn(1, 20)
t_out = teacher(dummy)
s_out = student(dummy)
assert t_out.shape == (1, 2), f"Unexpected teacher output shape: {t_out.shape}"
assert s_out.shape == (1, 2), f"Unexpected student output shape: {s_out.shape}"

# Compare inference outputs
print("Teacher output:", t_out)
print("Student output:", s_out)

# Ensure outputs are tensors
assert isinstance(t_out, torch.Tensor) and isinstance(s_out, torch.Tensor), "Outputs must be tensors"
print("All simple tests passed and inference comparison printed.")

Teacher output: tensor([[-0.1509, -0.4650]], grad_fn=<AddmmBackward0>)
Student output: tensor([[0.2120, 0.0776]], grad_fn=<AddmmBackward0>)
All simple tests passed and inference comparison printed.
