In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
from torch import nn
from torch.utils import data
from torchvision import transforms

import training_utils as utils
import distillation_methods_module

import numpy as np

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

print(device)

torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

cuda:0


<torch.autograd.profiler.emit_nvtx at 0x25048ae1450>

In [2]:
# Download the training and test sets
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

train_iter = data.DataLoader(mnist_train, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
test_iter = data.DataLoader(mnist_test, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

In [3]:
class Teacher_Net(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        
        self.linear_1 = nn.Linear(784, 1200)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

        self.linear_2 = nn.Linear(1200, 1200)

        self.linear_3 = nn.Linear(1200, 1200)

        self.linear_4 = nn.Linear(1200, 10)
    
    def forward(self, input):
        input = torch.flatten(input, start_dim=1)
        out = self.linear_1(input)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_2(out)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_3(out)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_4(out)

        return out

In [4]:
class Student_Net(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        
        self.linear_1 = nn.Linear(784, 600)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

        self.linear_2 = nn.Linear(600, 1200)

        self.linear_3 = nn.Linear(1200, 1200)

        self.linear_4 = nn.Linear(1200, 10)
    
    def forward(self, input):
        input = torch.flatten(input, start_dim=1)
        out = self.linear_1(input)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_2(out)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_3(out)
        out = self.relu(out)
        out = self.dropout(out)

        out = self.linear_4(out)

        return out

In [5]:
teacher = Teacher_Net().to(device)
num_epochs = 10
loss_fn = nn.CrossEntropyLoss(reduction='none').to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=0.001)
history = utils.train(teacher, train_iter, test_iter, loss_fn, num_epochs, optimizer)

Epoch:		  0
train_metrics:	  (0.6241023885091146, 0.7685666666666666)
test_accuracy:	  0.8423
Epoch:		  1
train_metrics:	  (0.4410204090118408, 0.8393)
test_accuracy:	  0.8476
Epoch:		  2
train_metrics:	  (0.4139511412302653, 0.8494666666666667)
test_accuracy:	  0.8557
Epoch:		  3
train_metrics:	  (0.3902605449040731, 0.8578333333333333)
test_accuracy:	  0.8649
Epoch:		  4
train_metrics:	  (0.38794019667307533, 0.8571166666666666)
test_accuracy:	  0.8657
Epoch:		  5
train_metrics:	  (0.367439803536733, 0.8639833333333333)
test_accuracy:	  0.869
Epoch:		  6
train_metrics:	  (0.35739705238342284, 0.8687)
test_accuracy:	  0.8716
Epoch:		  7
train_metrics:	  (0.3514067399342855, 0.87185)
test_accuracy:	  0.8724
Epoch:		  8
train_metrics:	  (0.34620792344411216, 0.8738833333333333)
test_accuracy:	  0.8704
Epoch:		  9
train_metrics:	  (0.3404334603309631, 0.8742166666666666)
test_accuracy:	  0.8698


In [6]:
student = Student_Net().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=5e-3)

d_logits = distillation_methods_module.Logits_Distiller(5, teacher=teacher, student=student, optimizer=optimizer)
d_logits.train(train_iter, test_iter, num_epochs)

Epoch:		  0
train_metrics:	  (1.7454750814552728e-05, 0.7055166666666667)
test_accuracy:	  0.8083
Epoch:		  1
train_metrics:	  (7.618954948460062e-06, 0.7979833333333334)
test_accuracy:	  0.8232
Epoch:		  2
train_metrics:	  (7.396421593148261e-06, 0.80325)
test_accuracy:	  0.8254
Epoch:		  3
train_metrics:	  (7.100268139038235e-06, 0.80775)
test_accuracy:	  0.8364
Epoch:		  4
train_metrics:	  (6.944385714208086e-06, 0.8093166666666667)
test_accuracy:	  0.8299
Epoch:		  5
train_metrics:	  (7.0928943479278435e-06, 0.8086833333333333)
test_accuracy:	  0.8279
Epoch:		  6
train_metrics:	  (7.3111612213930735e-06, 0.8059333333333333)
test_accuracy:	  0.8294
Epoch:		  7
train_metrics:	  (7.171109330374747e-06, 0.8088)
test_accuracy:	  0.8282
Epoch:		  8
train_metrics:	  (7.329882475702713e-06, 0.801)
test_accuracy:	  0.8324
Epoch:		  9
train_metrics:	  (7.197150789822142e-06, 0.8058666666666666)
test_accuracy:	  0.8188


([0.7055166666666667,
  0.7979833333333334,
  0.80325,
  0.80775,
  0.8093166666666667,
  0.8086833333333333,
  0.8059333333333333,
  0.8088,
  0.801,
  0.8058666666666666],
 [1.7454750814552728e-05,
  7.618954948460062e-06,
  7.396421593148261e-06,
  7.100268139038235e-06,
  6.944385714208086e-06,
  7.0928943479278435e-06,
  7.3111612213930735e-06,
  7.171109330374747e-06,
  7.329882475702713e-06,
  7.197150789822142e-06],
 [0.8083,
  0.8232,
  0.8254,
  0.8364,
  0.8299,
  0.8279,
  0.8294,
  0.8282,
  0.8324,
  0.8188])

In [11]:
student = Student_Net().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

d_features = distillation_methods_module.Features_Distiller(hint_layer=teacher.linear_2, hinted_layer=student.linear_2, teacher=teacher, student=student, optimizer=optimizer)
d_features.train(train_iter, test_iter, num_epochs)

Epoch:		  0
train_metrics:	  (0.0008552989792078733, 0.0008333333333333334)
test_accuracy:	  0.0037
Epoch:		  1
train_metrics:	  (0.00033188448411722977, 0.00215)
test_accuracy:	  0.0028
Epoch:		  2
train_metrics:	  (0.0002571209882075588, 0.00265)
test_accuracy:	  0.0029
Epoch:		  3
train_metrics:	  (0.00022459706577161948, 0.0026)
test_accuracy:	  0.004
Epoch:		  4
train_metrics:	  (0.00020735235592971246, 0.0034666666666666665)
test_accuracy:	  0.0045
Epoch:		  5
train_metrics:	  (0.00019487899063775938, 0.00385)
test_accuracy:	  0.0029
Epoch:		  6
train_metrics:	  (0.00018680839594453573, 0.0041)
test_accuracy:	  0.0054
Epoch:		  7
train_metrics:	  (0.00018028159526487192, 0.0047)
test_accuracy:	  0.0055
Epoch:		  8
train_metrics:	  (0.00017540895709147056, 0.00535)
test_accuracy:	  0.0057
Epoch:		  9
train_metrics:	  (0.00017121644373983144, 0.0052)
test_accuracy:	  0.0063
Epoch:		  0
train_metrics:	  (0.7611187035878499, 0.7745)
test_accuracy:	  0.8396
Epoch:		  1
train_metrics:	

In [12]:
student = Student_Net().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

d_relation = distillation_methods_module.Relations_Distiller(teacher=teacher, student=student, optimizer=optimizer)
d_relation.train(train_iter, test_iter, num_epochs)

Epoch:		  0
train_metrics:	  (1456.674359375, 0.10603333333333333)
test_accuracy:	  0.1451
Epoch:		  1
train_metrics:	  (1454.6212447916666, 0.10473333333333333)
test_accuracy:	  0.1451
Epoch:		  2
train_metrics:	  (1455.2199770833333, 0.1072)
test_accuracy:	  0.1451
Epoch:		  3
train_metrics:	  (1455.8520908854166, 0.10521666666666667)
test_accuracy:	  0.1451
Epoch:		  4
train_metrics:	  (1455.6877171875, 0.10593333333333334)
test_accuracy:	  0.1451
Epoch:		  5
train_metrics:	  (1458.8693299479166, 0.10378333333333334)
test_accuracy:	  0.1451
Epoch:		  6
train_metrics:	  (1453.3930104166666, 0.10323333333333333)
test_accuracy:	  0.1451
Epoch:		  7
train_metrics:	  (1455.8364981770833, 0.10503333333333334)
test_accuracy:	  0.1451
Epoch:		  8
train_metrics:	  (1457.3352338541667, 0.10425)
test_accuracy:	  0.1451
Epoch:		  9
train_metrics:	  (1456.2441893229166, 0.10586666666666666)
test_accuracy:	  0.1451
Epoch:		  0
train_metrics:	  (0.6521746459325155, 0.7565166666666666)
test_accurac