In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.utils.data import DataLoader

In [21]:
# 1 Data Transformer
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])

# 2 Create Train Dataset
trainset = torchvision.datasets.MNIST(root='./data', train =True,
                                    download = True, transform = transform)
trainloader = DataLoader(trainset, batch_size =64, shuffle =True)

#3 Create Test Dataset
testset = torchvision.datasets.MNIST(root = "./data", train = False,
                                     download = True, transform = transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [10]:
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet,self).__init__()
        self.conv = nn.Conv2d(1,32,5)
        self.pool = nn.MaxPool2d(5,5)
        self.fc1 = nn.Linear(32*4*4, 128)
        self.fc2 = nn.Linear(128,10)
        
    def forward(self,x):
        x = self.pool(F.relu(self.conv(x)))
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
        

In [15]:
# initialize model
teacher_model = TeacherNet().to(device)

# Define Optimizer
teacher_optimizer = optim.Adam(teacher_model.parameters(),
                              lr = 0.001)

# Define loss function
teacher_criterion  = nn.CrossEntropyLoss().to(device)

In [26]:
for epoch in range(5):
    epoch_loss =0.0
    correct_predictions = 0
    total_predictions = 0

    #set to train mode
    teacher_model.train()

    #train for all batches of data
    for data in trainloader:
        X, y = data
        X , y = X.to(device), y.to(device)
        y_pred = teacher_model(X)
        loss = teacher_criterion(y_pred,y)
        loss.backward()
        teacher_optimizer.step()
        teacher_optimizer.zero_grad()
                
        epoch_loss += loss.item()

        #Calculate accuracy
        _, predicted = torch.max(y_pred,1)
        correct_predictions += (predicted ==y).sum().item()
        total_predictions += y.size(0)

    accuracy = correct_predictions / total_predictions
    average_loss = epoch_loss / len(trainloader)

    print(f'Epoch {epoch+1}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}')

    

Epoch 1, Loss: 0.2257, Accuracy: 0.9367
Epoch 2, Loss: 0.0764, Accuracy: 0.9769
Epoch 3, Loss: 0.0585, Accuracy: 0.9821
Epoch 4, Loss: 0.0475, Accuracy: 0.9851
Epoch 5, Loss: 0.0416, Accuracy: 0.9868


In [104]:
class StudentNet(nn.Module):

    def __init__(self):
        super(StudentNet,self).__init__()
        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,10)

    def forward(self, x):
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
         

In [105]:
# Knowledge distilation loss (KL divergence):
def KL_loss(student_logits, teacher_logits,temperature = 1):
    # convert teacher model output to probabilities
    p_teacher = F.softmax(teacher_logits / temperature, dim= 1).to(device)

    # convert student model output to probabilities
    p_student  = F.softmax(student_logits / temperature, dim = 1).to(device)

    #compute KL divergence loss (PyTorch's method)
    loss = F.kl_div(p_student, p_teacher, reduction="batchmean").to(device)

    return loss



In [106]:
(5)*(1/16)

0.3125

In [107]:
torch.exp(torch.tensor(0.4 / 4))

tensor(1.1052)

In [108]:
# Innitialize Model
student_model = StudentNet().to(device)

#Define Optimizer
student_optimizer = optim.Adam(student_model.parameters(), lr =0.001)
student_criterion = nn.CrossEntropyLoss().to(device)

a = 0.6
tempature = 6

for epoch in range(10):
    # set to train mode
    student_model.train()

    epoch_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    student_model.train()
    #train for all batches of data
    for data in trainloader:
        inputs, labels = data[0].to(device),data[1].to(device)
        student_optimizer.zero_grad()

        # get student outputs
        student_logits = student_model(inputs)

        #get teacher outputs and detach them
        #to avoid backpropagation
        teacher_logits = teacher_model(inputs).detach()

        #compute KL Divergence loss 
        loss_distill = KL_loss(student_logits, teacher_logits,temperature=tempature) 
        loss_criterion = student_criterion(student_logits,labels)
        
        loss = a * (1/(tempature**2))* loss_distill + (1-a)* loss_criterion
        #run backpropagation step
        loss.backward()
        student_optimizer.step()

        epoch_loss += loss.item()

        #Calculate accuracy
        _, predicted = torch.max(student_logits,1)
        correct_predictions += (predicted ==labels).sum().item()
        total_predictions += labels.size(0)
    accuracy = correct_predictions / total_predictions
    average_loss = epoch_loss / len(trainloader)

    print(f'Epoch {epoch+1}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}')


Epoch 1, Loss: 0.1104, Accuracy: 0.8946
Epoch 2, Loss: 0.0369, Accuracy: 0.9498
Epoch 3, Loss: 0.0176, Accuracy: 0.9633
Epoch 4, Loss: 0.0077, Accuracy: 0.9701
Epoch 5, Loss: 0.0024, Accuracy: 0.9738
Epoch 6, Loss: -0.0024, Accuracy: 0.9771
Epoch 7, Loss: -0.0058, Accuracy: 0.9799
Epoch 8, Loss: -0.0087, Accuracy: 0.9816
Epoch 9, Loss: -0.0114, Accuracy: 0.9836
Epoch 10, Loss: -0.0132, Accuracy: 0.9848


In [110]:
student = StudentNet().to(device)
optimizer = optim.Adam(student.parameters(), lr =0.001)
criterion  = nn.CrossEntropyLoss().to(device)

for epoch in range(10):
    epoch_loss =0.0
    correct_predictions = 0
    total_predictions = 0

    #set to train mode
    student.train()

    #train for all batches of data
    for data in trainloader:
        X, y = data
        X , y = X.to(device), y.to(device)
        y_pred = student(X)
        loss = criterion(y_pred,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
                
        epoch_loss += loss.item()

        #Calculate accuracy
        _, predicted = torch.max(y_pred,1)
        correct_predictions += (predicted ==y).sum().item()
        total_predictions += y.size(0)

    accuracy = correct_predictions / total_predictions
    average_loss = epoch_loss / len(trainloader)

    print(f'Epoch {epoch+1}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}')

    


Epoch 1, Loss: 0.3427, Accuracy: 0.8975
Epoch 2, Loss: 0.1604, Accuracy: 0.9535
Epoch 3, Loss: 0.1162, Accuracy: 0.9648
Epoch 4, Loss: 0.0950, Accuracy: 0.9708
Epoch 5, Loss: 0.0789, Accuracy: 0.9757
Epoch 6, Loss: 0.0698, Accuracy: 0.9776
Epoch 7, Loss: 0.0607, Accuracy: 0.9809
Epoch 8, Loss: 0.0568, Accuracy: 0.9819
Epoch 9, Loss: 0.0505, Accuracy: 0.9834
Epoch 10, Loss: 0.0460, Accuracy: 0.9843
