In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Load CIFAR-10 Dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

100%|██████████| 170M/170M [00:05<00:00, 29.6MB/s]


# Load Teacher & Student Models

In [3]:
teacher = torchvision.models.resnet18(pretrained=True)
teacher.fc = nn.Linear(512, 10)
teacher = teacher.to(device)
teacher.eval()

student = torchvision.models.mobilenet_v2(pretrained=False)
student.classifier[1] = nn.Linear(student.last_channel, 10)
student = student.to(device)
student.train()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 199MB/s]


MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

# Dynamic KD Loss Functions

In [4]:
def entropy(p):
    return -torch.sum(p * torch.log(p + 1e-8), dim=1)

def dynamic_alpha(student_logits):
    prob = F.softmax(student_logits, dim=1)
    ent = entropy(prob)
    max_entropy = torch.log(torch.tensor(prob.size(1), dtype=torch.float32)).to(device)
    return (ent / max_entropy).unsqueeze(1)

def dynamic_kd_loss(student_logits, teacher_logits, labels, T=4):
    alpha = dynamic_alpha(student_logits).detach()
    ce = F.cross_entropy(student_logits, labels, reduction='none').unsqueeze(1)
    kd = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='none'
    ).sum(dim=1).unsqueeze(1) * (T * T)
    return ((1 - alpha) * ce + alpha * kd).mean()

# Train Student with Dynamic KD

In [5]:
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for epoch in range(5):
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            t_logits = teacher(images)

        s_logits = student(images)
        loss = dynamic_kd_loss(s_logits, t_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {running_loss / len(train_loader):.4f}")

100%|██████████| 391/391 [00:28<00:00, 13.66it/s]


Epoch 1 | Loss: 0.1151


100%|██████████| 391/391 [00:24<00:00, 16.24it/s]


Epoch 2 | Loss: 0.0957


100%|██████████| 391/391 [00:22<00:00, 17.15it/s]


Epoch 3 | Loss: 0.0945


100%|██████████| 391/391 [00:22<00:00, 17.07it/s]


Epoch 4 | Loss: 0.0933


100%|██████████| 391/391 [00:23<00:00, 16.87it/s]

Epoch 5 | Loss: 0.0923





In [6]:
student_cpu = student.cpu()
quantized_student = torch.quantization.quantize_dynamic(
    student_cpu,
    {nn.Linear},
    dtype=torch.qint8
)
quantized_student.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

# Selective Dynamic Quantization on Student

In [None]:
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = quantized_student(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Quantized Student Accuracy: {100 * correct / total:.2f}%")