In [5]:
from teacher import ResNet50
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os
from torch import optim

In [2]:

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
learning_rate = 3e-4  # Reduced learning rate for better convergence
epochs = 5  # Increased epochs

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),  # Normalize to [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # Additional normalization
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


device="cuda" if torch.cuda.is_available() else "cpu"
model=ResNet50(num_classes=10).to(device)



In [3]:
test_dataset=datasets.MNIST(root='./data',train=False,transform=transform,download=True)
test_loader=DataLoader(test_dataset,batch_size=64,shuffle=True)

In [6]:
from tqdm import tqdm
optimizer=optim.AdamW(model.parameters(),lr=learning_rate)
criterian=torch.nn.CrossEntropyLoss()
# Training loop
model.train()
for epoch in range(epochs):
    correct = 0
    total = 0
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterian(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update running loss
        running_loss += loss.item()
        progress_bar.set_postfix(loss=running_loss / (total / batch_size), accuracy=100 * correct / total)
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

                                                                                       

Epoch 1/5, Loss: 0.1248, Accuracy: 96.14%


                                                                                        

Epoch 2/5, Loss: 0.0509, Accuracy: 98.42%


                                                                                        

Epoch 3/5, Loss: 0.0439, Accuracy: 98.69%


                                                                                        

Epoch 4/5, Loss: 0.0349, Accuracy: 98.86%


                                                                                        

Epoch 5/5, Loss: 0.0314, Accuracy: 99.00%




In [8]:
teacher_model=model

Student Model

In [7]:


class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)  # (B, 32, H, W)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # (B, 64, H, W)
        self.pool = nn.MaxPool2d(2, 2)  # downsample by 2
        
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # adjust input shape if input size is different
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # -> (B, 32, H/2, W/2)
        x = self.pool(F.relu(self.conv2(x)))  # -> (B, 64, H/4, W/4)
        x = x.view(x.size(0), -1)  # flatten
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

student_model=Student().to(device)

In [9]:
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * soft_loss + (1. - alpha) * hard_loss

In [10]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

In [11]:
# Freeze teacher during distillation
for param in teacher_model.parameters():
    param.requires_grad = False

teacher_model.eval()

ResNet50(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1

In [12]:
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

In [22]:
epochs = 5
T = 1.0
alpha = 0.4
for epoch in range(epochs):
    student_model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            teacher_outputs = teacher_model(images)

        student_outputs = student_model(images)
        loss = distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Accuracy Calculation
        _, predicted = torch.max(student_outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = total_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

Epoch 1/5, Loss: 0.6025, Accuracy: 98.33%
Epoch 2/5, Loss: 0.6011, Accuracy: 98.32%
Epoch 3/5, Loss: 0.6010, Accuracy: 98.35%
Epoch 4/5, Loss: 0.6004, Accuracy: 98.36%
Epoch 5/5, Loss: 0.6001, Accuracy: 98.41%
