In [1]:
import torch
model = torch.hub.load("pytorch/vision", "vit_b_16")

import torchvision
from Models.transformer import VisionTransformer as vit
import Models.Conv as conv

from DataLoader import CIFAR100
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

import utils

import numpy as np
import torch.backends.cudnn as cudnn
import random

  from .autonotebook import tqdm as notebook_tqdm
Using cache found in /root/.cache/torch/hub/pytorch_vision_main
  warn(f"Failed to load image Python extension: {e}")
Using cache found in /root/.cache/torch/hub/pytorch_vision_main
Using cache found in /root/.cache/torch/hub/pytorch_vision_main
Using cache found in /root/.cache/torch/hub/pytorch_vision_main


In [2]:
train_loader, test_loader = CIFAR100.get_data(58*3)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
teacher = vit(class_num = 100, pretrained = True)
teacher.load_state_dict(torch.load("saved_models/vit_b_teacher_16_88_00.pth").module.state_dict())
student = vit(class_num = 100, pretrained = True)

In [4]:
device = "cuda"

teacher = teacher.to(device)
teacher = torch.nn.DataParallel(teacher, device_ids=[0, 1, 2])

student = student.to(device)
student = torch.nn.DataParallel(student, device_ids=[0, 1, 2])


In [5]:
criterion_onlylabel = lambda a,b : mse(a*b, b)
criterion_CE = nn.CrossEntropyLoss()
mse = nn.MSELoss()
softmax = torch.nn.Softmax(dim = 1)
criterion_KLD = torch.nn.KLDivLoss(reduction="batchmean")
criterion_response = lambda a,b : criterion_KLD(torch.log_softmax(a, dim=1),torch.softmax(b, dim=1))


In [6]:
S_optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)
T_optimizer = optim.SGD(teacher.parameters(), lr=0.01, momentum=0.9)
CE_loss = nn.CrossEntropyLoss()

In [7]:
S_scheduler = torch.optim.lr_scheduler.MultiStepLR(S_optimizer, milestones=[1,2,3,4,5,6,7], gamma=0.1)
T_scheduler = torch.optim.lr_scheduler.MultiStepLR(T_optimizer, milestones=[1,2,3,4,5,6,7], gamma=0.1)

In [8]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)
best_acc = 0.0
stack = 0

accs_train = []
accs_test = []

In [9]:
utils.test(teacher, test_loader,device) # student도 변하는거 확인 완료함


100%|██████████| 58/58 [00:27<00:00,  2.08it/s]

0 	 test acc : 0.8799999952316284





tensor(0.8800, device='cuda:0')

In [None]:
student_test_accs = []

try:
    encoder_length = len(teacher.encoder.layers)
except:
    encoder_length = len(teacher.module.encoder.layers)

for epoch in range(100):
    
    print(f"lr : {S_scheduler.get_last_lr()}")
    T_correct = 0
    S_correct = 0
    all_data = 0
    
    loss_distill = []
    loss_CE = []
    loss_response = []
    student.train()
    for img, label in tqdm(train_loader):
        teacher.train()
        input_data = img.to(device)
        label = label.to(device)
        
        
        all_data += len(input_data)
        input_lrp = utils.get_LRP_img(input_data, label, teacher, criterion_CE, T_optimizer, mean=1.5, std = 0.1).cuda()
        
        S_optimizer.zero_grad()
        T_optimizer.zero_grad()

        layer = random.randint(0,  2+encoder_length)
        input_data, fk = student(input_data,layer)
        input_lrp, fk_lrp = teacher(input_lrp,layer)
        
        distill_loss = mse(fk, fk_lrp)
                    
        CE_loss = criterion_CE(torch.softmax(input_data, dim=1), label)
        
        response_loss = criterion_response(input_data, input_lrp)
        
        T_correct += sum(label == torch.argmax(input_lrp, dim=1))
        S_correct += sum(label == torch.argmax(input_data, dim=1))
        
        loss_CE.append(CE_loss.item())
        loss_distill.append(distill_loss.item())
        loss_response.append(response_loss.item())
        
        loss = (distill_loss * 2 + CE_loss + response_loss * 0.25) / 2
        loss.backward()
        S_optimizer.step()

    print("distill loss : ", sum(loss_distill) / len(loss_distill))
    print("general loss : ", sum(loss_CE) / len(loss_CE))
    print("response loss : ", sum(loss_response) / len(loss_response))
    
    print(f"Teacher acc: {T_correct / all_data}")
    print(f"Student acc: {S_correct / all_data}")

    test_acc = utils.test(student, test_loader,device, epoch) # student도 변하는거 확인 완료함
    
    if test_acc > best_acc + 0.01:
        stack = 0
        best_acc = test_acc
        
    else:
        stack+=1
    
    if stack > 3:
        S_scheduler.step()
        stack = 0
        
    student_test_accs.append(test_acc.item())
    print("=" * 100)

lr : [0.01]


100%|██████████| 288/288 [11:33<00:00,  2.41s/it]


distill loss :  0.050382229143805385
general loss :  4.430208048886723
response loss :  1.9460779428482056
Teacher acc: 0.8040399551391602
Student acc: 0.41495999693870544


100%|██████████| 58/58 [00:20<00:00,  2.83it/s]


0 	 test acc : 0.7555999755859375
lr : [0.01]


100%|██████████| 288/288 [11:33<00:00,  2.41s/it]


distill loss :  0.03021993418062468
general loss :  4.054238852527407
response loss :  0.42281760109795463
Teacher acc: 0.8028199672698975
Student acc: 0.6793599724769592


100%|██████████| 58/58 [00:20<00:00,  2.82it/s]


1 	 test acc : 0.8166999816894531
lr : [0.01]


100%|██████████| 288/288 [11:31<00:00,  2.40s/it]


distill loss :  0.02684872253869091
general loss :  3.9682750064465733
response loss :  0.2716040295652217
Teacher acc: 0.8033199906349182
Student acc: 0.7242000102996826


100%|██████████| 58/58 [00:20<00:00,  2.82it/s]


2 	 test acc : 0.8417999744415283
lr : [0.01]


100%|██████████| 288/288 [11:34<00:00,  2.41s/it]


distill loss :  0.01910198407939687
general loss :  3.9314322719971337
response loss :  0.21525637241494325
Teacher acc: 0.8041799664497375
Student acc: 0.7483599781990051


100%|██████████| 58/58 [00:20<00:00,  2.83it/s]


3 	 test acc : 0.8518999814987183
lr : [0.01]


100%|██████████| 288/288 [11:35<00:00,  2.41s/it]


distill loss :  0.017521508744620305
general loss :  3.907356503109137
response loss :  0.19140260591585603
Teacher acc: 0.8059999942779541
Student acc: 0.7656399607658386


100%|██████████| 58/58 [00:20<00:00,  2.80it/s]


4 	 test acc : 0.8549999594688416
lr : [0.01]


100%|██████████| 288/288 [11:35<00:00,  2.42s/it]


distill loss :  0.017187951648439694
general loss :  3.893997141884433
response loss :  0.17455330560915172
Teacher acc: 0.8041799664497375
Student acc: 0.7763400077819824


100%|██████████| 58/58 [00:20<00:00,  2.81it/s]


5 	 test acc : 0.8593999743461609
lr : [0.01]


100%|██████████| 288/288 [11:32<00:00,  2.40s/it]


distill loss :  0.01651309929729905
general loss :  3.8837345871660442
response loss :  0.16490654041990638
Teacher acc: 0.8046599626541138
Student acc: 0.7843599915504456


100%|██████████| 58/58 [00:20<00:00,  2.97it/s]

In [None]:
# distill loss를 2배 키워보는것도 좋을지도

In [None]:
utils.test(teacher, test_loader,device, epoch) # student도 변하는거 확인 완료함
utils.test(student, test_loader,device, epoch) # student도 변하는거 확인 완료함


In [None]:
torch.save(student, "vit_b_16_m_0_9_std_0_4_acc_89_01.pth")