In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"


import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms


import utils
import random
import numpy as np
from tqdm import tqdm

teacher = torchvision.models.vit_b_16(pretrained = True)
student = torchvision.models.vit_b_16(pretrained = True)
# model = torchvision.models.vit_l_32(pretrained=True)

print(f"patch_size : {student.patch_size}")
print(f"image_size : {student.image_size}")
print(f"hidden_dim : {student.hidden_dim}")

use_trained_model = False

# 현재 testset과 trainset의 normalize가 다르다. 그런데 student model의 성능이 teacher model에 비해 월등하다는 것은 domain invariant하다는 것이 아닌가?

normalize = transforms.Normalize(mean=[0.5074,0.4867,0.4411],
                                 std=[0.2011,0.1987,0.2025])

# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
# ])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])


transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
])

dataset_val = torchvision.datasets.CIFAR100(root="./", train=False, transform=transform, download=True)
dataset_train = torchvision.datasets.CIFAR100(root="./", train=True, transform=transform_train, download=True)

val_loader = torch.utils.data.DataLoader(dataset_val,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=4)

train_loader = torch.utils.data.DataLoader(dataset_train,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=4)

device = 'cuda' if torch.cuda.is_available() else 'cpu'


patch_size : 16
image_size : 224
hidden_dim : 768
Files already downloaded and verified
Files already downloaded and verified


In [2]:
if use_trained_model:
    teacher = torch.load("CIFAR100_vit_b_16_71.pth")
else:
    teacher = utils.VisionTransformer(teacher)
    teacher.heads.head = nn.Linear(768, 100)
                         
student = utils.VisionTransformer(student)
student.heads.head = nn.Linear(768, 100)



S_optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)
T_optimizer = optim.SGD(teacher.parameters(), lr=0.01, momentum=0.9)
CE_loss = nn.CrossEntropyLoss()

S_scheduler = torch.optim.lr_scheduler.MultiStepLR(S_optimizer, milestones=[1,2,3,4,5,6,7], gamma=0.1)
T_scheduler = torch.optim.lr_scheduler.MultiStepLR(T_optimizer, milestones=[1,2,3,4,5,6,7], gamma=0.1)

student = student.cuda()
teacher = teacher.cuda()



In [None]:


if not use_trained_model:
    criterion_onlylabel = lambda a,b : mse(a*b, b)
    best_acc = 0.0
    stack = 0 
    criterion_CE = nn.CrossEntropyLoss()
    mse = nn.MSELoss()
    softmax = torch.nn.Softmax(dim = 1)
    criterion_KLD = torch.nn.KLDivLoss(reduction="batchmean")
    criterion_response = lambda a,b : criterion_KLD(torch.log_softmax(a, dim=1),torch.softmax(b, dim=1))




    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(0)


    for epoch in range(100):
        print(f"lr : {S_scheduler.get_last_lr()}")
        T_correct = 0
        S_correct = 0
        all_data = 0

        loss_distill = []
        loss_CE = []
        loss_response = []
        student = student.train()
        for img, label in tqdm(train_loader):
            img = img.cuda()
            label = label.cuda()

            T_optimizer.zero_grad()
            output, _ = teacher(img, 0)

#             output = softmax(output)
            loss = criterion_CE(output, label)
            loss.backward()
            T_optimizer.step()
            
        T_scheduler.step()
        test_acc = utils.test(teacher, val_loader, epoch)


        if test_acc > best_acc:
            stack = 0
            best_acc = test_acc
        else:
            stack+=1
        past_acc = test_acc

        if stack > 5:
            S_scheduler.step()
            stack = 0

        print("=" * 100)        

lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.81it/s]


0 	 test acc : 0.7590000033378601
lr : [0.01]


100%|██████████| 391/391 [03:29<00:00,  1.86it/s]
100%|██████████| 79/79 [00:16<00:00,  4.80it/s]


1 	 test acc : 0.8314999938011169
lr : [0.01]


100%|██████████| 391/391 [03:29<00:00,  1.86it/s]
100%|██████████| 79/79 [00:16<00:00,  4.81it/s]


2 	 test acc : 0.8355000019073486
lr : [0.01]


100%|██████████| 391/391 [03:29<00:00,  1.86it/s]
100%|██████████| 79/79 [00:16<00:00,  4.76it/s]


3 	 test acc : 0.835599958896637
lr : [0.01]


100%|██████████| 391/391 [03:29<00:00,  1.87it/s]
100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


4 	 test acc : 0.8356999754905701
lr : [0.01]


100%|██████████| 391/391 [03:29<00:00,  1.87it/s]
100%|██████████| 79/79 [00:16<00:00,  4.80it/s]


5 	 test acc : 0.8356999754905701
lr : [0.01]


  1%|          | 4/391 [00:02<03:56,  1.63it/s]

In [3]:


if not use_trained_model:
    criterion_onlylabel = lambda a,b : mse(a*b, b)
    best_acc = 0.0
    stack = 0 
    criterion_CE = nn.CrossEntropyLoss()
    mse = nn.MSELoss()
    softmax = torch.nn.Softmax(dim = 1)
    criterion_KLD = torch.nn.KLDivLoss(reduction="batchmean")
    criterion_response = lambda a,b : criterion_KLD(torch.log_softmax(a, dim=1),torch.softmax(b, dim=1))




    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(0)


    for epoch in range(100):
        print(f"lr : {S_scheduler.get_last_lr()}")
        T_correct = 0
        S_correct = 0
        all_data = 0

        loss_distill = []
        loss_CE = []
        loss_response = []
        student = student.train()
        for img, label in tqdm(train_loader):
            img = img.cuda()
            label = label.cuda()

            T_optimizer.zero_grad()
            output, _ = teacher(img, 0)

#             output = softmax(output)
            loss = criterion_CE(output, label)
            loss.backward()
            T_optimizer.step()
            
        T_scheduler.step()
        test_acc = utils.test(teacher, val_loader, epoch)


        if test_acc > best_acc:
            stack = 0
            best_acc = test_acc
        else:
            stack+=1
        past_acc = test_acc

        if stack > 5:
            S_scheduler.step()
            stack = 0

        print("=" * 100)        

lr : [0.01]


100%|██████████| 391/391 [03:25<00:00,  1.90it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


0 	 test acc : 0.5884999632835388
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.91it/s]


1 	 test acc : 0.7601000070571899
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


2 	 test acc : 0.7750999927520752
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


3 	 test acc : 0.7770999670028687
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


4 	 test acc : 0.7773000001907349
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


5 	 test acc : 0.7771999835968018
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


6 	 test acc : 0.7773000001907349
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


7 	 test acc : 0.7771999835968018
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


8 	 test acc : 0.7771999835968018
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


9 	 test acc : 0.7771999835968018
lr : [0.01]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


10 	 test acc : 0.7773000001907349
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


11 	 test acc : 0.7771999835968018
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


12 	 test acc : 0.7773000001907349
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


13 	 test acc : 0.7773000001907349
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


14 	 test acc : 0.7771999835968018
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


15 	 test acc : 0.7773000001907349
lr : [0.001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


16 	 test acc : 0.7773000001907349
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


17 	 test acc : 0.7773000001907349
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


18 	 test acc : 0.7771999835968018
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


19 	 test acc : 0.7773000001907349
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


20 	 test acc : 0.7771999835968018
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


21 	 test acc : 0.7773000001907349
lr : [0.0001]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.91it/s]


22 	 test acc : 0.7771999835968018
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


23 	 test acc : 0.7771999835968018
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


24 	 test acc : 0.7773000001907349
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


25 	 test acc : 0.7771999835968018
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


26 	 test acc : 0.7771999835968018
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


27 	 test acc : 0.7773000001907349
lr : [1e-05]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


28 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


29 	 test acc : 0.7771999835968018
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


30 	 test acc : 0.7771999835968018
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


31 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


32 	 test acc : 0.7770999670028687
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


33 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


34 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


35 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


36 	 test acc : 0.7771999835968018
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


37 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


38 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


39 	 test acc : 0.7773000001907349
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


40 	 test acc : 0.7771999835968018
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


41 	 test acc : 0.7771999835968018
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


42 	 test acc : 0.7771999835968018
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


43 	 test acc : 0.7771999835968018
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


44 	 test acc : 0.7773000001907349
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


45 	 test acc : 0.7773000001907349
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


46 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


47 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


48 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


49 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


50 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


51 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


52 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


53 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


54 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


55 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


56 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


57 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


58 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


59 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


60 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


61 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


62 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


63 	 test acc : 0.7770999670028687
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


64 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


65 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


66 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


67 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


68 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


69 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


70 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


71 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


72 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


73 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


74 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


75 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


76 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


77 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


78 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


79 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


80 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


81 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.89it/s]
100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


82 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


83 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


84 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


85 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


86 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


87 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


88 	 test acc : 0.7773000001907349
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


89 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


90 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


91 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


92 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


93 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


94 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


95 	 test acc : 0.7773999571800232
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


96 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


97 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


98 	 test acc : 0.7771999835968018
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [03:27<00:00,  1.88it/s]
100%|██████████| 79/79 [00:16<00:00,  4.87it/s]

99 	 test acc : 0.7771999835968018





In [4]:
utils.test(teacher, val_loader)

100%|██████████| 79/79 [00:16<00:00,  4.91it/s]

0 	 test acc : 0.7110999822616577





tensor(0.7111, device='cuda:0')

In [5]:
import numpy as np

criterion_onlylabel = lambda a,b : mse(a*b, b)

criterion_CE = nn.CrossEntropyLoss()
mse = nn.MSELoss()
softmax = torch.nn.Softmax(dim = 1)
criterion_KLD = torch.nn.KLDivLoss(reduction="batchmean")
criterion_response = lambda a,b : criterion_KLD(torch.log_softmax(a, dim=1),torch.softmax(b, dim=1))

    
import torch
import random
import torch.backends.cudnn as cudnn
import numpy as np

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)
past_acc = 0.0
stack = 0  
for epoch in range(100):
    print(f"lr : {S_scheduler.get_last_lr()}")
    T_correct = 0
    S_correct = 0
    all_data = 0
    
    loss_distill = []
    loss_CE = []
    loss_response = []
    student = student.train()
    for img, label in tqdm(train_loader):
        input_data = img.cuda()
        label = label.cuda()
        
        
        all_data += len(input_data)
        input_lrp = utils.get_LRP_img(input_data, label, teacher, criterion_CE, T_optimizer, mean=0.9, std = 0.4).cuda()
        
        S_optimizer.zero_grad()
        T_optimizer.zero_grad()

        layer = random.randint(0,  3+len(teacher.encoder.layers))
        input_data, fk = student(input_data,layer)
        input_lrp, fk_lrp = teacher(input_lrp,layer)
        
        distill_loss = mse(fk, fk_lrp)
                    
        CE_loss = criterion_CE(input_data, label)
        
#         response_loss = criterion_response(input_data, input_lrp)
        
        T_correct += sum(label == torch.argmax(input_lrp, dim=1))
        S_correct += sum(label == torch.argmax(input_data, dim=1))
        
        loss_CE.append(CE_loss.item())
        loss_distill.append(distill_loss.item())

        loss = (distill_loss * 2 + CE_loss) / 2
        loss.backward()
        S_optimizer.step()
    print("distill loss : ", sum(loss_distill) / len(loss_distill))
    print("general loss : ", sum(loss_CE) / len(loss_CE))
#     print("response loss : ", sum(loss_response) / len(loss_response))
    
    print(f"Teacher acc: {T_correct / all_data}")
    print(f"Student acc: {S_correct / all_data}")
    test_acc = utils.test(student, val_loader, epoch) # student도 변하는거 확인 완료함
    utils.test(teacher, val_loader, epoch) # student도 변하는거 확인 완료함
    
    if test_acc > past_acc + 0.01:
        stack = 0
    elif past_acc + 0.005 > test_acc:
        stack+=1
    past_acc = test_acc
    
    if stack > 5:
        S_scheduler.step()
        stack = 0
        
    print("=" * 100)

lr : [0.01]


100%|██████████| 391/391 [09:07<00:00,  1.40s/it]


distill loss :  0.2791871545107468
general loss :  1.9853936442938607
Teacher acc: 0.5732799768447876
Student acc: 0.5217800140380859


100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


0 	 test acc : 0.796999990940094


100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


0 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.2587928894306997
general loss :  1.1723008570463762
Teacher acc: 0.574999988079071
Student acc: 0.6806399822235107


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


1 	 test acc : 0.8190000057220459


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


1 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.23086316939776816
general loss :  0.9988112303302111
Teacher acc: 0.570360004901886
Student acc: 0.7262199521064758


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


2 	 test acc : 0.8448999524116516


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


2 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.22888771772308422
general loss :  0.9020073906235073
Teacher acc: 0.5747399926185608
Student acc: 0.7513399720191956


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


3 	 test acc : 0.843999981880188


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


3 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:08<00:00,  1.40s/it]


distill loss :  0.21709994659246995
general loss :  0.84144124609735
Teacher acc: 0.5732399821281433
Student acc: 0.7674799561500549


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


4 	 test acc : 0.8513999581336975


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


4 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.21088632300038776
general loss :  0.7746223479585574
Teacher acc: 0.573199987411499
Student acc: 0.7830199599266052


100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


5 	 test acc : 0.8611999750137329


100%|██████████| 79/79 [00:16<00:00,  4.91it/s]


5 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.2167472858410662
general loss :  0.7404265484541578
Teacher acc: 0.5740199685096741
Student acc: 0.7953799962997437


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


6 	 test acc : 0.8592000007629395


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


6 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.21336919223637227
general loss :  0.6863094514135815
Teacher acc: 0.5758399963378906
Student acc: 0.8079800009727478


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


7 	 test acc : 0.8570999503135681


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


7 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.2118691765629422
general loss :  0.6582499099966815
Teacher acc: 0.5755800008773804
Student acc: 0.8161799907684326


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


8 	 test acc : 0.8671000003814697


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


8 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.2043076970750261
general loss :  0.6256764249880905
Teacher acc: 0.5733399987220764
Student acc: 0.8267799615859985


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


9 	 test acc : 0.8623999953269958


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


9 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.21551310827436349
general loss :  0.6115903490034821
Teacher acc: 0.5699999928474426
Student acc: 0.8310799598693848


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


10 	 test acc : 0.8673999905586243


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


10 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:16<00:00,  1.42s/it]


distill loss :  0.21621162094690305
general loss :  0.5898118304931904
Teacher acc: 0.5728999972343445
Student acc: 0.8352999687194824


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


11 	 test acc : 0.8725999593734741


100%|██████████| 79/79 [00:16<00:00,  4.91it/s]


11 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.19831636242206443
general loss :  0.56981111761859
Teacher acc: 0.5725199580192566
Student acc: 0.8426399827003479


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


12 	 test acc : 0.8661999702453613


100%|██████████| 79/79 [00:16<00:00,  4.80it/s]


12 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:07<00:00,  1.40s/it]


distill loss :  0.19100781327204022
general loss :  0.5465096454791096
Teacher acc: 0.5719599723815918
Student acc: 0.8497999906539917


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


13 	 test acc : 0.8693000078201294


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


13 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.19582368168608308
general loss :  0.5245184423521047
Teacher acc: 0.5735399723052979
Student acc: 0.8559199571609497


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


14 	 test acc : 0.8711999654769897


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


14 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.1995292913235362
general loss :  0.5097843849140665
Teacher acc: 0.5777400135993958
Student acc: 0.8604599833488464


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


15 	 test acc : 0.865399956703186


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


15 	 test acc : 0.7110999822616577
lr : [0.01]


100%|██████████| 391/391 [09:08<00:00,  1.40s/it]


distill loss :  0.18783534421106737
general loss :  0.49779722307953994
Teacher acc: 0.5717799663543701
Student acc: 0.8645199537277222


100%|██████████| 79/79 [00:16<00:00,  4.79it/s]


16 	 test acc : 0.8643999695777893


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


16 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.1829110577826381
general loss :  0.43408583897306485
Teacher acc: 0.5723999738693237
Student acc: 0.8817399740219116


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


17 	 test acc : 0.8811999559402466


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


17 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.1803018406052571
general loss :  0.40732614535962225
Teacher acc: 0.5715799927711487
Student acc: 0.8900399804115295


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


18 	 test acc : 0.8822000026702881


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


18 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:16<00:00,  1.42s/it]


distill loss :  0.18986452629556283
general loss :  0.38676848405462394
Teacher acc: 0.5744999647140503
Student acc: 0.8943799734115601


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


19 	 test acc : 0.8823999762535095


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


19 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.16950765475059104
general loss :  0.3873971340525181
Teacher acc: 0.5708999633789062
Student acc: 0.8959199786186218


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


20 	 test acc : 0.8833999633789062


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


20 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.17231001194013887
general loss :  0.3827662371156161
Teacher acc: 0.5745399594306946
Student acc: 0.8978399634361267


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


21 	 test acc : 0.8842999935150146


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


21 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:08<00:00,  1.40s/it]


distill loss :  0.16316691712926493
general loss :  0.3789715252416518
Teacher acc: 0.5748999714851379
Student acc: 0.8977999687194824


100%|██████████| 79/79 [00:16<00:00,  4.79it/s]


22 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.81it/s]


22 	 test acc : 0.7110999822616577
lr : [0.001]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16629526727711377
general loss :  0.37228952988486763
Teacher acc: 0.5763199925422668
Student acc: 0.9003199934959412


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


23 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


23 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.180665944945877
general loss :  0.37050951836282947
Teacher acc: 0.5718199610710144
Student acc: 0.9003599882125854


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


24 	 test acc : 0.8851999640464783


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


24 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.17086017728471162
general loss :  0.3624957078481879
Teacher acc: 0.5743199586868286
Student acc: 0.9032999873161316


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


25 	 test acc : 0.8852999806404114


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


25 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.1742067485428451
general loss :  0.372338997776551
Teacher acc: 0.5713199973106384
Student acc: 0.899399995803833


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


26 	 test acc : 0.8853999972343445


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


26 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.1682044557102806
general loss :  0.3630634741405087
Teacher acc: 0.5742599964141846
Student acc: 0.901479959487915


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


27 	 test acc : 0.8847999572753906


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


27 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17780681710947505
general loss :  0.3668943999520958
Teacher acc: 0.5727199912071228
Student acc: 0.9012199640274048


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


28 	 test acc : 0.885699987411499


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


28 	 test acc : 0.7110999822616577
lr : [0.0001]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17189663831058824
general loss :  0.36275468038780917
Teacher acc: 0.5750200152397156
Student acc: 0.9025799632072449


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


29 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


29 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.17347860887475178
general loss :  0.36792863345207155
Teacher acc: 0.5733799934387207
Student acc: 0.9008599519729614


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


30 	 test acc : 0.8840999603271484


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


30 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.1762973770065724
general loss :  0.3648754615155632
Teacher acc: 0.5761600136756897
Student acc: 0.9022799730300903


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


31 	 test acc : 0.884399950504303


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


31 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16465632181347864
general loss :  0.3603579907694741
Teacher acc: 0.5748800039291382
Student acc: 0.9028599858283997


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


32 	 test acc : 0.8844999670982361


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


32 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17395699403160597
general loss :  0.35766971610543674
Teacher acc: 0.5747199654579163
Student acc: 0.9037399888038635


100%|██████████| 79/79 [00:16<00:00,  4.78it/s]


33 	 test acc : 0.8842999935150146


100%|██████████| 79/79 [00:16<00:00,  4.78it/s]


33 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17519570228017275
general loss :  0.36483620511144016
Teacher acc: 0.5724599957466125
Student acc: 0.9025599956512451


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


34 	 test acc : 0.8840999603271484


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


34 	 test acc : 0.7110999822616577
lr : [1e-05]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.16241367483783103
general loss :  0.3594444555699673
Teacher acc: 0.5719199776649475
Student acc: 0.9047799706459045


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


35 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


35 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.1745558247622817
general loss :  0.3606051645239296
Teacher acc: 0.5733999609947205
Student acc: 0.9030599594116211


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


36 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


36 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:13<00:00,  1.42s/it]


distill loss :  0.17374316555783725
general loss :  0.36464879892366314
Teacher acc: 0.5722399950027466
Student acc: 0.9021999835968018


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


37 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


37 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.16340650206007767
general loss :  0.3681108741199269
Teacher acc: 0.5737800002098083
Student acc: 0.9004799723625183


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


38 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


38 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.1728246240993328
general loss :  0.36586139314924665
Teacher acc: 0.5738599896430969
Student acc: 0.902239978313446


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


39 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


39 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.16669509393851395
general loss :  0.35939141353377907
Teacher acc: 0.5753799676895142
Student acc: 0.9040199518203735


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


40 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


40 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-06]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.16491674473914114
general loss :  0.3548551569585605
Teacher acc: 0.5746399760246277
Student acc: 0.905299961566925


100%|██████████| 79/79 [00:16<00:00,  4.89it/s]


41 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.78it/s]


41 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:13<00:00,  1.42s/it]


distill loss :  0.17143276792800868
general loss :  0.3587899985139632
Teacher acc: 0.5738599896430969
Student acc: 0.9038800001144409


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


42 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.71it/s]


42 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.171794343408664
general loss :  0.36399323830519187
Teacher acc: 0.5746399760246277
Student acc: 0.9008399844169617


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


43 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


43 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.1654130686550875
general loss :  0.3684512031124071
Teacher acc: 0.5738399624824524
Student acc: 0.9007999897003174


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


44 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


44 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16471724426064194
general loss :  0.3626778230566503
Teacher acc: 0.5723199844360352
Student acc: 0.9015799760818481


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


45 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


45 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17568932467704768
general loss :  0.3587466734640129
Teacher acc: 0.572219967842102
Student acc: 0.9029399752616882


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


46 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


46 	 test acc : 0.7110999822616577
lr : [1.0000000000000002e-07]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17430496692914715
general loss :  0.3562188328760664
Teacher acc: 0.5740599632263184
Student acc: 0.9035999774932861


100%|██████████| 79/79 [00:16<00:00,  4.76it/s]


47 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


47 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17639137210938938
general loss :  0.3655168451654637
Teacher acc: 0.5740999579429626
Student acc: 0.901699960231781


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


48 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


48 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16159601437161342
general loss :  0.35979594917172364
Teacher acc: 0.5741199851036072
Student acc: 0.9022799730300903


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


49 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.73it/s]


49 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:06<00:00,  1.40s/it]


distill loss :  0.1554586699189585
general loss :  0.3623568988822
Teacher acc: 0.5753399729728699
Student acc: 0.9012999534606934


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


50 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


50 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.1752947530401942
general loss :  0.36487908124009055
Teacher acc: 0.5740999579429626
Student acc: 0.9022600054740906


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


51 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


51 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:13<00:00,  1.42s/it]


distill loss :  0.17388577264783633
general loss :  0.35804871042899766
Teacher acc: 0.5749599933624268
Student acc: 0.9034799933433533


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


52 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


52 	 test acc : 0.7110999822616577
lr : [1.0000000000000004e-08]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.16369069414570583
general loss :  0.36262036661815156
Teacher acc: 0.573199987411499
Student acc: 0.9034599661827087


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


53 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


53 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.16947133576406923
general loss :  0.360100318403805
Teacher acc: 0.5755199790000916
Student acc: 0.9023399949073792


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


54 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


54 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.1658831101430156
general loss :  0.3600477395993669
Teacher acc: 0.5738799571990967
Student acc: 0.9036799669265747


100%|██████████| 79/79 [00:16<00:00,  4.73it/s]


55 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.75it/s]


55 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17783231973467048
general loss :  0.36450652458021404
Teacher acc: 0.5749599933624268
Student acc: 0.9023599624633789


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


56 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


56 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17318483884863156
general loss :  0.3671830446671342
Teacher acc: 0.5743199586868286
Student acc: 0.9014999866485596


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


57 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


57 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16950758778349595
general loss :  0.36263579537953866
Teacher acc: 0.5749599933624268
Student acc: 0.9032399654388428


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


58 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


58 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.16907735712959637
general loss :  0.36981731100612897
Teacher acc: 0.5730400085449219
Student acc: 0.9012399911880493


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


59 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


59 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17278968713120046
general loss :  0.363977666858517
Teacher acc: 0.574180006980896
Student acc: 0.9023799896240234


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


60 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


60 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:13<00:00,  1.42s/it]


distill loss :  0.17689025187936355
general loss :  0.3583888207250239
Teacher acc: 0.5760599970817566
Student acc: 0.9029399752616882


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


61 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


61 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:15<00:00,  1.42s/it]


distill loss :  0.1801542773666551
general loss :  0.3560677874652321
Teacher acc: 0.5753399729728699
Student acc: 0.903719961643219


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


62 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


62 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.16814250997303393
general loss :  0.3656421763169796
Teacher acc: 0.5717999935150146
Student acc: 0.9014599919319153


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


63 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


63 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:16<00:00,  1.42s/it]


distill loss :  0.18133531152830482
general loss :  0.3637543016153833
Teacher acc: 0.5740999579429626
Student acc: 0.9015799760818481


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


64 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


64 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16662162303200465
general loss :  0.3623293414902504
Teacher acc: 0.5748199820518494
Student acc: 0.902459979057312


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


65 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


65 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.17040698302199925
general loss :  0.36296832149900743
Teacher acc: 0.5744799971580505
Student acc: 0.9025399684906006


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


66 	 test acc : 0.8836999535560608


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


66 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:05<00:00,  1.39s/it]


distill loss :  0.14962250759820347
general loss :  0.36208930005655265
Teacher acc: 0.5749599933624268
Student acc: 0.9035199880599976


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


67 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.72it/s]


67 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.17256823107314384
general loss :  0.3689099314534451
Teacher acc: 0.5749199986457825
Student acc: 0.9007200002670288


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


68 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


68 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:08<00:00,  1.40s/it]


distill loss :  0.16682461775896495
general loss :  0.3591243032070682
Teacher acc: 0.5722999572753906
Student acc: 0.9030799865722656


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


69 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.81it/s]


69 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:09<00:00,  1.41s/it]


distill loss :  0.16152779952339505
general loss :  0.36136739146526514
Teacher acc: 0.5730999708175659
Student acc: 0.9025200009346008


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


70 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


70 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.16753781977993296
general loss :  0.36566747961294316
Teacher acc: 0.5723999738693237
Student acc: 0.9016199707984924


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


71 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.87it/s]


71 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:05<00:00,  1.39s/it]


distill loss :  0.15324081347116728
general loss :  0.36742367086660527
Teacher acc: 0.5734800100326538
Student acc: 0.9009999632835388


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


72 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


72 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:13<00:00,  1.41s/it]


distill loss :  0.17449446425527868
general loss :  0.37083438991585654
Teacher acc: 0.5714399814605713
Student acc: 0.9000799655914307


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


73 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


73 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:08<00:00,  1.40s/it]


distill loss :  0.16564487477959802
general loss :  0.36178511745103487
Teacher acc: 0.5733799934387207
Student acc: 0.9025999903678894


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


74 	 test acc : 0.883899986743927


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


74 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.1738620039737781
general loss :  0.3651951948547607
Teacher acc: 0.5727199912071228
Student acc: 0.9010799527168274


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


75 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


75 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.16658769801373371
general loss :  0.36485399157189957
Teacher acc: 0.5726400017738342
Student acc: 0.9020999670028687


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


76 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


76 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:14<00:00,  1.42s/it]


distill loss :  0.17566037379786412
general loss :  0.36857282699983746
Teacher acc: 0.5738399624824524
Student acc: 0.9012399911880493


100%|██████████| 79/79 [00:16<00:00,  4.82it/s]


77 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.85it/s]


77 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:07<00:00,  1.40s/it]


distill loss :  0.16166088190834846
general loss :  0.36479347220162295
Teacher acc: 0.5748400092124939
Student acc: 0.901919960975647


100%|██████████| 79/79 [00:16<00:00,  4.86it/s]


78 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.83it/s]


78 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:12<00:00,  1.41s/it]


distill loss :  0.17713107757003563
general loss :  0.3562869277527875
Teacher acc: 0.5744199752807617
Student acc: 0.9039799571037292


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


79 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


79 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:11<00:00,  1.41s/it]


distill loss :  0.17082321077175533
general loss :  0.36606407192204615
Teacher acc: 0.5765199661254883
Student acc: 0.9013399481773376


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


80 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.84it/s]


80 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


100%|██████████| 391/391 [09:10<00:00,  1.41s/it]


distill loss :  0.16948118027004286
general loss :  0.3638801385679513
Teacher acc: 0.571619987487793
Student acc: 0.9030799865722656


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


81 	 test acc : 0.8837999701499939


100%|██████████| 79/79 [00:16<00:00,  4.88it/s]


81 	 test acc : 0.7110999822616577
lr : [1.0000000000000005e-09]


 29%|██▉       | 113/391 [02:39<06:33,  1.41s/it]

KeyboardInterrupt



In [15]:
test_acc = utils.test(student, val_loader, epoch)

100%|██████████| 79/79 [00:16<00:00,  4.93it/s]

82 	 test acc : 0.8837999701499939





In [14]:
torch.save(student, "models/CIFAR100_VIT_student_8837.pth")