## Interpretability plots SmallImagenet

In [None]:
import torch
from pytorch_trainers_interpretability.interpretability_eval import IntegratedGrad
from pytorch_trainers_interpretability.interpretability_eval.lime import LimeEval
from pytorch_trainers_interpretability.interpretability_eval import ShapEval
import matplotlib.pyplot as plt
import numpy as np
from textwrap import wrap
from pytorch_trainers_interpretability.trainers import BasicTrainer, AdversarialTrainer
from pytorch_trainers_interpretability.interpretability_eval.integrated_grad import IntegratedGrad
from pytorch_trainers_interpretability.attack import Attacker, L2Step, LinfStep
from pytorch_trainers_interpretability.models.resnet  import ResNet50
from torchvision import datasets, transforms
from PIL import Image
import json
import torchvision
import copy
import os
import torch.nn as nn

In [None]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1619V_hLgH3mhZSVCYuYO1G7y0088A1vq' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1619V_hLgH3mhZSVCYuYO1G7y0088A1vq" -O "smallimagenet.tar.gz" && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1zpHIJ_dPYb6-Seqtbk9YoWSItvdwU-GO' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1zpHIJ_dPYb6-Seqtbk9YoWSItvdwU-GO" -O "regular_smimagenet.pt" && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1_5bKIy4n0rtbRy0YK64BUblnBqUnISMv' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1_5bKIy4n0rtbRy0YK64BUblnBqUnISMv" -O "robust_smimagenet.pt" && rm -rf /tmp/cookies.txt
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1t71KG_u-X-LCAFJ94Kg0pqNBajumEEsu' -O "smallimagenet150_labels.json"

In [None]:
def int_grads_compare_regular_robust(regular, robust, images, labels, classes, normalizer=lambda x: x):
    integrated_grad = IntegratedGrad(regular, normalizer=normalizer)
    integrated_grad2 = IntegratedGrad(robust, normalizer=normalizer)
    num_img = labels.shape[0]
    labels_text = [ '\n'.join(wrap(classes[l.item()], 20)) for l in labels.cpu() ]
    num_cols = np.ceil(num_img/9).astype(int)
    fig = plt.figure(figsize=(20*num_cols, 85))
    subfigs = fig.subfigures(nrows=1, ncols=num_cols)
    k = 0
    for j, sub in enumerate(subfigs):
        sub2 = sub.subfigures(nrows=9, ncols=1)
        for i, subfig in enumerate(sub2):
            if num_img == k:
                break
            axs = subfig.subplots(nrows=1, ncols=3)
            if i is 0:
                axs[0].set_title("Original image", fontsize=40)
                axs[1].set_title("Regular model", fontsize=40)
                axs[2].set_title("Robust model", fontsize=40)
            img = images[k:k+1].cuda()
            img = normalizer(img)
            pr = regular(img)
            pr2 = robust(img)
            image = images[k].cpu().permute(1, 2, 0).numpy()
            grad = integrated_grad.random_baseline_integrated_grads(image, pr.argmax(dim=1), steps=100, num_random_trials=10, batch_size=100)
            grad2 = integrated_grad2.random_baseline_integrated_grads(image, pr2.argmax(dim=1), steps=100, num_random_trials=10, batch_size=100)
            pred = '\n'.join(wrap(classes[pr.argmax(dim=1).item()], 10))
            pred2 = '\n'.join(wrap(classes[pr2.argmax(dim=1).item()], 10))
            axs[0].set_xlabel('\n'.join(wrap(labels_text[k], 10)), fontsize=35)
            axs[0].imshow(image)
            axs[1].set_xlabel(f"Class: {pred}\n Prob: {torch.softmax(pr, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[1].imshow(integrated_grad.visualization(grad, image))
            axs[2].set_xlabel(f"Class: {pred2}\n Prob: {torch.softmax(pr2, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[2].imshow(integrated_grad2.visualization(grad2, image))
            axs[0].set_xticklabels([])
            axs[0].set_yticklabels([])
            axs[1].set_xticklabels([])
            axs[1].set_yticklabels([])
            axs[2].set_xticklabels([])
            axs[2].set_yticklabels([])
            k+=1


In [None]:
def lime_compare_regular_robust(regular, robust, images, labels, classes, normalizer=lambda x: x):
    leval = LimeEval(regular, normalizer=normalizer)
    leval2 = LimeEval(robust, normalizer=normalizer)
    num_img = labels.shape[0]
    labels_text = [ '\n'.join(wrap(classes[l.item()], 20)) for l in labels.cpu() ]
    num_cols = np.ceil(num_img/9).astype(int)
    fig = plt.figure(figsize=(20*num_cols, 80))
    subfigs = fig.subfigures(nrows=1, ncols=num_cols)
    k = 0
    for j, sub in enumerate(subfigs):
        sub2 = sub.subfigures(nrows=9, ncols=1)
        for i, subfig in enumerate(sub2):
            if num_img == k:
                break
            axs = subfig.subplots(nrows=1, ncols=3)
            if i is 0:
                axs[0].set_title("Original image", fontsize=40)
                axs[1].set_title("Regular model", fontsize=40)
                axs[2].set_title("Robust model", fontsize=40)
            img = images[k:k+1].cuda()
            img = normalizer(img)
            pr = regular(img)
            pr2 = robust(img)
            image = images[k].cpu().permute(1, 2, 0).numpy()
            plot1 = leval.explain_model(image)
            plot2 = leval2.explain_model(image)
            pred = '\n'.join(wrap(classes[pr.argmax(dim=1).item()], 10))
            pred2 = '\n'.join(wrap(classes[pr2.argmax(dim=1).item()], 10))
            axs[0].set_xlabel('\n'.join(wrap(labels_text[k], 10)), fontsize=35)
            axs[0].imshow(image)
            axs[1].set_xlabel(f"Class: {pred}\n Prob: {torch.softmax(pr, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[1].imshow(plot1)
            axs[2].set_xlabel(f"Class: {pred2}\n Prob: {torch.softmax(pr2, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[2].imshow(plot2)
            axs[0].set_xticklabels([])
            axs[0].set_yticklabels([])
            axs[1].set_xticklabels([])
            axs[1].set_yticklabels([])
            axs[2].set_xticklabels([])
            axs[2].set_yticklabels([])
            k+=1


In [None]:
traindir = os.path.join("/home/server/smallimagenet", 'train')
valdir = os.path.join("/home/server/smallimagenet", 'test')
normalize = transforms.Normalize(mean=[0.4808, 0.4512, 0.4072],
                                     std=[0.2687, 0.2610, 0.2742])
transform_train =   transforms.Compose([
          transforms.Resize(140),
          transforms.RandomResizedCrop(128),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
    ])
transform_test = transforms.Compose([
          transforms.Resize(140),
          transforms.CenterCrop(128),
          transforms.ToTensor(),
    ])
trainset = datasets.ImageFolder(traindir, transform=transform_train)
testset = datasets.ImageFolder(valdir, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=200,
                                         shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=True)
f = open('./smallimagenet150_labels.json')
classes = json.load(f)

In [None]:
model = ResNet50(num_classes=150)
model2 = ResNet50(num_classes=150)
model.load_state_dict(torch.load("./regular_smimagenet.pt")["model_state_dict"])
model2.load_state_dict(torch.load("./robust_smimagenet.pt")["model_state_dict"])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model2.to(device)
model.eval()
model2.eval()
integrated_grad = IntegratedGrad(model, normalizer=normalize)
integrated_grad2 = IntegratedGrad(model2, normalizer=normalize)

In [None]:
images, labels =  next(iter(testloader))
images_b, _ = next(iter(trainloader))
shapeval = ShapEval(model, classes, normalize)
shapeval.nat_deep_exp(images_b[:100], images[:36], labels[:36])
plt.savefig("./smallimagenet_shap1.pdf")
shapeval2 = ShapEval(model2, classes, normalize)
shapeval2.nat_deep_exp(images_b[:100], images[:36], labels[:36])
plt.savefig("./smallimagenet_shap2.pdf")
int_grads_compare_regular_robust(model, model2, images[:36], labels[:36], classes, normalize)
plt.savefig("./smallimagenet_int_grads.pdf")
lime_compare_regular_robust(model, model2, images[:36], labels[:36], classes, normalize)
plt.savefig("./smallimagenet_lime_plots.pdf")