In [16]:
import argparse

from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from conf import settings
from utils import get_network, get_test_dataloader, get_training_dataloader


In [2]:
class args:
    def __init__(self):
        self.net = 'vgg16'


In [27]:
file_path = "checkpoint/vgg16/Friday_04_March_2022_19h_25m_09s/vgg16-200-regular.pth"

class args:
    def __init__(self):
        self.net = 'vgg16'
        self.gpu = True
        self.b = 250
        self.weights = file_path
        
args = args()
net = get_network(args)

cifar100_test_loader = get_test_dataloader(
    settings.CIFAR100_TRAIN_MEAN,
    settings.CIFAR100_TRAIN_STD,
    #settings.CIFAR100_PATH,
    num_workers=4,
    batch_size=args.b,
)

net.load_state_dict(torch.load(args.weights))
print(net)
net.eval()

correct_1 = 0.0
correct_5 = 0.0
total = 0


Files already downloaded and verified
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil

In [28]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.00001, momentum=0.9, weight_decay=5e-4)


In [29]:
from tqdm import tqdm

def normalize_max1(w):
    for i in range(len(w)):
        w[i] = w[i] / torch.max(abs(w[i]))
    return w

to_gaussian = lambda arr, mean = 1, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean

softmax = torch.nn.Softmax(dim=1)
softmax2d = lambda b: softmax(torch.flatten(b, start_dim = 1)).reshape(b.shape)
f2 = lambda w, _=None: softmax2d(normalize_max1(-w)) * len(w[0])

def normalize(img):
    img = img- torch.min(img)
    img /= (torch.max(img) + 0.01)
    img = img* 255
    return img.int()

def change_format(img):
    return torch.cat((img[2].unsqueeze(-1), img[1].unsqueeze(-1), img[0].unsqueeze(-1)), dim=-1)

def image_unnormalize(img):
    img = normalize(img)
    img = change_format(img).cpu().detach().numpy().reshape(32,32,3)
    return img




In [30]:
net.eval()

write_data = []

correct_1 = 0.0
correct_5 = 0.0
total = 0

correct_1_after = 0.0
correct_5_after = 0.0

for n_iter, (image, label) in enumerate(cifar100_test_loader):

    if args.gpu:
        image = image.cuda()
        label = label.cuda()

    image.requires_grad = True
    image.retain_grad = True


    output = net(image)

    # calc acc
    labels_origin = label.clone()
    _, pred = output.topk(5, 1, largest=True, sorted=True)
    label = label.view(label.size(0), -1).expand_as(pred)
    correct = pred.eq(label).float()
    correct_5 += correct[:, :5].sum()
    correct_1 += correct[:, :1].sum()
    optimizer.zero_grad()


    loss = criterion(output, labels_origin)
    loss.backward()
    optimizer.zero_grad()



    img_lrp = (image*image.grad).clone()
    img_lrp = f2(img_lrp)


    with torch.no_grad():
        for i in range(len(img_lrp)):
            img_lrp[i] = to_gaussian(img_lrp[i], std = 0.02)

        img_lrp = image*img_lrp # img_lrp가 음수값인것 지움
        softlabel = net(img_lrp)

        _, pred = softlabel.topk(5, 1, largest=True, sorted=True)
        correct = pred.eq(label).float()
        correct_5_after += correct[:, :5].sum()
        correct_1_after += correct[:, :1].sum()


    for it in range(len(img_lrp)):
        write_pickle = {
            "label" : labels_origin[it].item(),
            "softlabel" : softlabel[it].cpu().numpy(),
            "img" : image[it].detach().cpu().numpy(),
            "lrp_img": img_lrp[it].detach().cpu().numpy()
        }
        write_data.append(write_pickle)




print()
print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset))
print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset))
print("Parameter numbers: {}".format(sum(p.numel() for p in net.parameters())))

print()
print("Top 1 err: ", 1 - correct_1_after / len(cifar100_test_loader.dataset))
print("Top 5 err: ", 1 - correct_5_after / len(cifar100_test_loader.dataset))
print("Parameter numbers: {}\n\n".format(sum(p.numel() for p in net.parameters())))



Top 1 err:  tensor(0.2776, device='cuda:0')
Top 5 err:  tensor(0.1021, device='cuda:0')
Parameter numbers: 34015396

Top 1 err:  tensor(0.0994, device='cuda:0')
Top 5 err:  tensor(0.0313, device='cuda:0')
Parameter numbers: 34015396




In [31]:
import pickle

for idx, d in enumerate(tqdm(write_data)):
    with open("LRP_Data/test/" + str(idx).zfill(6) + ".pickle", "wb") as f:
        pickle.dump(d, f)


100%|██████████| 10000/10000 [00:00<00:00, 10799.23it/s]
