# Homework 3: Knowledge Distillation for AI Dermatologist

## CS 4774 Machine Learning - University of Virginia

In this notebook, you'll implement knowledge distillation to improve your skin disease classifier by learning from **MedSigLIP** (from Google), a powerful medical imaging model.

**Key Requirements:**
- Student model must be < **25 MB** on disk
- Use MedSigLIP as frozen teacher model (inference only)
- Implement temperature-scaled knowledge distillation following Hinton et al. (2015)

**Recommended Starting Point:** Use ShuffleNetV2 for your student model (~5 MB)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import os
import requests
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
print(f'Using device: {device}')

Using device: mps


## Step 1: Load Data (Same as HW1)

In [2]:
# Define dataset class
class SkinDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.image_paths = []
        self.labels = []
        valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.jfif')
        for cls_name in self.classes:
            cls_dir = os.path.join(root_dir, cls_name)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(valid_exts):
                    self.image_paths.append(os.path.join(cls_dir, fname))
                    self.labels.append(self.class_to_idx[cls_name])
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Load data with image size
# Training transform (Do not change)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation transform (Do not change)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = SkinDataset('training_dataset/train_dataset', transform=train_transform)
print(f'Dataset loaded with {len(dataset)} images and {len(dataset.classes)} classes')

Dataset loaded with 10000 images and 10 classes


## Step 2: Load Teacher Model (MedSigLIP from Google)

**Important:** Load the pre-trained MedSigLIP model for inference only. Do NOT fine-tune it.

In [3]:
# Load MedSigLIP teacher model
from transformers import AutoModel, AutoProcessor
from huggingface_hub import login, HfFolder

print("=" * 70)
print("IMPORTANT: Before running this cell, you must:")
print("1. Go to https://huggingface.co/google/medsiglip-448")
print("2. Click 'Request Access' and wait for approval (usually instant)")
print("3. Get your HuggingFace token from https://huggingface.co/settings/tokens")
print("=" * 70)

# Login to HuggingFace - this will prompt you to enter your token
login()

# Verify login
token = HfFolder.get_token()
if token:
    print("‚úÖ Successfully logged in to HuggingFace!")
else:
    print("‚ùå Login failed. Please try again.")
    raise ValueError("HuggingFace authentication required")


def load_teacher_model():
    """Load MedSigLIP-448 from HuggingFace."""

    print("\nLoading MedSigLIP-448 teacher model...")
    model_name = "google/medsiglip-448"
    
    # Get token to pass explicitly
    token = HfFolder.get_token()
    
    # Load model and processor with token
    teacher_model = AutoModel.from_pretrained(
        model_name, 
        trust_remote_code=True,
        token=token
    )
    processor = AutoProcessor.from_pretrained(
        model_name, 
        trust_remote_code=True,
        token=token
    )
    
    teacher_model = teacher_model.to(device)
    teacher_model.eval()
    
    # Freeze all parameters
    for param in teacher_model.parameters():
        param.requires_grad = False
    
    print("‚úÖ MedSigLIP loaded successfully!")
    return teacher_model, processor

# Load teacher
teacher_model, teacher_processor = load_teacher_model()

# Define student model: ShuffleNetV2 (Recommended, ~5MB)
from torchvision.models import shufflenet_v2_x0_5

def create_student_shufflenet(num_classes=10):
    """Create a ShuffleNetV2 student model (~5 MB)."""
    model = shufflenet_v2_x0_5(pretrained=False)
    # Replace final classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Create student model
student_model = create_student_shufflenet(num_classes=10).to(device)

print(f'Student model created with {sum(p.numel() for p in student_model.parameters()):,} parameters')

IMPORTANT: Before running this cell, you must:
1. Go to https://huggingface.co/google/medsiglip-448
2. Click 'Request Access' and wait for approval (usually instant)
3. Get your HuggingFace token from https://huggingface.co/settings/tokens


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

‚úÖ Successfully logged in to HuggingFace!

Loading MedSigLIP-448 teacher model...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


‚úÖ MedSigLIP loaded successfully!
Student model created with 352,042 parameters




## Step 3: Define Distillation Loss

Implement the knowledge distillation loss following Hinton et al. (2015):
- **Hard loss**: Cross-entropy with ground truth labels
- **Soft loss**: KL divergence between teacher and student soft predictions
- **Temperature scaling**: Soften distributions for better knowledge transfer

In [4]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        # Initialize cross-entropy loss for hard targets
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, student_logits, teacher_logits, labels):
        # Hard loss: standard cross-entropy with ground truth labels
        hard_loss = self.ce_loss(student_logits, labels)
        
        # Soft loss: KL divergence between teacher and student soft predictions
        # Temperature scaling softens the distributions for better knowledge transfer
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature ** 2)
        
        # Combine hard and soft losses
        # alpha controls the balance between hard labels and soft teacher predictions
        total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
        
        return total_loss, hard_loss, soft_loss

# Create an instance of DistillationLoss
distillation_loss = DistillationLoss(temperature=4.0, alpha=0.3)
print("‚úÖ Distillation loss initialized with temperature=4.0, alpha=0.3")

‚úÖ Distillation loss initialized with temperature=4.0, alpha=0.3


## Step 4: Train with Knowledge Distillation

Implement training loop that:
1. Gets teacher's soft predictions (with torch.no_grad())
2. Gets student's predictions
3. Computes distillation loss
4. Updates only student model parameters

In [5]:
# Prepare data loaders
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# Setup training
optimizer = optim.Adam(student_model.parameters(), lr=1e-3)
criterion = distillation_loss

# Initialize teacher's classifier head for generating logits
# Note: This is a simple approach - the head generates soft targets from MedSigLIP features
print("Initializing teacher classifier head...")

# MedSigLIP expects 448x448 images, so we need a resize transform for teacher
teacher_resize = transforms.Resize((448, 448))

with torch.no_grad():
    # Get sample batch to determine feature dimension (448x448 for MedSigLIP)
    sample_images = torch.randn(1, 3, 448, 448).to(device)
    teacher_features = teacher_model.vision_model(sample_images).pooler_output
    hidden_dim = teacher_features.shape[1]
    
# Create classifier head (fixed random projection for soft targets)
teacher_model.classifier_head = nn.Linear(hidden_dim, 10).to(device)
print(f"‚úÖ Teacher classifier head created: {hidden_dim} -> 10 classes")

# Training function
def train_epoch(student, teacher, teacher_proc, dataloader, criterion, optimizer):
    student.train()
    total_loss = 0
    
    for images, labels in tqdm(dataloader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        # Get teacher predictions (no gradients)
        with torch.no_grad():
            # Resize images to 448x448 for MedSigLIP (expects larger input)
            teacher_images = F.interpolate(images, size=(448, 448), mode='bilinear', align_corners=False)
            # Get MedSigLIP vision embeddings and project to class logits
            teacher_features = teacher.vision_model(teacher_images).pooler_output
            teacher_logits = teacher.classifier_head(teacher_features)
        
        # Get student predictions
        student_logits = student(images)
        
        # Compute distillation loss
        loss, hard_loss, soft_loss = criterion(student_logits, teacher_logits, labels)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Validation function
def validate(student, dataloader):
    student.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validation'):
            images = images.to(device)
            outputs = student(images)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    return accuracy, f1

# Training loop
NUM_EPOCHS = 10
best_f1 = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    
    # Train
    train_loss = train_epoch(student_model, teacher_model, teacher_processor, 
                             train_loader, criterion, optimizer)
    
    # Validate
    val_acc, val_f1 = validate(student_model, val_loader)
    
    print(f'Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        print(f'‚úÖ New best F1: {best_f1:.4f}')

print(f'\nTraining complete! Best F1: {best_f1:.4f}')

Initializing teacher classifier head...
‚úÖ Teacher classifier head created: 1152 -> 10 classes

Epoch 1/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [37:32<00:00,  7.99s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  6.68it/s]


Train Loss: 0.6468 | Val Acc: 0.3470 | Val F1: 0.2195
‚úÖ New best F1: 0.2195

Epoch 2/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [38:48<00:00,  8.26s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  6.65it/s]


Train Loss: 0.6217 | Val Acc: 0.4760 | Val F1: 0.3454
‚úÖ New best F1: 0.3454

Epoch 3/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [4:26:03<00:00, 56.61s/it]     
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.78it/s]


Train Loss: 0.6022 | Val Acc: 0.5100 | Val F1: 0.3718
‚úÖ New best F1: 0.3718

Epoch 4/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [34:10<00:00,  7.27s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.62it/s]


Train Loss: 0.5868 | Val Acc: 0.5300 | Val F1: 0.3959
‚úÖ New best F1: 0.3959

Epoch 5/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [36:01<00:00,  7.67s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  6.97it/s]


Train Loss: 0.5761 | Val Acc: 0.5560 | Val F1: 0.4546
‚úÖ New best F1: 0.4546

Epoch 6/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [35:50<00:00,  7.63s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.12it/s]


Train Loss: 0.5692 | Val Acc: 0.5760 | Val F1: 0.4590
‚úÖ New best F1: 0.4590

Epoch 7/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [37:07<00:00,  7.90s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.84it/s]


Train Loss: 0.5617 | Val Acc: 0.5700 | Val F1: 0.4383

Epoch 8/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [34:18<00:00,  7.30s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.68it/s]


Train Loss: 0.5582 | Val Acc: 0.5810 | Val F1: 0.4530

Epoch 9/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [33:42<00:00,  7.17s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.98it/s]


Train Loss: 0.5524 | Val Acc: 0.5950 | Val F1: 0.4808
‚úÖ New best F1: 0.4808

Epoch 10/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 282/282 [33:45<00:00,  7.18s/it]
Validation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 32/32 [00:04<00:00,  7.79it/s]

Train Loss: 0.5485 | Val Acc: 0.6110 | Val F1: 0.5179
‚úÖ New best F1: 0.5179

Training complete! Best F1: 0.5179





## Step 5: Save and Submit

Save your student model (< 25 MB) and submit to the HW3 leaderboard.

**Important:** Only submit the student model, NOT the teacher!

In [6]:
# Save student model
student_model.eval()
student_model.cpu()
scripted_model = torch.jit.script(student_model)
scripted_model.save('student_model_hw3.pt')

# Check model size
import os
size_mb = os.path.getsize('student_model_hw3.pt') / (1024 * 1024)
print(f'‚úÖ Model saved: student_model_hw3.pt')
print(f'üì¶ Model size: {size_mb:.2f} MB')

if size_mb >= 25.0:
    print('‚ùå WARNING: Model exceeds 25 MB limit!')
else:
    print('‚úÖ Model size is within the 25 MB limit')

# Submit to HW3 leaderboard
def submit_model(token, model_path, server_url='http://hadi.cs.virginia.edu:8000'):
    """Submit model to the HW3 leaderboard."""
    with open(model_path, 'rb') as f:
        files = {'file': f}
        data = {'token': token}
        response = requests.post(f'{server_url}/submit', data=data, files=files)
        resp_json = response.json()
        if 'message' in resp_json:
            print(f"‚úÖ {resp_json['message']}")
        else:
            print(f"‚ùå {resp_json.get('error', 'Unknown error')}")

# Check submission status
def check_status(token, server_url='http://hadi.cs.virginia.edu:8000'):
    """Check your submission status."""
    url = f'{server_url}/submission-status/{token}'
    response = requests.get(url)
    
    if response.status_code == 200:
        attempts = response.json()
        for a in attempts:
            score = f"{a['score']:.4f}" if isinstance(a['score'], (float, int)) else "Pending"
            size = f"{a['model_size']:.2f}" if isinstance(a['model_size'], (float, int)) else "N/A"
            print(f"Attempt {a['attempt']}: Score={score}, Size={size} MB, Status={a['status']}")
    else:
        print(f"Error: {response.status_code}")

# Use your token from registration
my_token = 'your_token_here'

# Uncomment to submit:
# submit_model(my_token, 'student_model_hw3.pt')
# check_status(my_token)

print('\nüéØ View the HW3 leaderboard at: http://hadi.cs.virginia.edu:8000/leaderboard3')

‚úÖ Model saved: student_model_hw3.pt
üì¶ Model size: 1.53 MB
‚úÖ Model size is within the 25 MB limit

üéØ View the HW3 leaderboard at: http://hadi.cs.virginia.edu:8000/leaderboard3
