In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
num_epochs = 5
batch_size = 40
learning_rate = 0.001
classes = ('plane', 'car' , 'bird',
    'cat', 'deer', 'dog',
    'frog', 'horse', 'ship', 'truck')

cuda


In [2]:
transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize( 
       (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 
    )
])
train_dataset = torchvision.datasets.CIFAR10(
    root= './data', train = True,
    download =True, transform = transform)
test_dataset = torchvision.datasets.CIFAR10(
    root= './data', train = False,
    download =True, transform = transform)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_loader = torch.utils.data.DataLoader(train_dataset
    , batch_size = batch_size
    , shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset
    , batch_size = 256
    , shuffle = True)
n_total_step = len(train_loader)
print(n_total_step)


1250


In [4]:
teacher_model = torch.load("cifar-10.pth")

In [5]:
model = models.vgg16(pretrained = True)
input_lastLayer = model.classifier[6].in_features
model.classifier[6] = nn.Linear(input_lastLayer,10)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum=0.9,weight_decay=5e-4)
fake_optimizer = torch.optim.SGD(teacher_model.parameters(), lr = 0.0, momentum=0.9,weight_decay=5e-4)


In [None]:
def normalize_max1(w):
    for i in range(len(w)):
        w[i] = w[i] / torch.max(abs(w[i]))
    return w

to_gaussian = lambda arr, mean = 1, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean

softmax = torch.nn.Softmax(dim=1)
softmax2d = lambda b: softmax(torch.flatten(b, start_dim = 1)).reshape(b.shape)
f2 = lambda w, _=None: softmax2d(normalize_max1(-w)) * len(w[0])

for epoch in range(num_epochs):
    correct_T, correct_S, all = 0, 0, 0
    for i, (imgs , labels) in enumerate(train_loader):
        all += len(labels)
        imgs = imgs.to(device)
        labels = labels.to(device)
        
        # teacher model
        img_clone = imgs.clone()
        labels_clone = labels.clone()
        
        img_clone.requires_grad = True
        img_clone.retain_grad = True
        
        t_output = teacher_model(img_clone)
        loss = criterion(t_output, labels_clone)
        loss.backward()
        
        fake_optimizer.zero_grad()
        img_lrp = img_clone * img_clone.grad
        img_lrp = f2(img_lrp)
        
        with torch.no_grad():
            for ii in range(len(img_lrp)):
                img_lrp[ii] = to_gaussian(img_lrp[ii], std = 0.1)
            
            img_clone = img_clone*img_clone
            softlabel = model(img_clone)

            correct_T += sum(labels == torch.argmax(softlabel, dim=1))
            
        
        # student model
        output = model(imgs)

        
        correct_S += sum(labels == torch.argmax(output, dim=1))

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if (i+1) % 250 == 0:
            print(f'epoch {epoch+1}/{num_epochs}, step: {i+1}/{n_total_step}: loss = {loss:.5f}, acc = {100*(correct_S/all):.2f}%')
            print()


# LRP std에관한 실험

In [8]:
def normalize_max1(w):
    for i in range(len(w)):
        w[i] = w[i] / torch.max(abs(w[i]))
    return w

to_gaussian = lambda arr, mean = 1, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean

softmax = torch.nn.Softmax(dim=1)
softmax2d = lambda b: softmax(torch.flatten(b, start_dim = 1)).reshape(b.shape)
f2 = lambda w, _=None: softmax2d(normalize_max1(-w)) * len(w[0])


In [9]:
def lrp_normalize(model, dataset, std = 0.01):
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.0, momentum=0.9,weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    f2 = lambda w, _=None: softmax2d(normalize_max1(-w)) * len(w[0])
    
    correct, all = 0,0
    for idx, batch in enumerate(tqdm(dataset)):
        img , label = batch[0].cuda(), batch[1].cuda()

#         return img, label
        img.requires_grad = True
        img.retain_grad = True
        
        fake_label = torch.ones_like(label)
        
        output = model(img)
        output_arg = torch.argmax(output, dim=1)
        fake_label = (fake_label + output_arg) % 1000
        
        loss = criterion(output, label)
        loss.backward()
#         torch.sum(output).backward()
        optimizer.zero_grad()
        
        img_lrp = img*img.grad
        img_lrp = f2(img_lrp)
        with torch.no_grad():
            for i in range(len(img_lrp)):
                img_lrp[i] = to_gaussian(img_lrp[i], std = std)
            
            img = img*img_lrp # img_lrp가 음수값인것 지움
            output = model(img)
        
            all += len(label)
            correct += sum(label == torch.argmax(output, dim=1))
        
    print('Accuracy : %.4f' % (correct/all))
    return img, img.grad
    

In [16]:
# std : 0.1
# Accuracy : 0.9977

from tqdm import tqdm
teacher_model = teacher_model.cuda()
teacher_model.eval()
for std in range(9,0, -1):
    print(f'std : {std/100}')
    dd = lrp_normalize(teacher_model, test_loader, std/100)

  0%|          | 0/40 [00:00<?, ?it/s]

std : 0.09


100%|██████████| 40/40 [00:47<00:00,  1.19s/it]
  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy : 0.9976
std : 0.08


100%|██████████| 40/40 [00:48<00:00,  1.21s/it]
  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy : 0.9976
std : 0.07


100%|██████████| 40/40 [00:48<00:00,  1.21s/it]
  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy : 0.9974
std : 0.06


100%|██████████| 40/40 [00:48<00:00,  1.21s/it]
  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy : 0.9971
std : 0.05


 18%|█▊        | 7/40 [00:08<00:42,  1.28s/it]


KeyboardInterrupt: 

In [12]:
dd[0].shape

torch.Size([16, 3, 224, 224])

In [35]:
from matplotlib import pyplot as plt
def change_format(img):
    return torch.cat((img[0].unsqueeze(-1), img[1].unsqueeze(-1), img[2].unsqueeze(-1)), dim=-1)

plt.imshow(change_format(dd[1][8]).cpu().detach().numpy().reshape(224,224,3))

TypeError: 'NoneType' object is not subscriptable

In [24]:
for i, (imgs , labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        break

In [26]:
imgs.shape

torch.Size([40, 3, 224, 224])