In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import inspect
import time

from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### В качестве датасета возьмем CIFAR10, поскольку на MNIST оказалось очень сложно оценивать результаты работы дистилляция. Похоже на то, что в случае MNIST самые простые архитектуры (даже двухслойная полносвязная сеть) способны выдавать неплохой скор на валидации, однако они не могут обучиться генерировать распределение учителя. Поэтому архитектуру модели-студента приходится улучшать, чтобы повысить ее способность к подражанию. Чем сложнее модель-студент, тем лучше она сама справляется с MNIST. Поэтому выбор пал на CIFAR10, с этими данными простые модели не справляются.  

In [25]:
BATCH_SIZE = 128
transform = transforms.Compose([
     transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True)

valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [0]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores) / batch_size:.4f}")

In [0]:
def train_val(model, epochs=1, lr=1e-2):
    #optimizer = optim.Adam(model.parameters(), lr)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    loss_func = nn.CrossEntropyLoss()

    for ep in range(epochs):
        running_loss = 0
        pbar = tqdm(enumerate(trainloader), total=len(trainloader))
        model.train()
        
        for i, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = loss_func(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_description('loss: {:.4f}'.format(running_loss / (i + 1)))
            
        torch.cuda.empty_cache()
        
        val_loss = 0
        precision, recall, f1, accuracy = [], [], [], []
        
        model.eval()
        with torch.no_grad():
            for i, (data, target) in enumerate(valloader):
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                val_loss += loss_func(outputs, target)

                predicted_classes = torch.argmax(outputs, 1)            
                for acc, metric in zip((precision, recall, f1, accuracy), 
                                    (precision_score, recall_score, f1_score, accuracy_score)):
                    acc.append(
                        calculate_metric(metric, target.cpu(), predicted_classes.cpu())
                    )        
            
        print(f"Epoch {ep + 1}/{epochs}, training loss: {running_loss / len(trainloader)}, validation loss: {val_loss / len(valloader)}")
        print_scores(precision, recall, f1, accuracy, len(valloader))

In [0]:
def train_val_distill(student_model, teacher_model, epochs=1, lr=1e-3, alpha=0.5, temperature=1):
    #optimizer = optim.Adam(student_model.parameters(), lr=lr)
    optimizer = optim.SGD(student_model.parameters(), lr=lr, momentum=0.9)
    loss_func = nn.CrossEntropyLoss()
    loss_dist_func = nn.KLDivLoss(reduction='batchmean')
    #loss_dist_func = nn.MSELoss()

    for ep in range(epochs):
        total_loss = 0
        total_real_loss = 0
        total_dist_loss = 0
        pbar = tqdm(enumerate(trainloader), total=len(trainloader))
        student_model.train()
        teacher_model.eval()
        
        for i, (data, target) in pbar:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = student_model(data)
            teacher_output = teacher_model(data).detach()

            distill_loss = temperature**2 * loss_dist_func(F.log_softmax(output / temperature, dim=1), 
                                                           F.softmax(teacher_output / temperature, dim=1))
            #distill_loss = loss_dist_func(output, teacher_output)
            
            real_loss = loss_func(output, target)
            loss = alpha * distill_loss + (1 - alpha) * real_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_real_loss += real_loss.item()
            total_dist_loss += distill_loss.item()
            pbar.set_description('L: {:.2f}, \
                                  RL: {:.2f}, \
                                  DL: {:.2f}'.format(total_loss / (i + 1),
                                                      total_real_loss / (i + 1),
                                                      total_dist_loss / (i + 1)))

        torch.cuda.empty_cache()
        
        val_loss = 0
        precision, recall, f1, accuracy = [], [], [], []
        
        student_model.eval()
        with torch.no_grad():
            for i, (data, target) in enumerate(valloader):
                data, target = data.to(device), target.to(device)
                outputs = student_model(data)
                val_loss += loss_func(outputs, target)

                predicted_classes = torch.argmax(outputs, 1)            
                for acc, metric in zip((precision, recall, f1, accuracy), 
                                    (precision_score, recall_score, f1_score, accuracy_score)):
                    acc.append(
                        calculate_metric(metric, target.cpu(), predicted_classes.cpu())
                    )        
            
        print(f"Epoch {ep + 1}/{epochs}, training loss: {total_loss / len(trainloader)}, validation loss: {val_loss / len(valloader)}")
        print_scores(precision, recall, f1, accuracy, len(valloader))

## Учителем будет resnet18, адаптированный под 10 классов.

In [0]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=True)
        self.fc = nn.Linear(self.model.fc.out_features, 10)

    def forward(self, x):
        x = self.model.forward(x)
        x = self.fc(x)
        return x


resnet = ResNet().to(device)

## Дообучим на cifar и посмотрим, что выдает учитель.

In [62]:
train_val(resnet, epochs=5, lr=1e-3)

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 1/5, training loss: 0.5299041196132255, validation loss: 0.43157848715782166
	     precision: 0.8618
	        recall: 0.8472
	            F1: 0.8432
	      accuracy: 0.8513


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 2/5, training loss: 0.2848538991892734, validation loss: 0.35211697220802307
	     precision: 0.8882
	        recall: 0.8795
	            F1: 0.8765
	      accuracy: 0.8807


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 3/5, training loss: 0.20055560641886327, validation loss: 0.3238263428211212
	     precision: 0.8998
	        recall: 0.8950
	            F1: 0.8911
	      accuracy: 0.8979


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 4/5, training loss: 0.15155743245426043, validation loss: 0.31796908378601074
	     precision: 0.9120
	        recall: 0.9075
	            F1: 0.9052
	      accuracy: 0.9081


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 5/5, training loss: 0.12952035506877602, validation loss: 0.3444471061229706
	     precision: 0.9030
	        recall: 0.8933
	            F1: 0.8911
	      accuracy: 0.8959


## Учеником будет сеть с 2 conv и 3 fc слоями.

In [0]:
class StudentCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # 220
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5) # 106
        self.fc1 = nn.Linear(16 * 53 * 53, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 53 * 53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [54]:
student_cnn = StudentCNN().to(device)
train_val(student_cnn, 10, 1e-3)

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 1/10, training loss: 1.507739981422034, validation loss: 1.3007999658584595
	     precision: 0.5454
	        recall: 0.5311
	            F1: 0.5219
	      accuracy: 0.5327


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 2/10, training loss: 1.1767757041070162, validation loss: 1.274371862411499
	     precision: 0.5657
	        recall: 0.5396
	            F1: 0.5321
	      accuracy: 0.5418


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 3/10, training loss: 0.9957785290830276, validation loss: 1.1804554462432861
	     precision: 0.5964
	        recall: 0.5822
	            F1: 0.5700
	      accuracy: 0.5854


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 4/10, training loss: 0.8189343791788496, validation loss: 1.2489019632339478
	     precision: 0.5888
	        recall: 0.5849
	            F1: 0.5732
	      accuracy: 0.5876


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 5/10, training loss: 0.6415191535907023, validation loss: 1.362184762954712
	     precision: 0.5863
	        recall: 0.5779
	            F1: 0.5688
	      accuracy: 0.5813


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 6/10, training loss: 0.47887141289918317, validation loss: 1.587994933128357
	     precision: 0.5748
	        recall: 0.5664
	            F1: 0.5588
	      accuracy: 0.5696


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 7/10, training loss: 0.353288105267393, validation loss: 1.8255711793899536
	     precision: 0.5673
	        recall: 0.5710
	            F1: 0.5568
	      accuracy: 0.5741


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 8/10, training loss: 0.25257434854117194, validation loss: 2.2035276889801025
	     precision: 0.5624
	        recall: 0.5556
	            F1: 0.5444
	      accuracy: 0.5598


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 9/10, training loss: 0.1903197676172037, validation loss: 2.436962127685547
	     precision: 0.5649
	        recall: 0.5563
	            F1: 0.5484
	      accuracy: 0.5591


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 10/10, training loss: 0.15359124032508992, validation loss: 2.586085319519043
	     precision: 0.5535
	        recall: 0.5480
	            F1: 0.5386
	      accuracy: 0.5524


## Видно, что модель не справляется с данными и уже после 4-5 эпохи происходит переобучение и качество на валидационном датасете падает. Лучший скор - 0.587. Посмотрим, какое преимущество даст дистилляция.

In [63]:
student_cnn = StudentCNN().to(device)
train_val_distill(student_cnn, resnet, epochs=10, lr=1e-3, alpha=0.9, temperature=20) # L - total loss, RL - real loss, DL - distill loss

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))




  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/10, training loss: 17.07943392897506, validation loss: 2.841996431350708
	     precision: 0.4484
	        recall: 0.4089
	            F1: 0.3802
	      accuracy: 0.4108


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 2/10, training loss: 13.176162922168936, validation loss: 2.5575122833251953
	     precision: 0.4772
	        recall: 0.4699
	            F1: 0.4499
	      accuracy: 0.4719


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 3/10, training loss: 11.627193826543705, validation loss: 2.3405535221099854
	     precision: 0.5201
	        recall: 0.5024
	            F1: 0.4852
	      accuracy: 0.5056


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 4/10, training loss: 9.858404558332985, validation loss: 2.2976937294006348
	     precision: 0.5633
	        recall: 0.5494
	            F1: 0.5320
	      accuracy: 0.5528


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 5/10, training loss: 8.410825457414397, validation loss: 2.2153639793395996
	     precision: 0.5965
	        recall: 0.5554
	            F1: 0.5382
	      accuracy: 0.5581


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 6/10, training loss: 7.240879034447243, validation loss: 2.182177782058716
	     precision: 0.5845
	        recall: 0.5683
	            F1: 0.5523
	      accuracy: 0.5725


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 7/10, training loss: 6.294640425221084, validation loss: 2.1011600494384766
	     precision: 0.6162
	        recall: 0.5899
	            F1: 0.5808
	      accuracy: 0.5928


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 8/10, training loss: 5.451742529564196, validation loss: 2.2198352813720703
	     precision: 0.6132
	        recall: 0.5865
	            F1: 0.5709
	      accuracy: 0.5890


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 9/10, training loss: 4.7671665316042695, validation loss: 2.0691678524017334
	     precision: 0.6138
	        recall: 0.5909
	            F1: 0.5855
	      accuracy: 0.5937


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Epoch 10/10, training loss: 4.157553797182829, validation loss: 1.9686923027038574
	     precision: 0.6098
	        recall: 0.5974
	            F1: 0.5875
	      accuracy: 0.5974


## С помощью дистилляции (при таком же количестве эпох и оптимизаторе) удалось улучшить скор до 0.593 (против 0.587). Причем лосс стабильно падает и переобучения не происходит, если увеличить количество эпох, то модель обучится еще лучше. Как и следовало ожидать: одно из свойств дистилляции - регуляризация обучения.

In [0]:
student_cnn = StudentCNN().to(device)
train_val_distill(student_cnn, resnet, epochs=15, lr=1e-3, alpha=0.9, temperature=20) # L - total loss, RL - real loss, DL - distill loss

HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))