In [3]:
from google.colab import drive
drive.mount('/content/drive')

import torch
print(torch.cuda.get_device_name(0))

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Tesla T4


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

baseline_model_path = "/content/drive/MyDrive/AI_MODEL_OPTIMIZATION/models/mobilenetv2_cifar10_baseline.pth"
student_model_path = "/content/drive/MyDrive/AI_MODEL_OPTIMIZATION/models/student_kd.pth"


Using device: cuda


In [5]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


100%|██████████| 170M/170M [00:13<00:00, 12.6MB/s]


In [6]:
teacher = models.mobilenet_v2(pretrained=False)
teacher.classifier[1] = nn.Linear(teacher.last_channel, 10)
teacher.load_state_dict(torch.load(baseline_model_path, map_location=device))
teacher = teacher.to(device)
teacher.eval()
print(" Teacher loaded.")




 Teacher loaded.


In [7]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        self.classifier = nn.Linear(32*56*56, 10)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

student = StudentNet().to(device)
print(" Student defined.")


 Student defined.


In [8]:
def kd_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    """
    student_logits: output of student
    teacher_logits: output of teacher
    labels: true labels
    T: temperature
    alpha: balance between soft & hard loss
    """
    hard_loss = F.cross_entropy(student_logits, labels)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    return alpha * hard_loss + (1 - alpha) * soft_loss


In [9]:
def evaluate(model):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc


In [11]:
# define optimizer
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

# training loop
epochs = 5
for epoch in range(epochs):
    student.train()
    running_loss = 0.0
    loop = tqdm(trainloader, desc=f"Epoch [{epoch+1}/{epochs}]")
    for i, (images, labels) in enumerate(loop):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            teacher_logits = teacher(images)

        student_logits = student(images)
        loss = kd_loss(student_logits, teacher_logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        loop.set_postfix({"Loss": f"{running_loss/(i+1):.4f}"})

    acc = evaluate(student)
    print(f" Epoch {epoch+1}, Test Accuracy: {acc:.2f}%")



Epoch [1/5]: 100%|██████████| 196/196 [03:00<00:00,  1.09it/s, Loss=2.5518]


🎯 Epoch 1, Test Accuracy: 45.36%


Epoch [2/5]: 100%|██████████| 196/196 [02:50<00:00,  1.15it/s, Loss=2.0232]


🎯 Epoch 2, Test Accuracy: 49.57%


Epoch [3/5]: 100%|██████████| 196/196 [02:49<00:00,  1.16it/s, Loss=1.8337]


🎯 Epoch 3, Test Accuracy: 52.23%


Epoch [4/5]: 100%|██████████| 196/196 [02:47<00:00,  1.17it/s, Loss=1.6349]


🎯 Epoch 4, Test Accuracy: 53.94%


Epoch [5/5]: 100%|██████████| 196/196 [02:48<00:00,  1.16it/s, Loss=1.5354]


🎯 Epoch 5, Test Accuracy: 55.62%


In [12]:
torch.save(student.state_dict(), student_model_path)
print(f" Student model saved: {student_model_path}")

import os
teacher_size = os.path.getsize(baseline_model_path) / 1e6
student_size = os.path.getsize(student_model_path) / 1e6
print(f" Teacher Size: {teacher_size:.2f} MB")
print(f" Student Size: {student_size:.2f} MB")


 Student model saved: /content/drive/MyDrive/AI_MODEL_OPTIMIZATION/models/student_kd.pth
 Teacher Size: 9.19 MB
 Student Size: 4.04 MB


In [None]:
optimizer = optim.Adam(student.parameters(), lr=0.001)