In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="3"


In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

import utils
from tqdm import tqdm

teacher = torchvision.models.vit_b_16(pretrained = True)
student = torchvision.models.vit_b_16(pretrained = False)
# model = torchvision.models.vit_l_32(pretrained=True)

print(f"patch_size : {student.patch_size}")
print(f"image_size : {student.image_size}")
print(f"hidden_dim : {student.hidden_dim}")

student = utils.VisionTransformer(student)
teacher = utils.VisionTransformer(teacher)


patch_size : 16
image_size : 224
hidden_dim : 768


In [None]:
torchvision.datasets.

In [3]:
normalize = transforms.Normalize(mean=[0.5074,0.4867,0.4411],
                                 std=[0.2011,0.1987,0.2025])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])

transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
])

dataset_val = torchvision.datasets.ImageNet(root= "../ImagenetData/data", split='val', transform = transform)
dataset_train = torchvision.datasets.ImageNet(root= "../ImagenetData/data", split='train', transform = transform_train)

val_loader = torch.utils.data.DataLoader(dataset_val,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=4)

train_loader = torch.utils.data.DataLoader(dataset_train,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=4)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
S_optimizer = optim.SGD(student.parameters(), lr=0.01, momentum=0.9)
T_optimizer = optim.SGD(teacher.parameters(), lr=0.001, momentum=0.9)
CE_loss = nn.CrossEntropyLoss()

S_scheduler = torch.optim.lr_scheduler.MultiStepLR(S_optimizer, milestones=[1, 2, 3,4], gamma=0.1)
T_scheduler = torch.optim.lr_scheduler.MultiStepLR(T_optimizer, milestones=[40,80], gamma=0.1)

student = student.cuda()
teacher = teacher.cuda()



In [None]:
import numpy as np

criterion_onlylabel = lambda a,b : mse(a*b, b)

criterion_CE = nn.CrossEntropyLoss()
mse = nn.MSELoss()
softmax = torch.nn.Softmax(dim = 1)
criterion_KLD = torch.nn.KLDivLoss(reduction="batchmean")
criterion_response = lambda a,b : criterion_KLD(torch.log_softmax(a, dim=1),torch.softmax(b, dim=1))

    
import torch
import random
import torch.backends.cudnn as cudnn
import numpy as np

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)
past_acc = 0.0
stack = 0  
for epoch in range(100):
    print(f"lr : {S_scheduler.get_last_lr()}")
    T_correct = 0
    S_correct = 0
    all_data = 0
    
    loss_distill = []
    loss_CE = []
    loss_response = []
    student = student.train()
    for img, label in tqdm(train_loader):
        input_data = img.cuda()
        label = label.cuda()
        
        
        all_data += len(input_data)
        input_lrp = utils.get_LRP_img(input_data, label, teacher, criterion_CE, T_optimizer, mean=0.9, std = 0.4).cuda()
        
        S_optimizer.zero_grad()
        T_optimizer.zero_grad()

        layer = random.randint(0,  2)
        input_data, fk = student(input_data,layer)
        input_lrp, fk_lrp = teacher(input_lrp,layer)
        
        distill_loss = mse(fk, fk_lrp)
                    
        CE_loss = criterion_CE(input_data, label)
        
#         response_loss = criterion_response(input_data, input_lrp)
        
        T_correct += sum(label == torch.argmax(input_lrp, dim=1))
        S_correct += sum(label == torch.argmax(input_data, dim=1))
        
        loss_CE.append(CE_loss.item())
        loss_distill.append(distill_loss.item())

        loss = (distill_loss * 2 + CE_loss) / 2
        loss.backward()
        S_optimizer.step()
    print("distill loss : ", sum(loss_distill) / len(loss_distill))
    print("general loss : ", sum(loss_CE) / len(loss_CE))
#     print("response loss : ", sum(loss_response) / len(loss_response))
    
    print(f"Teacher acc: {T_correct / all_data}")
    print(f"Student acc: {S_correct / all_data}")
    test_acc = utils.test(student, val_loader, epoch) # student도 변하는거 확인 완료함
    utils.test(teacher, val_loader, epoch) # student도 변하는거 확인 완료함
    
    if test_acc > past_acc + 0.01:
        stack = 0
    elif past_acc + 0.005 > test_acc:
        stack+=1
    past_acc = test_acc
    
    
    if stack > 2:
        S_scheduler.step()
        stack = 0
        
    print("=" * 100)

lr : [0.01]


100%|██████████| 10010/10010 [3:26:10<00:00,  1.24s/it] 


distill loss :  0.15275222984227269
general loss :  5.479379260218465
Teacher acc: 0.9393154978752136
Student acc: 0.06781473010778427


100%|██████████| 391/391 [01:16<00:00,  5.10it/s]


0 	 test acc : 0.11017999798059464


100%|██████████| 391/391 [01:15<00:00,  5.17it/s]


0 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:16<00:00,  1.24s/it] 


distill loss :  0.14121054266090993
general loss :  4.703286801970803
Teacher acc: 0.9393279552459717
Student acc: 0.1383933573961258


100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


1 	 test acc : 0.17227999866008759


100%|██████████| 391/391 [01:15<00:00,  5.16it/s]


1 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:18<00:00,  1.24s/it] 


distill loss :  0.13989453102042387
general loss :  4.216818486107932
Teacher acc: 0.939051628112793
Student acc: 0.1966504007577896


100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


2 	 test acc : 0.23269999027252197


100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


2 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:09<00:00,  1.24s/it] 


distill loss :  0.1353042886680835
general loss :  3.8429427088081063
Teacher acc: 0.9394692182540894
Student acc: 0.2449212223291397


100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


3 	 test acc : 0.2722199857234955


100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


3 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:06<00:00,  1.24s/it] 


distill loss :  0.13091732738362802
general loss :  3.54842607793989
Teacher acc: 0.9390914440155029
Student acc: 0.2866620719432831


100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


4 	 test acc : 0.30945998430252075


100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


4 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:29<00:00,  1.24s/it] 


distill loss :  0.12715057676190977
general loss :  3.3078071740004686
Teacher acc: 0.9394559860229492
Student acc: 0.32266753911972046


100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


5 	 test acc : 0.3480599820613861


100%|██████████| 391/391 [01:16<00:00,  5.08it/s]


5 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:28<00:00,  1.24s/it] 


distill loss :  0.12381749492097686
general loss :  3.1118908363622384
Teacher acc: 0.939480185508728
Student acc: 0.3533294200897217


100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


6 	 test acc : 0.37105998396873474


100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


6 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:15<00:00,  1.24s/it] 


distill loss :  0.12117432365035796
general loss :  2.939098648615293
Teacher acc: 0.9390618205070496
Student acc: 0.38042113184928894


100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


7 	 test acc : 0.38989999890327454


100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


7 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:35<00:00,  1.24s/it] 


distill loss :  0.11842786482759528
general loss :  2.7891451473359936
Teacher acc: 0.939531683921814
Student acc: 0.4058479368686676


100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


8 	 test acc : 0.4225800037384033


100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


8 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:02<00:00,  1.24s/it] 


distill loss :  0.11601230865204847
general loss :  2.6494894924816434
Teacher acc: 0.9390984773635864
Student acc: 0.42956146597862244


100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


9 	 test acc : 0.43289998173713684


100%|██████████| 391/391 [01:17<00:00,  5.03it/s]


9 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:29<00:00,  1.24s/it] 


distill loss :  0.11385114799503918
general loss :  2.5208358800137316
Teacher acc: 0.9390696287155151
Student acc: 0.4514579176902771


100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


10 	 test acc : 0.45027998089790344


100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


10 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:13<00:00,  1.24s/it] 


distill loss :  0.11161426324453148
general loss :  2.4003035551541814
Teacher acc: 0.9394536018371582
Student acc: 0.4724169373512268


100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


11 	 test acc : 0.4627399742603302


100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


11 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:16<00:00,  1.24s/it] 


distill loss :  0.1102092528393814
general loss :  2.290221674399419
Teacher acc: 0.939479410648346
Student acc: 0.49231675267219543


100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


12 	 test acc : 0.46991997957229614


100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


12 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:35<00:00,  1.24s/it] 


distill loss :  0.1087797615576636
general loss :  2.186614729307748
Teacher acc: 0.9388783574104309
Student acc: 0.5108928084373474


100%|██████████| 391/391 [01:16<00:00,  5.08it/s]


13 	 test acc : 0.483599990606308


100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


13 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:10<00:00,  1.24s/it] 


distill loss :  0.10695966703298089
general loss :  2.086952106531088
Teacher acc: 0.9390618205070496
Student acc: 0.5286906361579895


100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


14 	 test acc : 0.48861998319625854


100%|██████████| 391/391 [01:15<00:00,  5.15it/s]


14 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:26:50<00:00,  1.24s/it] 


distill loss :  0.10612615125326368
general loss :  1.994310423699054
Teacher acc: 0.9388877153396606
Student acc: 0.546785831451416


100%|██████████| 391/391 [01:16<00:00,  5.10it/s]


15 	 test acc : 0.4925599992275238


100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


15 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:11<00:00,  1.24s/it] 


distill loss :  0.104896271164035
general loss :  1.9026237057281898
Teacher acc: 0.939425528049469
Student acc: 0.5643003582954407


100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


16 	 test acc : 0.4994199872016907


100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


16 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:00<00:00,  1.24s/it] 


distill loss :  0.10435682468034409
general loss :  1.8160118732180868
Teacher acc: 0.9392772316932678
Student acc: 0.5806206464767456


100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


17 	 test acc : 0.5039199590682983


100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


17 	 test acc : 0.810699999332428
lr : [0.01]


100%|██████████| 10010/10010 [3:27:10<00:00,  1.24s/it] 


distill loss :  0.10361144915967435
general loss :  1.7329983768763242
Teacher acc: 0.9392038583755493
Student acc: 0.5968285202980042


100%|██████████| 391/391 [01:17<00:00,  5.03it/s]


18 	 test acc : 0.5069000124931335


100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


18 	 test acc : 0.810699999332428
lr : [0.001]


100%|██████████| 10010/10010 [3:27:02<00:00,  1.24s/it] 


distill loss :  0.057223733021067334
general loss :  1.2891622942465764
Teacher acc: 0.9392545819282532
Student acc: 0.6972112059593201


100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


19 	 test acc : 0.5442599654197693


100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


19 	 test acc : 0.810699999332428
lr : [0.001]


100%|██████████| 10010/10010 [3:26:35<00:00,  1.24s/it] 


distill loss :  0.04733513181044773
general loss :  1.1976860988926101
Teacher acc: 0.9391796588897705
Student acc: 0.7193722724914551


100%|██████████| 391/391 [01:16<00:00,  5.08it/s]


20 	 test acc : 0.5487399697303772


100%|██████████| 391/391 [01:16<00:00,  5.10it/s]


20 	 test acc : 0.810699999332428
lr : [0.001]


100%|██████████| 10010/10010 [3:26:57<00:00,  1.24s/it] 


distill loss :  0.043537098678437386
general loss :  1.1501907477131137
Teacher acc: 0.9390961527824402
Student acc: 0.7303606867790222


100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


21 	 test acc : 0.5492199659347534


100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


21 	 test acc : 0.810699999332428
lr : [0.001]


100%|██████████| 10010/10010 [3:25:58<00:00,  1.23s/it] 


distill loss :  0.0406736667618736
general loss :  1.1142054658967417
Teacher acc: 0.9391570091247559
Student acc: 0.7388131022453308


100%|██████████| 391/391 [01:15<00:00,  5.15it/s]


22 	 test acc : 0.5490599870681763


100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


22 	 test acc : 0.810699999332428
lr : [0.0001]


100%|██████████| 10010/10010 [3:25:53<00:00,  1.23s/it] 


distill loss :  0.03474857843146993
general loss :  1.0499860989523457
Teacher acc: 0.9395152926445007
Student acc: 0.7542068958282471


100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


23 	 test acc : 0.553820013999939


100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


23 	 test acc : 0.810699999332428
lr : [0.0001]


100%|██████████| 10010/10010 [3:25:54<00:00,  1.23s/it] 


distill loss :  0.03356486013443856
general loss :  1.0425173548194435
Teacher acc: 0.9393138885498047
Student acc: 0.7561122179031372


100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


24 	 test acc : 0.5536199808120728


100%|██████████| 391/391 [01:15<00:00,  5.15it/s]


24 	 test acc : 0.810699999332428
lr : [0.0001]


100%|██████████| 10010/10010 [3:26:02<00:00,  1.23s/it] 


distill loss :  0.03328156306814927
general loss :  1.0358510947846746
Teacher acc: 0.9393974542617798
Student acc: 0.7577809691429138


100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


25 	 test acc : 0.5523599982261658


100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


25 	 test acc : 0.810699999332428
lr : [1e-05]


 19%|█▉        | 1894/10010 [38:57<2:45:55,  1.23s/it]

In [None]:
def test(model, data, epoch = 0):
    all_data = 0
    correct = 0
    model = model.eval()
    for img, label in tqdm(data):
        model.eval()
        with torch.no_grad():
            img = img.cuda()
            label = label.cuda()

            output, _ = model(img)

            correct += sum(label == torch.argmax(output, dim=1))
            all_data += len(img)
    print(f"{epoch} \t test acc : {correct / all_data}")
    return correct / all_data

In [None]:
student = torchvision.models.vit_b_16(pretrained = True).cuda()
student = utils.VisionTransformer(student)
test(student, val_loader, 0)

In [None]:
torchvision.models.vit_b_16(pretrained = True).encoder.layers