In [2]:
import os
if os.getcwd().split('/')[-1] == "notebooks":
    os.chdir('..')

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
from src.kd import DistillationLoss
from src.train import train

data = torch.load("./data/cifar10_training_data.pt")
labels = torch.load("./data/cifar10_training_labels.pt")
logits = torch.load("./data/cifar10_logits.pt")

batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

kdtrain = torch.utils.data.TensorDataset(data, labels, logits)
kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 500)
        self.fc2 = nn.Linear(500, 250)
        self.fc3 = nn.Linear(250, 100)
        self.fc4 = nn.Linear(100, 50)
        self.fc5 = nn.Linear(50, 25)
        self.fc6 = nn.Linear(25, 10)
        self.fc7 = nn.Linear(10, 10)
        

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = self.fc7(x)
        return x

In [15]:
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
if torch.cuda.is_available():
    model.cuda()
train(model, trainloader, testloader, optimizer, nn.CrossEntropyLoss(), epochs=20, writer=None)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Epoch 1 accuracy = 43.57%
Epoch 2 accuracy = 46.54%
Epoch 3 accuracy = 49.63%
Epoch 4 accuracy = 50.53%
Epoch 5 accuracy = 52.12%
Epoch 6 accuracy = 52.08%
Epoch 7 accuracy = 52.38%
Epoch 8 accuracy = 52.09%
Epoch 9 accuracy = 51.66%
Epoch 10 accuracy = 51.44%
Epoch 11 accuracy = 51.06%
Epoch 12 accuracy = 51.72%
Epoch 13 accuracy = 51.84%
Epoch 14 accuracy = 51.90%
Epoch 15 accuracy = 52.42%
Epoch 16 accuracy = 51.61%
Epoch 17 accuracy = 52.54%
Epoch 18 accuracy = 52.20%
Epoch 19 accuracy = 52.42%
Epoch 20 accuracy = 51.97%



In [16]:
model = Net()
distillation_loss = DistillationLoss(5, 0.8)

def kd_ce_loss(logits_S, labels, logits_T, temperature=5):
    '''
    Calculate the cross entropy between logits_S and logits_T
    :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
    '''
    if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
        temperature = temperature.unsqueeze(-1)
    beta_logits_T = logits_T / temperature
    beta_logits_S = logits_S / temperature
    p_T = F.softmax(beta_logits_T, dim=-1)
    distillation_loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
    target_loss = nn.CrossEntropyLoss()(logits_S, labels)
    loss = 0.8 * distillation_loss + (1 - 0.8) * target_loss
    return loss

optimizer = optim.Adam(model.parameters(), lr=0.001)
if torch.cuda.is_available():
    model.cuda()
train(model, kdloader, testloader, optimizer, kd_ce_loss, epochs=20, writer=None)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

Epoch 1 accuracy = 42.22%
Epoch 2 accuracy = 46.11%
Epoch 3 accuracy = 49.15%
Epoch 4 accuracy = 49.53%
Epoch 5 accuracy = 50.87%
Epoch 6 accuracy = 51.23%
Epoch 7 accuracy = 51.53%
Epoch 8 accuracy = 50.72%
Epoch 9 accuracy = 52.13%
Epoch 10 accuracy = 51.61%
Epoch 11 accuracy = 51.51%
Epoch 12 accuracy = 52.43%
Epoch 13 accuracy = 51.32%
Epoch 14 accuracy = 51.35%
Epoch 15 accuracy = 51.59%
Epoch 16 accuracy = 51.27%
Epoch 17 accuracy = 52.12%
Epoch 18 accuracy = 51.77%
Epoch 19 accuracy = 51.46%
Epoch 20 accuracy = 52.68%

