In [None]:
import torch
import torchvision.models as models
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F

## Load Dataset

In [None]:
# Define transformations
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load CIFAR10 dataset
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)
val_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define ResNet18 and ResNet50 models
resnet18 = models.resnet18(pretrained=True)
resnet50 = models.resnet50(pretrained=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29145367.80it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 162MB/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 149MB/s]


## Train Model

In [None]:
# Define a function for model validation
def validate_model(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

In [None]:
def train_model(model, train_loader, epochs=10, learning_rate=1e-3):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)  # Added weight decay
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Learning rate scheduler

    best_acc = 0

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        scheduler.step()  # Adjust learning rate based on scheduler
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total

        # Validation to check generalization
        val_acc = validate_model(model, val_loader, device)
        print(f'Epoch {epoch+1},  Loss: {epoch_loss}, Train Accuracy: {epoch_acc}, Validation Accuracy: {val_acc}')

        if val_acc > best_acc:
            best_acc = val_acc

    print(f'Epoch {epoch+1}, Loss: {epoch_loss}, Accuracy: {best_acc}')

    print('Finished Training')
    return (epoch_acc, model)



In [None]:
def knowledge_distillation(teacher_model, student_model, train_loader, val_loader, alpha = 0.5, temperature = 2.5, epochs=10, learning_rate=1e-3):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    teacher_model = teacher_model.to(device)
    student_model = student_model.to(device)
    optimizer = Adam(student_model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Learning rate scheduler
    criterion = nn.CrossEntropyLoss()

    # Hyperparameters for grid search

    best_acc = 0

    teacher_model.eval()
    student_model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)

            student_outputs = student_model(inputs)
            loss = (1. - alpha) * criterion(student_outputs, labels)
            loss += alpha * (F.kl_div(F.log_softmax(student_outputs / temperature, dim=1),
                                      F.softmax(teacher_outputs / temperature, dim=1),
                                      reduction='batchmean') * (temperature ** 2))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(student_outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        scheduler.step()  # Update learning rate
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total

        # Validation to check generalization
        val_acc = validate_model(student_model, val_loader, device)
        print(f'Epoch {epoch+1},  Loss: {epoch_loss}, Train Accuracy: {epoch_acc}, Validation Accuracy: {val_acc}')

        if val_acc > best_acc:
            best_acc = val_acc
            best_alpha, best_temperature = alpha, temperature

    print('Finished Knowledge Distillation')
    return (best_acc, student_model)

In [None]:
# Train both models to get baseline accuracy
baseline_acc_resnet18, trained_resnet18 = train_model(resnet18, train_loader)

Epoch 1, Loss: 1.0684906738569668, Accuracy: 0.65396
Epoch 2, Loss: 0.6608771151670104, Accuracy: 0.77786
Epoch 3, Loss: 0.5224957946316361, Accuracy: 0.82368
Epoch 4, Loss: 0.42069070103109035, Accuracy: 0.8574
Epoch 5, Loss: 0.32570542236956795, Accuracy: 0.88888
Epoch 6, Loss: 0.1332750970812138, Accuracy: 0.95668
Epoch 7, Loss: 0.06760632546643591, Accuracy: 0.97858
Epoch 8, Loss: 0.039955910762398, Accuracy: 0.98772
Epoch 9, Loss: 0.02546420908352757, Accuracy: 0.9918
Epoch 10, Loss: 0.017183984738497964, Accuracy: 0.99454
Finished Training


In [None]:
validate_model(trained_resnet18, val_loader, device)

0.9732


In [None]:
baseline_acc_resnet50, trained_resnet50 = train_model(resnet50, train_loader)

Epoch 1,  Loss: 1.1241562697664855, Train Accuracy: 0.6378, Validation Accuracy: 0.65502
Epoch 2,  Loss: 1.2624519733745423, Train Accuracy: 0.5492, Validation Accuracy: 0.72952
Epoch 3,  Loss: 0.7654269236280485, Train Accuracy: 0.73858, Validation Accuracy: 0.78462
Epoch 4,  Loss: 0.6440953198067673, Train Accuracy: 0.7813, Validation Accuracy: 0.78498
Epoch 5,  Loss: 0.5555187167261567, Train Accuracy: 0.81098, Validation Accuracy: 0.83528
Epoch 6,  Loss: 0.27584131566993414, Train Accuracy: 0.90572, Validation Accuracy: 0.94058
Epoch 7,  Loss: 0.17696004295173814, Train Accuracy: 0.9397, Validation Accuracy: 0.96646
Epoch 8,  Loss: 0.10826695287038031, Train Accuracy: 0.96582, Validation Accuracy: 0.9842
Epoch 9,  Loss: 0.05647473648289347, Train Accuracy: 0.98342, Validation Accuracy: 0.9932
Epoch 10,  Loss: 0.03306476617893835, Train Accuracy: 0.99042, Validation Accuracy: 0.99346
Epoch 10, Loss: 0.03306476617893835, Accuracy: 0.99346
Finished Training


In [None]:
validate_model(trained_resnet50, val_loader, device)

0.99346

In [None]:
# Perform normal KD with ResNet50 as teacher and ResNet18 as student
resnet18 = models.resnet18(weights=True)
kd_acc_resnet18, distilled_resnet18 = knowledge_distillation(trained_resnet50, resnet18, train_loader, val_loader)



Epoch 1,  Loss: 3.0607115101936224, Train Accuracy: 0.68296, Validation Accuracy: 0.7939
Epoch 2,  Loss: 2.159079546208882, Train Accuracy: 0.75192, Validation Accuracy: 0.81804
Epoch 3,  Loss: 1.5197044298090898, Train Accuracy: 0.81476, Validation Accuracy: 0.85136
Epoch 4,  Loss: 1.2159842743593103, Train Accuracy: 0.84836, Validation Accuracy: 0.87724
Epoch 5,  Loss: 1.023549463819055, Train Accuracy: 0.87, Validation Accuracy: 0.90286
Epoch 6,  Loss: 0.47884095383955694, Train Accuracy: 0.93834, Validation Accuracy: 0.9598
Epoch 7,  Loss: 0.33244786446775926, Train Accuracy: 0.96022, Validation Accuracy: 0.97396
Epoch 8,  Loss: 0.25939579493821124, Train Accuracy: 0.97212, Validation Accuracy: 0.98326
Epoch 9,  Loss: 0.20615337638522657, Train Accuracy: 0.982, Validation Accuracy: 0.9889
Epoch 10,  Loss: 0.16909092522757438, Train Accuracy: 0.98818, Validation Accuracy: 0.99134
Finished Knowledge Distillation


In [None]:
resnet18 = models.resnet18(weights=True)
self_kd_acc_resnet18, self_distilled_resnet18 = knowledge_distillation(trained_resnet18, resnet18, train_loader, val_loader, epochs = 10)



Epoch 1,  Loss: 3.3653346842817027, Train Accuracy: 0.67918, Validation Accuracy: 0.77538
Epoch 2,  Loss: 2.38172355134164, Train Accuracy: 0.75374, Validation Accuracy: 0.79686
Epoch 3,  Loss: 1.732421047013739, Train Accuracy: 0.81382, Validation Accuracy: 0.84936
Epoch 4,  Loss: 1.4007619300766674, Train Accuracy: 0.84642, Validation Accuracy: 0.8807
Epoch 5,  Loss: 1.1003576603234577, Train Accuracy: 0.87774, Validation Accuracy: 0.91566
Epoch 6,  Loss: 0.4496425986671082, Train Accuracy: 0.95102, Validation Accuracy: 0.97302
Epoch 7,  Loss: 0.2801854875119751, Train Accuracy: 0.972, Validation Accuracy: 0.98472
Epoch 8,  Loss: 0.19662773227104752, Train Accuracy: 0.98434, Validation Accuracy: 0.99184
Epoch 9,  Loss: 0.14477787831383745, Train Accuracy: 0.99204, Validation Accuracy: 0.99514
Epoch 10,  Loss: 0.11595752836702883, Train Accuracy: 0.9947, Validation Accuracy: 0.99602
Finished Knowledge Distillation


In [None]:
# Perform reverse KD with ResNet18 as teacher and ResNet50 as student
resnet50 = models.resnet50(weights=True)
kd_acc_resnet50 = knowledge_distillation(trained_resnet18, resnet50, train_loader, val_loader)



Epoch 1,  Loss: 1.5946509189465468, Train Accuracy: 0.84074, Validation Accuracy: 0.86998
Epoch 2,  Loss: 1.303880830059576, Train Accuracy: 0.86488, Validation Accuracy: 0.91306
Epoch 3,  Loss: 0.9351826997669151, Train Accuracy: 0.90242, Validation Accuracy: 0.933
Epoch 4,  Loss: 0.7957348483983818, Train Accuracy: 0.91916, Validation Accuracy: 0.9229
Epoch 5,  Loss: 0.7239070548235303, Train Accuracy: 0.92786, Validation Accuracy: 0.95434
Epoch 6,  Loss: 0.2926031673026969, Train Accuracy: 0.97718, Validation Accuracy: 0.99192
Epoch 7,  Loss: 0.18359896431074424, Train Accuracy: 0.99254, Validation Accuracy: 0.99682
Epoch 8,  Loss: 0.14672014138678, Train Accuracy: 0.9967, Validation Accuracy: 0.9984
Epoch 9,  Loss: 0.12295342332033245, Train Accuracy: 0.99848, Validation Accuracy: 0.99894
Epoch 10,  Loss: 0.10445033968962214, Train Accuracy: 0.99888, Validation Accuracy: 0.99918
Finished Knowledge Distillation


In [None]:
resnet50 = models.resnet50(weights=True)
self_kd_acc_resnet50 = knowledge_distillation(trained_resnet50, resnet50, train_loader, val_loader, epochs = 10)



Epoch 1,  Loss: 2.8490625717450895, Train Accuracy: 0.69612, Validation Accuracy: 0.78756
Epoch 2,  Loss: 2.222142069540975, Train Accuracy: 0.74498, Validation Accuracy: 0.8032
Epoch 3,  Loss: 1.6210009162993078, Train Accuracy: 0.80504, Validation Accuracy: 0.83558
Epoch 4,  Loss: 1.3904464440729918, Train Accuracy: 0.82642, Validation Accuracy: 0.8514
Epoch 5,  Loss: 1.1793581144050564, Train Accuracy: 0.85068, Validation Accuracy: 0.87878
Epoch 6,  Loss: 0.5157197678218717, Train Accuracy: 0.92608, Validation Accuracy: 0.95658
Epoch 7,  Loss: 0.3186811959690145, Train Accuracy: 0.9562, Validation Accuracy: 0.9729
Epoch 8,  Loss: 0.22367067955186604, Train Accuracy: 0.97294, Validation Accuracy: 0.98554
Epoch 9,  Loss: 0.15974410416563148, Train Accuracy: 0.98512, Validation Accuracy: 0.99142
Epoch 10,  Loss: 0.12574605879080875, Train Accuracy: 0.99054, Validation Accuracy: 0.994
Finished Knowledge Distillation
