# Import

In [1]:
tensorboard_enable=False

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from sklearn import decomposition
from sklearn import manifold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np

import copy
from collections import namedtuple
import os
import random
import shutil
import time
import argparse
from tqdm import tqdm

from model_archive.ResNet import ResNet, Config
from loss_archive.knowledge_distillation_loss import KD_loss

import argparse

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

  from .autonotebook import tqdm as notebook_tqdm


# Arguments

In [3]:
args = argparse.Namespace(
    model='resnet50',
    batch_size=16,
    lr=1.25e-4,
    epochs=100,
    root_path='/root/datasets/archive/CUB_200_2011/',
    scheduler='no',
    pretrained='no',
    teacher_path='/root/workspace/CNN_work/teacher_model/teacher_model.pt'
)

# trainer.py

In [4]:
# some functions for training / model
def normalize_image(image):
    image_min = image.min()
    image_max = image.max()
    image.clamp_(min = image_min, max = image_max)
    image.add_(-image_min).div_(image_max - image_min + 1e-5)
    return image

def calculate_topk_accuracy(y_pred, y, k = 5):
    with torch.no_grad():
        batch_size = y.shape[0]
        _, top_pred = y_pred.topk(k, 1)
        top_pred = top_pred.t()
        correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
        correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
        acc_1 = correct_1 / batch_size
        acc_k = correct_k / batch_size
    return acc_1, acc_k

def train(model, iterator, optimizer, criterion, device, scheduler=None):
    
    epoch_loss = 0
    epoch_acc_1 = 0
    epoch_acc_5 = 0
    
    model.train()
    
    for step, (x, y) in enumerate(iterator):
        
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
                
        y_pred, _ = model(x)
        
        loss = criterion(y_pred, y)
        
        acc_1, acc_5 = calculate_topk_accuracy(y_pred, y)
        
        loss.backward()
        
        optimizer.step()
        
        if scheduler=='yes':
            scheduler.step()
        
        epoch_loss += loss.item()
        epoch_acc_1 += acc_1.item()
        epoch_acc_5 += acc_5.item()
        
    epoch_loss /= len(iterator)
    epoch_acc_1 /= len(iterator)
    epoch_acc_5 /= len(iterator)
        
    return epoch_loss, epoch_acc_1, epoch_acc_5

def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    epoch_acc_1 = 0
    epoch_acc_5 = 0
    
    model.eval()
    
    with torch.no_grad():
        
        for (x, y) in iterator:

            x = x.to(device)
            y = y.to(device)

            y_pred, _ = model(x)

            loss = criterion(y_pred, y)

            acc_1, acc_5 = calculate_topk_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc_1 += acc_1.item()
            epoch_acc_5 += acc_5.item()
        
    epoch_loss /= len(iterator)
    epoch_acc_1 /= len(iterator)
    epoch_acc_5 /= len(iterator)
        
    return epoch_loss, epoch_acc_1, epoch_acc_5

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


## Load Dataset

In [5]:
# Dataset
pretrained_size = 224
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds= [0.229, 0.224, 0.225]
train_transforms = transforms.Compose([
                        transforms.Resize(pretrained_size),
                        transforms.RandomRotation(5),
                        transforms.RandomHorizontalFlip(0.5),
                        transforms.RandomCrop(pretrained_size, padding = 10),
                        transforms.ToTensor(),
                        transforms.Normalize(mean = pretrained_means, 
                                                std = pretrained_stds)
                    ])
test_transforms = transforms.Compose([
                        transforms.Resize(pretrained_size),
                        transforms.CenterCrop(pretrained_size),
                        transforms.ToTensor(),
                        transforms.Normalize(mean = pretrained_means, 
                                                std = pretrained_stds)
                    ])

ROOT = args.root_path
data_dir = os.path.join(ROOT, 'CUB_200_2011')
images_dir = os.path.join(data_dir, 'images')
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')

train_data = datasets.ImageFolder(root = train_dir,
                                transform = train_transforms)
test_data = datasets.ImageFolder(root = test_dir,
                                transform = test_transforms)

VALID_RATIO = 0.8
n_train_examples = int(len(train_data)*VALID_RATIO)
n_valid_examples = len(train_data) - n_train_examples

train_data, valid_data = data.random_split(train_data,
                                        [n_train_examples, n_valid_examples])
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transform = test_transforms

BATCH_SIZE = args.batch_size
train_iterator = data.DataLoader(train_data, 
                                shuffle = True, 
                                batch_size = BATCH_SIZE)
valid_iterator = data.DataLoader(valid_data, 
                                batch_size = BATCH_SIZE)
test_iterator = data.DataLoader(test_data, 
                                batch_size = BATCH_SIZE)

In [6]:
class KD_loss(nn.Module):
    def __init__(self, Temperature):
        super(KD_loss,self).__init__()
        self.T = Temperature
    
    def forward(self, outputs, labels):
        """
            input : 
                y : (gt)
                y_stu : (student output)
                y_tea : (teacher output)
            output : 
                loss (Variable) : 논문's distillation loss
        """
        default_loss = nn.CrossEntropyLoss()(y_stu,y)               # TODO How this could work? --> "default_loss"  be an insatnce carrying some needed values.
        term1 = F.softmax(torch.mul(y_tea,1/T))         # nn.functional 이 softmax의 computational graph를 지원하나?
        term2 = F.softmax(torch.mul(y_stu,1/T))
        distill_loss = T**2 * nn.CrossEntropyLoss()(term1, term2)

        loss = default_loss + distill_loss

        return loss

## Student_model, Optimizer, Criterion, Scheduler

In [7]:
# Train
if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--model", type=str, default='resnet34')     # resnet18, 34, 50, 101, 152.
    # parser.add_argument("--batch_size", type=int, default=64)
    # parser.add_argument("--lr", type=float, default=1e-3)
    # parser.add_argument("--epochs", type=int, default=100)
    # parser.add_argument("--root_path", type=str, default='/home/jh/Desktop/VSC/CNN_work/archive/CUB_200_2011/')
    # parser.add_argument("--scheduler", type=str, default='no')      # yes / no
    # parser.add_argument("--pretrained", type=str, default='no')     # yes / no
    # parser.add_argument("--teacher_path", type=str, default='/home/jh/Desktop/VSC/CNN_work/teacher_model/teacher_model.pt')
    # args = parser.parse_args()

    print()
    print('<< Configurations >>')
    print(f'[*] Model       - {args.model}')
    print(f'[*] Batch_size  - {args.batch_size}')
    print(f'[*] LR          - {args.lr}')
    print(f'[*] Epochs      - {args.epochs}')

    # * Dasetset (Above)

    # get pretrained model.
    if args.model=='resnet18':
        if args.pretrained=='yes':
            model = models.resnet18(pretrained = True)
            print('[*] pre-trained model being used!')
    elif args.model=='resnet34':
        if args.pretrained=='yes':
            model = models.resnet34(pretrained = True)
            print('[*] pre-trained model being used!')
    elif args.model=='resnet50':
        if args.pretrained=='yes':
            model = models.resnet50(pretrained = True)
            print('[*] pre-trained model being used!')
    elif args.model=='resnet101':
        if args.pretrained=='yes':
            model = models.resnet101(pretrained = True)
            print('[*] pre-trained model being used!')
    elif args.model=='resnet152':
        if args.pretrained=='yes':
            model = models.resnet152(pretrained = True)
            print('[*] pre-trained model being used!')

    if args.pretrained!='yes':
        # make new model.
        print('[*] train newly initialized model!')
        config = Config()
        resnet_config = config.get_resnet_config(model_name = args.model)
        OUTPUT_DIM = len(test_data.classes)
        model = ResNet(resnet_config, OUTPUT_DIM)  # get resnetXXX

        print(f'[*] Parameters  - {count_parameters(model):,}')
    else:
        # Change FC layer in downloaded model for Transfer Learning.
        IN_FEATURES = model.fc.in_features 
        OUTPUT_DIM = len(test_data.classes)
        model.fc = nn.Linear(IN_FEATURES, OUTPUT_DIM)

        print(f'[*] Parameters  - {count_parameters(model):,}')
    
    # if args.half=='yes':
    #     model = model.half()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ! Knowledge Distillation Loss here
    criterion = KD_loss(Temperature=20)
    # criterion = nn.CrossEntropyLoss()

    if args.scheduler == 'yes':
        # cosine scheduler
        '''
        - 기존 lr overfitting지점인, 30-40 에 다다르기 전에 decay주는게 적합해보여.
        - 한번 그렇게 ``lr==0`` 까지 탐색하는것보단, hard_reset하면서 그 optima에서 빠져나와서 
          주변 다른 optima 들어가보는것도 좋지 않을까?
        - 지금 실험상황은 best 모델 찾는거고, epoch 100 가면서 어차피 튀는 경향성 보이니, 좋은 시도같은데?
        '''
        ITERATIONS = args.epochs * len(train_iterator)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=ITERATIONS, eta_min=1e-8)
    else:
        scheduler=None

    model = model.to(device)
    criterion = criterion.to(device)

    if tensorboard_enable:
        writer = SummaryWriter()


<< Configurations >>
[*] Model       - resnet50
[*] Batch_size  - 16
[*] LR          - 0.000125
[*] Epochs      - 100
[*] train newly initialized model!
[*] Parameters  - 23,917,832


# [Test] Knowledge Distillation

- teacher model 마련
- loss 구현
- train code 손보기

## Load Teacher Model

[TODO] Get Teacher Model

- Train "ResNet152" for "CUB200-2011" Dataset   ->   Get good accuracy

In [8]:
# run trainer_teacher_model.py
PATH = './teacher_model/ResNet152_bs128_lr0.001_epochs150_pretrained-yes.pt'
OUTPUT_DIM = 200    #! delete after 통합 to whole code.
teacher_model = ResNet(resnet_config, OUTPUT_DIM)
teacher_model.load_state_dict(torch.load(PATH))
teacher_model.eval()


RuntimeError: Error(s) in loading state_dict for ResNet:
	Unexpected key(s) in state_dict: "layer2.4.conv1.weight", "layer2.4.bn1.weight", "layer2.4.bn1.bias", "layer2.4.bn1.running_mean", "layer2.4.bn1.running_var", "layer2.4.bn1.num_batches_tracked", "layer2.4.conv2.weight", "layer2.4.bn2.weight", "layer2.4.bn2.bias", "layer2.4.bn2.running_mean", "layer2.4.bn2.running_var", "layer2.4.bn2.num_batches_tracked", "layer2.4.conv3.weight", "layer2.4.bn3.weight", "layer2.4.bn3.bias", "layer2.4.bn3.running_mean", "layer2.4.bn3.running_var", "layer2.4.bn3.num_batches_tracked", "layer2.5.conv1.weight", "layer2.5.bn1.weight", "layer2.5.bn1.bias", "layer2.5.bn1.running_mean", "layer2.5.bn1.running_var", "layer2.5.bn1.num_batches_tracked", "layer2.5.conv2.weight", "layer2.5.bn2.weight", "layer2.5.bn2.bias", "layer2.5.bn2.running_mean", "layer2.5.bn2.running_var", "layer2.5.bn2.num_batches_tracked", "layer2.5.conv3.weight", "layer2.5.bn3.weight", "layer2.5.bn3.bias", "layer2.5.bn3.running_mean", "layer2.5.bn3.running_var", "layer2.5.bn3.num_batches_tracked", "layer2.6.conv1.weight", "layer2.6.bn1.weight", "layer2.6.bn1.bias", "layer2.6.bn1.running_mean", "layer2.6.bn1.running_var", "layer2.6.bn1.num_batches_tracked", "layer2.6.conv2.weight", "layer2.6.bn2.weight", "layer2.6.bn2.bias", "layer2.6.bn2.running_mean", "layer2.6.bn2.running_var", "layer2.6.bn2.num_batches_tracked", "layer2.6.conv3.weight", "layer2.6.bn3.weight", "layer2.6.bn3.bias", "layer2.6.bn3.running_mean", "layer2.6.bn3.running_var", "layer2.6.bn3.num_batches_tracked", "layer2.7.conv1.weight", "layer2.7.bn1.weight", "layer2.7.bn1.bias", "layer2.7.bn1.running_mean", "layer2.7.bn1.running_var", "layer2.7.bn1.num_batches_tracked", "layer2.7.conv2.weight", "layer2.7.bn2.weight", "layer2.7.bn2.bias", "layer2.7.bn2.running_mean", "layer2.7.bn2.running_var", "layer2.7.bn2.num_batches_tracked", "layer2.7.conv3.weight", "layer2.7.bn3.weight", "layer2.7.bn3.bias", "layer2.7.bn3.running_mean", "layer2.7.bn3.running_var", "layer2.7.bn3.num_batches_tracked", "layer3.6.conv1.weight", "layer3.6.bn1.weight", "layer3.6.bn1.bias", "layer3.6.bn1.running_mean", "layer3.6.bn1.running_var", "layer3.6.bn1.num_batches_tracked", "layer3.6.conv2.weight", "layer3.6.bn2.weight", "layer3.6.bn2.bias", "layer3.6.bn2.running_mean", "layer3.6.bn2.running_var", "layer3.6.bn2.num_batches_tracked", "layer3.6.conv3.weight", "layer3.6.bn3.weight", "layer3.6.bn3.bias", "layer3.6.bn3.running_mean", "layer3.6.bn3.running_var", "layer3.6.bn3.num_batches_tracked", "layer3.7.conv1.weight", "layer3.7.bn1.weight", "layer3.7.bn1.bias", "layer3.7.bn1.running_mean", "layer3.7.bn1.running_var", "layer3.7.bn1.num_batches_tracked", "layer3.7.conv2.weight", "layer3.7.bn2.weight", "layer3.7.bn2.bias", "layer3.7.bn2.running_mean", "layer3.7.bn2.running_var", "layer3.7.bn2.num_batches_tracked", "layer3.7.conv3.weight", "layer3.7.bn3.weight", "layer3.7.bn3.bias", "layer3.7.bn3.running_mean", "layer3.7.bn3.running_var", "layer3.7.bn3.num_batches_tracked", "layer3.8.conv1.weight", "layer3.8.bn1.weight", "layer3.8.bn1.bias", "layer3.8.bn1.running_mean", "layer3.8.bn1.running_var", "layer3.8.bn1.num_batches_tracked", "layer3.8.conv2.weight", "layer3.8.bn2.weight", "layer3.8.bn2.bias", "layer3.8.bn2.running_mean", "layer3.8.bn2.running_var", "layer3.8.bn2.num_batches_tracked", "layer3.8.conv3.weight", "layer3.8.bn3.weight", "layer3.8.bn3.bias", "layer3.8.bn3.running_mean", "layer3.8.bn3.running_var", "layer3.8.bn3.num_batches_tracked", "layer3.9.conv1.weight", "layer3.9.bn1.weight", "layer3.9.bn1.bias", "layer3.9.bn1.running_mean", "layer3.9.bn1.running_var", "layer3.9.bn1.num_batches_tracked", "layer3.9.conv2.weight", "layer3.9.bn2.weight", "layer3.9.bn2.bias", "layer3.9.bn2.running_mean", "layer3.9.bn2.running_var", "layer3.9.bn2.num_batches_tracked", "layer3.9.conv3.weight", "layer3.9.bn3.weight", "layer3.9.bn3.bias", "layer3.9.bn3.running_mean", "layer3.9.bn3.running_var", "layer3.9.bn3.num_batches_tracked", "layer3.10.conv1.weight", "layer3.10.bn1.weight", "layer3.10.bn1.bias", "layer3.10.bn1.running_mean", "layer3.10.bn1.running_var", "layer3.10.bn1.num_batches_tracked", "layer3.10.conv2.weight", "layer3.10.bn2.weight", "layer3.10.bn2.bias", "layer3.10.bn2.running_mean", "layer3.10.bn2.running_var", "layer3.10.bn2.num_batches_tracked", "layer3.10.conv3.weight", "layer3.10.bn3.weight", "layer3.10.bn3.bias", "layer3.10.bn3.running_mean", "layer3.10.bn3.running_var", "layer3.10.bn3.num_batches_tracked", "layer3.11.conv1.weight", "layer3.11.bn1.weight", "layer3.11.bn1.bias", "layer3.11.bn1.running_mean", "layer3.11.bn1.running_var", "layer3.11.bn1.num_batches_tracked", "layer3.11.conv2.weight", "layer3.11.bn2.weight", "layer3.11.bn2.bias", "layer3.11.bn2.running_mean", "layer3.11.bn2.running_var", "layer3.11.bn2.num_batches_tracked", "layer3.11.conv3.weight", "layer3.11.bn3.weight", "layer3.11.bn3.bias", "layer3.11.bn3.running_mean", "layer3.11.bn3.running_var", "layer3.11.bn3.num_batches_tracked", "layer3.12.conv1.weight", "layer3.12.bn1.weight", "layer3.12.bn1.bias", "layer3.12.bn1.running_mean", "layer3.12.bn1.running_var", "layer3.12.bn1.num_batches_tracked", "layer3.12.conv2.weight", "layer3.12.bn2.weight", "layer3.12.bn2.bias", "layer3.12.bn2.running_mean", "layer3.12.bn2.running_var", "layer3.12.bn2.num_batches_tracked", "layer3.12.conv3.weight", "layer3.12.bn3.weight", "layer3.12.bn3.bias", "layer3.12.bn3.running_mean", "layer3.12.bn3.running_var", "layer3.12.bn3.num_batches_tracked", "layer3.13.conv1.weight", "layer3.13.bn1.weight", "layer3.13.bn1.bias", "layer3.13.bn1.running_mean", "layer3.13.bn1.running_var", "layer3.13.bn1.num_batches_tracked", "layer3.13.conv2.weight", "layer3.13.bn2.weight", "layer3.13.bn2.bias", "layer3.13.bn2.running_mean", "layer3.13.bn2.running_var", "layer3.13.bn2.num_batches_tracked", "layer3.13.conv3.weight", "layer3.13.bn3.weight", "layer3.13.bn3.bias", "layer3.13.bn3.running_mean", "layer3.13.bn3.running_var", "layer3.13.bn3.num_batches_tracked", "layer3.14.conv1.weight", "layer3.14.bn1.weight", "layer3.14.bn1.bias", "layer3.14.bn1.running_mean", "layer3.14.bn1.running_var", "layer3.14.bn1.num_batches_tracked", "layer3.14.conv2.weight", "layer3.14.bn2.weight", "layer3.14.bn2.bias", "layer3.14.bn2.running_mean", "layer3.14.bn2.running_var", "layer3.14.bn2.num_batches_tracked", "layer3.14.conv3.weight", "layer3.14.bn3.weight", "layer3.14.bn3.bias", "layer3.14.bn3.running_mean", "layer3.14.bn3.running_var", "layer3.14.bn3.num_batches_tracked", "layer3.15.conv1.weight", "layer3.15.bn1.weight", "layer3.15.bn1.bias", "layer3.15.bn1.running_mean", "layer3.15.bn1.running_var", "layer3.15.bn1.num_batches_tracked", "layer3.15.conv2.weight", "layer3.15.bn2.weight", "layer3.15.bn2.bias", "layer3.15.bn2.running_mean", "layer3.15.bn2.running_var", "layer3.15.bn2.num_batches_tracked", "layer3.15.conv3.weight", "layer3.15.bn3.weight", "layer3.15.bn3.bias", "layer3.15.bn3.running_mean", "layer3.15.bn3.running_var", "layer3.15.bn3.num_batches_tracked", "layer3.16.conv1.weight", "layer3.16.bn1.weight", "layer3.16.bn1.bias", "layer3.16.bn1.running_mean", "layer3.16.bn1.running_var", "layer3.16.bn1.num_batches_tracked", "layer3.16.conv2.weight", "layer3.16.bn2.weight", "layer3.16.bn2.bias", "layer3.16.bn2.running_mean", "layer3.16.bn2.running_var", "layer3.16.bn2.num_batches_tracked", "layer3.16.conv3.weight", "layer3.16.bn3.weight", "layer3.16.bn3.bias", "layer3.16.bn3.running_mean", "layer3.16.bn3.running_var", "layer3.16.bn3.num_batches_tracked", "layer3.17.conv1.weight", "layer3.17.bn1.weight", "layer3.17.bn1.bias", "layer3.17.bn1.running_mean", "layer3.17.bn1.running_var", "layer3.17.bn1.num_batches_tracked", "layer3.17.conv2.weight", "layer3.17.bn2.weight", "layer3.17.bn2.bias", "layer3.17.bn2.running_mean", "layer3.17.bn2.running_var", "layer3.17.bn2.num_batches_tracked", "layer3.17.conv3.weight", "layer3.17.bn3.weight", "layer3.17.bn3.bias", "layer3.17.bn3.running_mean", "layer3.17.bn3.running_var", "layer3.17.bn3.num_batches_tracked", "layer3.18.conv1.weight", "layer3.18.bn1.weight", "layer3.18.bn1.bias", "layer3.18.bn1.running_mean", "layer3.18.bn1.running_var", "layer3.18.bn1.num_batches_tracked", "layer3.18.conv2.weight", "layer3.18.bn2.weight", "layer3.18.bn2.bias", "layer3.18.bn2.running_mean", "layer3.18.bn2.running_var", "layer3.18.bn2.num_batches_tracked", "layer3.18.conv3.weight", "layer3.18.bn3.weight", "layer3.18.bn3.bias", "layer3.18.bn3.running_mean", "layer3.18.bn3.running_var", "layer3.18.bn3.num_batches_tracked", "layer3.19.conv1.weight", "layer3.19.bn1.weight", "layer3.19.bn1.bias", "layer3.19.bn1.running_mean", "layer3.19.bn1.running_var", "layer3.19.bn1.num_batches_tracked", "layer3.19.conv2.weight", "layer3.19.bn2.weight", "layer3.19.bn2.bias", "layer3.19.bn2.running_mean", "layer3.19.bn2.running_var", "layer3.19.bn2.num_batches_tracked", "layer3.19.conv3.weight", "layer3.19.bn3.weight", "layer3.19.bn3.bias", "layer3.19.bn3.running_mean", "layer3.19.bn3.running_var", "layer3.19.bn3.num_batches_tracked", "layer3.20.conv1.weight", "layer3.20.bn1.weight", "layer3.20.bn1.bias", "layer3.20.bn1.running_mean", "layer3.20.bn1.running_var", "layer3.20.bn1.num_batches_tracked", "layer3.20.conv2.weight", "layer3.20.bn2.weight", "layer3.20.bn2.bias", "layer3.20.bn2.running_mean", "layer3.20.bn2.running_var", "layer3.20.bn2.num_batches_tracked", "layer3.20.conv3.weight", "layer3.20.bn3.weight", "layer3.20.bn3.bias", "layer3.20.bn3.running_mean", "layer3.20.bn3.running_var", "layer3.20.bn3.num_batches_tracked", "layer3.21.conv1.weight", "layer3.21.bn1.weight", "layer3.21.bn1.bias", "layer3.21.bn1.running_mean", "layer3.21.bn1.running_var", "layer3.21.bn1.num_batches_tracked", "layer3.21.conv2.weight", "layer3.21.bn2.weight", "layer3.21.bn2.bias", "layer3.21.bn2.running_mean", "layer3.21.bn2.running_var", "layer3.21.bn2.num_batches_tracked", "layer3.21.conv3.weight", "layer3.21.bn3.weight", "layer3.21.bn3.bias", "layer3.21.bn3.running_mean", "layer3.21.bn3.running_var", "layer3.21.bn3.num_batches_tracked", "layer3.22.conv1.weight", "layer3.22.bn1.weight", "layer3.22.bn1.bias", "layer3.22.bn1.running_mean", "layer3.22.bn1.running_var", "layer3.22.bn1.num_batches_tracked", "layer3.22.conv2.weight", "layer3.22.bn2.weight", "layer3.22.bn2.bias", "layer3.22.bn2.running_mean", "layer3.22.bn2.running_var", "layer3.22.bn2.num_batches_tracked", "layer3.22.conv3.weight", "layer3.22.bn3.weight", "layer3.22.bn3.bias", "layer3.22.bn3.running_mean", "layer3.22.bn3.running_var", "layer3.22.bn3.num_batches_tracked", "layer3.23.conv1.weight", "layer3.23.bn1.weight", "layer3.23.bn1.bias", "layer3.23.bn1.running_mean", "layer3.23.bn1.running_var", "layer3.23.bn1.num_batches_tracked", "layer3.23.conv2.weight", "layer3.23.bn2.weight", "layer3.23.bn2.bias", "layer3.23.bn2.running_mean", "layer3.23.bn2.running_var", "layer3.23.bn2.num_batches_tracked", "layer3.23.conv3.weight", "layer3.23.bn3.weight", "layer3.23.bn3.bias", "layer3.23.bn3.running_mean", "layer3.23.bn3.running_var", "layer3.23.bn3.num_batches_tracked", "layer3.24.conv1.weight", "layer3.24.bn1.weight", "layer3.24.bn1.bias", "layer3.24.bn1.running_mean", "layer3.24.bn1.running_var", "layer3.24.bn1.num_batches_tracked", "layer3.24.conv2.weight", "layer3.24.bn2.weight", "layer3.24.bn2.bias", "layer3.24.bn2.running_mean", "layer3.24.bn2.running_var", "layer3.24.bn2.num_batches_tracked", "layer3.24.conv3.weight", "layer3.24.bn3.weight", "layer3.24.bn3.bias", "layer3.24.bn3.running_mean", "layer3.24.bn3.running_var", "layer3.24.bn3.num_batches_tracked", "layer3.25.conv1.weight", "layer3.25.bn1.weight", "layer3.25.bn1.bias", "layer3.25.bn1.running_mean", "layer3.25.bn1.running_var", "layer3.25.bn1.num_batches_tracked", "layer3.25.conv2.weight", "layer3.25.bn2.weight", "layer3.25.bn2.bias", "layer3.25.bn2.running_mean", "layer3.25.bn2.running_var", "layer3.25.bn2.num_batches_tracked", "layer3.25.conv3.weight", "layer3.25.bn3.weight", "layer3.25.bn3.bias", "layer3.25.bn3.running_mean", "layer3.25.bn3.running_var", "layer3.25.bn3.num_batches_tracked", "layer3.26.conv1.weight", "layer3.26.bn1.weight", "layer3.26.bn1.bias", "layer3.26.bn1.running_mean", "layer3.26.bn1.running_var", "layer3.26.bn1.num_batches_tracked", "layer3.26.conv2.weight", "layer3.26.bn2.weight", "layer3.26.bn2.bias", "layer3.26.bn2.running_mean", "layer3.26.bn2.running_var", "layer3.26.bn2.num_batches_tracked", "layer3.26.conv3.weight", "layer3.26.bn3.weight", "layer3.26.bn3.bias", "layer3.26.bn3.running_mean", "layer3.26.bn3.running_var", "layer3.26.bn3.num_batches_tracked", "layer3.27.conv1.weight", "layer3.27.bn1.weight", "layer3.27.bn1.bias", "layer3.27.bn1.running_mean", "layer3.27.bn1.running_var", "layer3.27.bn1.num_batches_tracked", "layer3.27.conv2.weight", "layer3.27.bn2.weight", "layer3.27.bn2.bias", "layer3.27.bn2.running_mean", "layer3.27.bn2.running_var", "layer3.27.bn2.num_batches_tracked", "layer3.27.conv3.weight", "layer3.27.bn3.weight", "layer3.27.bn3.bias", "layer3.27.bn3.running_mean", "layer3.27.bn3.running_var", "layer3.27.bn3.num_batches_tracked", "layer3.28.conv1.weight", "layer3.28.bn1.weight", "layer3.28.bn1.bias", "layer3.28.bn1.running_mean", "layer3.28.bn1.running_var", "layer3.28.bn1.num_batches_tracked", "layer3.28.conv2.weight", "layer3.28.bn2.weight", "layer3.28.bn2.bias", "layer3.28.bn2.running_mean", "layer3.28.bn2.running_var", "layer3.28.bn2.num_batches_tracked", "layer3.28.conv3.weight", "layer3.28.bn3.weight", "layer3.28.bn3.bias", "layer3.28.bn3.running_mean", "layer3.28.bn3.running_var", "layer3.28.bn3.num_batches_tracked", "layer3.29.conv1.weight", "layer3.29.bn1.weight", "layer3.29.bn1.bias", "layer3.29.bn1.running_mean", "layer3.29.bn1.running_var", "layer3.29.bn1.num_batches_tracked", "layer3.29.conv2.weight", "layer3.29.bn2.weight", "layer3.29.bn2.bias", "layer3.29.bn2.running_mean", "layer3.29.bn2.running_var", "layer3.29.bn2.num_batches_tracked", "layer3.29.conv3.weight", "layer3.29.bn3.weight", "layer3.29.bn3.bias", "layer3.29.bn3.running_mean", "layer3.29.bn3.running_var", "layer3.29.bn3.num_batches_tracked", "layer3.30.conv1.weight", "layer3.30.bn1.weight", "layer3.30.bn1.bias", "layer3.30.bn1.running_mean", "layer3.30.bn1.running_var", "layer3.30.bn1.num_batches_tracked", "layer3.30.conv2.weight", "layer3.30.bn2.weight", "layer3.30.bn2.bias", "layer3.30.bn2.running_mean", "layer3.30.bn2.running_var", "layer3.30.bn2.num_batches_tracked", "layer3.30.conv3.weight", "layer3.30.bn3.weight", "layer3.30.bn3.bias", "layer3.30.bn3.running_mean", "layer3.30.bn3.running_var", "layer3.30.bn3.num_batches_tracked", "layer3.31.conv1.weight", "layer3.31.bn1.weight", "layer3.31.bn1.bias", "layer3.31.bn1.running_mean", "layer3.31.bn1.running_var", "layer3.31.bn1.num_batches_tracked", "layer3.31.conv2.weight", "layer3.31.bn2.weight", "layer3.31.bn2.bias", "layer3.31.bn2.running_mean", "layer3.31.bn2.running_var", "layer3.31.bn2.num_batches_tracked", "layer3.31.conv3.weight", "layer3.31.bn3.weight", "layer3.31.bn3.bias", "layer3.31.bn3.running_mean", "layer3.31.bn3.running_var", "layer3.31.bn3.num_batches_tracked", "layer3.32.conv1.weight", "layer3.32.bn1.weight", "layer3.32.bn1.bias", "layer3.32.bn1.running_mean", "layer3.32.bn1.running_var", "layer3.32.bn1.num_batches_tracked", "layer3.32.conv2.weight", "layer3.32.bn2.weight", "layer3.32.bn2.bias", "layer3.32.bn2.running_mean", "layer3.32.bn2.running_var", "layer3.32.bn2.num_batches_tracked", "layer3.32.conv3.weight", "layer3.32.bn3.weight", "layer3.32.bn3.bias", "layer3.32.bn3.running_mean", "layer3.32.bn3.running_var", "layer3.32.bn3.num_batches_tracked", "layer3.33.conv1.weight", "layer3.33.bn1.weight", "layer3.33.bn1.bias", "layer3.33.bn1.running_mean", "layer3.33.bn1.running_var", "layer3.33.bn1.num_batches_tracked", "layer3.33.conv2.weight", "layer3.33.bn2.weight", "layer3.33.bn2.bias", "layer3.33.bn2.running_mean", "layer3.33.bn2.running_var", "layer3.33.bn2.num_batches_tracked", "layer3.33.conv3.weight", "layer3.33.bn3.weight", "layer3.33.bn3.bias", "layer3.33.bn3.running_mean", "layer3.33.bn3.running_var", "layer3.33.bn3.num_batches_tracked", "layer3.34.conv1.weight", "layer3.34.bn1.weight", "layer3.34.bn1.bias", "layer3.34.bn1.running_mean", "layer3.34.bn1.running_var", "layer3.34.bn1.num_batches_tracked", "layer3.34.conv2.weight", "layer3.34.bn2.weight", "layer3.34.bn2.bias", "layer3.34.bn2.running_mean", "layer3.34.bn2.running_var", "layer3.34.bn2.num_batches_tracked", "layer3.34.conv3.weight", "layer3.34.bn3.weight", "layer3.34.bn3.bias", "layer3.34.bn3.running_mean", "layer3.34.bn3.running_var", "layer3.34.bn3.num_batches_tracked", "layer3.35.conv1.weight", "layer3.35.bn1.weight", "layer3.35.bn1.bias", "layer3.35.bn1.running_mean", "layer3.35.bn1.running_var", "layer3.35.bn1.num_batches_tracked", "layer3.35.conv2.weight", "layer3.35.bn2.weight", "layer3.35.bn2.bias", "layer3.35.bn2.running_mean", "layer3.35.bn2.running_var", "layer3.35.bn2.num_batches_tracked", "layer3.35.conv3.weight", "layer3.35.bn3.weight", "layer3.35.bn3.bias", "layer3.35.bn3.running_mean", "layer3.35.bn3.running_var", "layer3.35.bn3.num_batches_tracked". 

Load my teacher model

In [None]:
# load model architecture
teacher_model = models.resnet152(pretrained = False)

IN_FEATURES = downloaded_model.fc.in_features 
OUTPUT_DIM = len(test_data.classes)
teacher_model.fc = nn.Linear(IN_FEATURES, OUTPUT_DIM)

print(f'[*] Teacher Model Parameters  - {count_parameters(teacher_model):,}')

# load trained model from ./teacher_model/teacher_model.pt
teacher_model.load_state_dict(torch.load(args.teacher_path))

## Distillation Loss

##### [Solving]  
I am not sure that this function properly do an backpropagation or grad storing stuffs. I need to check it before real training.

In [138]:
from loss_archive.knowledge_distillation_loss import KD_loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class KD_loss(nn.Module):
    def __init__(self, Temperature):
        super(KD_loss,self).__init__()
        self.T = Temperature
    
    def forward(self, outputs, labels):
        """
            input : 
                y : (gt)
                y_stu : (student output)
                y_tea : (teacher output)
            output : 
                loss (Variable) : 논문's distillation loss
        """
        default_loss = nn.CrossEntropyLoss()(y_stu,y)               # TODO How this could work? --> "default_loss"  be an insatnce carrying some needed values.
        term1 = F.softmax(torch.mul(y_tea,1/T))         # nn.functional 이 softmax의 computational graph를 지원하나?
        term2 = F.softmax(torch.mul(y_stu,1/T))
        distill_loss = T**2 * nn.CrossEntropyLoss()(term1, term2)

        loss = default_loss + distill_loss

        return loss

In [139]:
# criterion = KD_loss(Temperature=0.1)

functional loss testing.  

but don't know why it is needed.

In [133]:
import torch.nn.functional as F

def kd_loss_functional(y, y_stu, y_tea, T):
    default_loss = nn.CrossEntropyLoss()(y_stu,y)               # TODO How this could work? --> "default_loss"  be an insatnce carrying some needed values.
    term1 = F.softmax(torch.mul(y_tea,1/T))         # nn.functional 이 softmax의 computational graph를 지원하나?
    term2 = F.softmax(torch.mul(y_stu,1/T))
    distill_loss = T**2 * nn.CrossEntropyLoss()(term1, term2)

    loss = default_loss + distill_loss

    return loss

In [None]:
kd_loss_functional()

Q. The difference btw nn.functional.crossentropy() and nn.CrossEntropyLoss()  
A. nn.functional.crossentropy() doesn't exists.

In [121]:
a=torch.Tensor([[0,0,0],[1,1,1]])
b=torch.Tensor([[1,1,1],[0,0,0]])

torch.sum(a,b)

TypeError: sum() received an invalid combination of arguments - got (Tensor, Tensor), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)


## Train Code 손보기

# Train

In [None]:
# Model training.
best_valid_loss = float('inf')
best_valid_epoch = 0

print('[*] Start Training !', end='\n\n')
for epoch in range(args.epochs):
    start_time = time.monotonic()
    
    train_loss, train_acc_1, train_acc_5 = train(model, train_iterator, optimizer, criterion, device, scheduler)
    valid_loss, valid_acc_1, valid_acc_5 = evaluate(model, valid_iterator, criterion, device)

    if tensorboard_enable:
        writer.add_scalar("loss/train", train_loss, epoch)  # tensorboard
        writer.add_scalar("loss/val", valid_loss, epoch)    # tensorboard

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        best_valid_epoch = epoch
        torch.save(model.state_dict(), f'./saved/{args.model}_bs{args.batch_size}_lr{args.lr}_epochs{args.epochs}_pretrained-{args.pretrained}.pt')
    
    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc @1: {train_acc_1*100:6.2f}% | ' \
        f'Train Acc @5: {train_acc_5*100:6.2f}%')
    print(f'\tValid Loss: {valid_loss:.3f} | Valid Acc @1: {valid_acc_1*100:6.2f}% | ' \
        f'Valid Acc @5: {valid_acc_5*100:6.2f}%')

print()
print(f"Best valid epoch : {best_valid_epoch}/{args.epochs} epochs")

if tensorboard_enbale:
    writer.flush()