In [1]:
# Install necessary packages
!pip install torch torchvision torchsummary --quiet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torchsummary import summary

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Running on:", device)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Load CIFAR-10 Dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)


100%|██████████| 170M/170M [00:02<00:00, 82.5MB/s]


# Define the Teacher (ResNet18 pretrained)

In [3]:
teacher = models.resnet18(pretrained=True)
teacher.fc = nn.Linear(512, 10)  # Adapt for CIFAR-10
teacher = teacher.to(device)

# Optional: Fine-tune teacher here, but for demo we use it as-is
teacher.eval()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 81.5MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Define the Student Model (Tiny CNN)

In [4]:
class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

student = StudentCNN().to(device)
summary(student, (3, 32, 32))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             448
         MaxPool2d-2           [-1, 16, 16, 16]               0
            Conv2d-3           [-1, 32, 16, 16]           4,640
         MaxPool2d-4             [-1, 32, 8, 8]               0
            Linear-5                  [-1, 256]         524,544
            Linear-6                   [-1, 10]           2,570
Total params: 532,202
Trainable params: 532,202
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.24
Params size (MB): 2.03
Estimated Total Size (MB): 2.28
----------------------------------------------------------------


# Define Distillation Loss

In [5]:
def distillation_loss(y_student, y_teacher, labels, T=4, alpha=0.7):
    distill = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(y_student / T, dim=1),
        F.softmax(y_teacher / T, dim=1)
    ) * (T * T)
    student_loss = F.cross_entropy(y_student, labels)
    return alpha * distill + (1 - alpha) * student_loss


# Train the Student Using Distillation

In [6]:
optimizer = optim.Adam(student.parameters(), lr=0.001)
print("\nTraining student using distillation...\n")

for epoch in range(5):  # Keep it low for Colab runtime
    student.train()
    total_loss = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            teacher_logits = teacher(inputs)

        student_logits = student(inputs)
        loss = distillation_loss(student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")



Training student using distillation...

Epoch 1, Loss: 0.6335
Epoch 2, Loss: 0.5839
Epoch 3, Loss: 0.5624
Epoch 4, Loss: 0.5486
Epoch 5, Loss: 0.5383


# Evaluate the Student Accuracy

In [7]:
def evaluate(model):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

print("\nFinal Accuracy of student:", evaluate(student))



Final Accuracy of student: 0.6833


# Apply Quantization

In [8]:
student.cpu()
student.eval()

quantized_model = torch.quantization.quantize_dynamic(
    student, {nn.Linear}, dtype=torch.qint8
)


# Compare Model Sizes

In [9]:
torch.save(student.state_dict(), "student.pth")
torch.save(quantized_model.state_dict(), "student_quantized.pth")

import os
print("\nSize before quantization:", os.path.getsize("student.pth") / 1024, "KB")
print("Size after quantization:", os.path.getsize("student_quantized.pth") / 1024, "KB")



Size before quantization: 2081.5625 KB
Size after quantization: 539.955078125 KB


In [12]:
import torch.nn.functional as F
import torch.nn as nn
import torch

def entropy(p):
    return -torch.sum(p * torch.log(p + 1e-8), dim=1)

def dynamic_alpha(student_logits):
    prob = F.softmax(student_logits, dim=1)
    ent = entropy(prob)
    max_entropy = torch.log(torch.tensor(prob.size(1), dtype=torch.float32)).cuda()
    return (ent / max_entropy).unsqueeze(1)  # shape: [batch_size, 1]

def dynamic_kd_loss(student_logits, teacher_logits, labels, T=4):
    alpha = dynamic_alpha(student_logits).detach()
    ce = F.cross_entropy(student_logits, labels, reduction='none').unsqueeze(1)
    kd = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='none'
    ).sum(dim=1).unsqueeze(1) * (T * T)
    return ((1 - alpha) * ce + alpha * kd).mean()
