In [247]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [248]:
transform = transforms.Compose(([transforms.ToTensor()]))
train_data = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [249]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)
        
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [277]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
        
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [252]:
def distillation_loss(teacher_logits, student_logits, hard_labels, T, alpha):
    soft_targets = F.softmax(teacher_logits / T, dim=1)
    students_probes = F.log_softmax(student_logits / T, dim=1)

    soft_loss = F.kl_div(students_probes, soft_targets, reduction='batchmean') * (T**2)
    hard_loss = F.cross_entropy(students_probes, hard_labels)

    return alpha * soft_loss + (1 - alpha) * hard_loss

In [253]:
def train(model, save, name, num_epochs=10):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss {loss.item()}")

    if save:
        torch.save(model.state_dict(), f"{name}.pth")

In [254]:
teacher = TeacherModel().to(device)
train(teacher, save=True, name="teacher_model")

Epoch 1, Loss 1.1378507614135742
Epoch 2, Loss 1.2058892250061035
Epoch 3, Loss 0.41773444414138794
Epoch 4, Loss 1.000694990158081
Epoch 5, Loss 0.9638948440551758
Epoch 6, Loss 0.4364369511604309
Epoch 7, Loss 0.3945637345314026
Epoch 8, Loss 0.39768368005752563
Epoch 9, Loss 0.050330691039562225
Epoch 10, Loss 0.4710502326488495


In [275]:
def distillation_train(student, teacher_model, num_epochs=10):
    teacher.load_state_dict(torch.load(f"{teacher_model}.pth"))
    teacher.eval()

    optimizer = optim.Adam(student.parameters(), lr=0.001)
    T = 3.0
    alpha = 0.5

    for epoch in range(num_epochs):
        student.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            student_logits = student(images)
            teacher_logits = teacher(images)

            loss = distillation_loss(student_logits, teacher_logits, labels, T, alpha)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [278]:
student = StudentModel().to(device)
distillation_train(student, "teacher_model", 10)

  teacher.load_state_dict(torch.load(f"{teacher_model}.pth"))


Epoch 1, Loss: 4.814355373382568
Epoch 2, Loss: 3.061305284500122
Epoch 3, Loss: 1.5099807977676392
Epoch 4, Loss: 0.9559910297393799
Epoch 5, Loss: 2.609994888305664
Epoch 6, Loss: 1.1867704391479492
Epoch 7, Loss: 1.2134482860565186
Epoch 8, Loss: 1.9121482372283936
Epoch 9, Loss: 0.9148890972137451
Epoch 10, Loss: 1.9633268117904663


In [279]:
student2 = StudentModel().to(device)
train(student2, False, "no", 10)

Epoch 1, Loss 0.8132889270782471
Epoch 2, Loss 1.2162551879882812
Epoch 3, Loss 0.8355135917663574
Epoch 4, Loss 0.835460901260376
Epoch 5, Loss 0.34013789892196655
Epoch 6, Loss 0.7490552067756653
Epoch 7, Loss 1.3399980068206787
Epoch 8, Loss 0.5107120871543884
Epoch 9, Loss 0.30884769558906555
Epoch 10, Loss 0.4490154981613159


In [280]:
def evaluate(model):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    print(f"Accuracy: {100 * correct / total:.2f}%")

In [281]:
evaluate(student)

Accuracy: 71.84%


In [282]:
evaluate(teacher)

Accuracy: 79.12%


In [283]:
evaluate(student2)

Accuracy: 69.26%
