## Interpretability plots SmallImagenet

In [None]:
import torch
from pytorch_trainers_interpretability.interpretability_eval import IntegratedGrad, ShapEval, RepVisualization
from pytorch_trainers_interpretability.tools import show_image_column, show_image_row
import matplotlib.pyplot as plt
import numpy as np
from textwrap import wrap
from pytorch_trainers_interpretability.interpretability_eval.integrated_grad import IntegratedGrad
from pytorch_trainers_interpretability.attack import Attacker, L2Step, LinfStep
from pytorch_trainers_interpretability.models.resnet  import ResNet50
from torchvision import datasets, transforms
from PIL import Image
import json
import torchvision
import copy
import os
import torch.nn as nn

In [None]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1619V_hLgH3mhZSVCYuYO1G7y0088A1vq' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1619V_hLgH3mhZSVCYuYO1G7y0088A1vq" -O "smallimagenet.tar.gz" && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1zpHIJ_dPYb6-Seqtbk9YoWSItvdwU-GO' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1zpHIJ_dPYb6-Seqtbk9YoWSItvdwU-GO" -O "standard_smimagenet.pt" && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1_5bKIy4n0rtbRy0YK64BUblnBqUnISMv' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1_5bKIy4n0rtbRy0YK64BUblnBqUnISMv" -O "robust_l2_smimagenet.pt" && rm -rf /tmp/cookies.txt
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=12O5HxjqcSzjt9-mGfapYeZ-nOfsMopIM' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=12O5HxjqcSzjt9-mGfapYeZ-nOfsMopIM" -O "robust_linf_smimagenet.pt" && rm -rf /tmp/cookies.txt
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1t71KG_u-X-LCAFJ94Kg0pqNBajumEEsu' -O "smallimagenet150_labels.json"

In [None]:
def int_grads_compare_3(standard, robust_l2, robust_linf, images, labels, classes, normalizer=lambda x: x):
    integrated_grad = IntegratedGrad(standard, normalizer=normalizer)
    integrated_grad2 = IntegratedGrad(robust_l2, normalizer=normalizer)
    integrated_grad3 = IntegratedGrad(robust_linf, normalizer=normalizer)
    num_img = labels.shape[0]
    labels_text = [ '\n'.join(wrap(classes[l.item()], 20)) for l in labels.cpu() ]
    num_cols = np.ceil(num_img/8).astype(int)
    fig = plt.figure(figsize=(26*num_cols, 80))
    subfigs = fig.subfigures(nrows=1, ncols=num_cols)
    k = 0
    for j, sub in enumerate(subfigs):
        sub2 = sub.subfigures(nrows=8, ncols=1)
        for i, subfig in enumerate(sub2):
            if num_img == k:
                break
            axs = subfig.subplots(nrows=1, ncols=4)
            if i is 0:
                axs[0].set_title("Original image", fontsize=40)
                axs[1].set_title("Standard model", fontsize=40)
                axs[2].set_title(r"Robust $l_{2}$", fontsize=40)
                axs[3].set_title(r"Robust $l_{\infty}$", fontseize=40)
            img = images[k:k+1].cuda()
            img = normalizer(img)
            pr = standard(img)
            pr2 = robust_l2(img)
            pr3 = robust_linf(img)
            image = images[k].cpu().permute(1, 2, 0).numpy()
            grad = integrated_grad.random_baseline_integrated_grads(image, pr.argmax(dim=1).item(), steps=100, num_random_trials=10, batch_size=100)
            grad2 = integrated_grad2.random_baseline_integrated_grads(image, pr2.argmax(dim=1).item(), steps=100, num_random_trials=10, batch_size=100)
            grad3 = integrated_grad3.random_baseline_integrated_grads(image, pr3.argmax(dim=1).item(), steps=100, num_random_trials=10, batch_size=100)
            pred = '\n'.join(wrap(classes[pr.argmax(dim=1).item()], 10))
            pred2 = '\n'.join(wrap(classes[pr2.argmax(dim=1).item()], 10))
            pred3 = '\n'.join(wrap(classes[pr3.argmax(dim=1).item()], 10))
            axs[0].set_xlabel('\n'.join(wrap(labels_text[k], 10)), fontsize=35)
            axs[0].imshow(image)
            axs[1].set_xlabel(f"Class: {pred}\n Prob: {torch.softmax(pr, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[1].imshow(integrated_grad.visualization(grad, image))
            axs[2].set_xlabel(f"Class: {pred2}\n Prob: {torch.softmax(pr2, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[2].imshow(integrated_grad2.visualization(grad2, image))
            axs[3].set_xlabel(f"Class: {pred3}\n Prob: {torch.softmax(pr3, dim=1).amax(dim=1).item():.2f}", fontsize=35)
            axs[3].imshow(integrated_grad3.visualization(grad3, image))
            for t in range(4):
                axs[t].set_xticklabels([])
                axs[t].set_yticklabels([])
            k+=1


In [None]:
valdir = os.path.join("./smallimagenet", 'test')
normalize = transforms.Normalize(mean=[0.4808, 0.4512, 0.4072],
                                     std=[0.2687, 0.2610, 0.2742])
transform_test = transforms.Compose([
          transforms.Resize(140),
          transforms.CenterCrop(128),
          transforms.ToTensor(),
    ])
testset = datasets.ImageFolder(valdir, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=True)
f = open('./smallimagenet150_labels.json')
classes = json.load(f)

In [None]:
model = ResNet50(num_classes=150)
model2 = ResNet50(num_classes=150)
model3 = ResNet50(num_classes=150)
model.load_state_dict(torch.load("./standard_smimagenet.pt")["model_state_dict"])
model2.load_state_dict(torch.load("./robust_l2_smimagenet.pt")["model_state_dict"])
model3.load_state_dict(torch.load("./robust_linf_smimagenet.pt")["model_state_dict"])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model2.to(device)
model3.to(device)
model.eval()
model2.eval()
model3.eval()
vis = RepVisualization(model, normalizer=normalize)
vis2 = RepVisualization(model2, normalizer=normalize)
vis3 = RepVisualization(model3, normalizer=normalize)

In [None]:
images, labels =  next(iter(testloader))

In [None]:
# Representation Inversion
im = torch.concat([images[3:4] for i in range(5)])
r = torch.rand_like(images[0:2])/20 + 0.5
im_n = torch.concat([images[0:3], r], 0)
res = [im.cpu()]
res.append(vis2.rep_inversion(im, im_n.clone()).cpu())
res.append(vis3.rep_inversion(im, im_n.clone()).cpu())
res.append(vis.rep_inversion(im, im_n.clone()).cpu())
show_image_row(res, ["Original image", "Source", r"Robust $l_{2}$", r"Robust $l_{\lnfty}$", "Standard"])

In [None]:
images, labels =  next(iter(testloader))
cl = np.random.choice(150, (5, 1), replace=False)

In [None]:
#Class Specific Image Generation
r = torch.rand_like(images[0:2])/20 + 0.5
im_n = torch.concat([images[0:3], r], 0)
res = [im_n.cpu()]
res2 = [im_n.cpu()]
res3 = [im_n.cpu()]
for i in len(cl):
    res.append(vis2.class_im_gen(im_n.clone(), cl[i]).cpu())
    res2.append(vis3.class_im_gen(im_n.clone(), cl[i]).cpu())
    res3.append(vis.class_im_gen(im_n.clone(), cl[i]).cpu())
list_labels = ["Soruce"] + [ classes[cl[i]] for i in len(cl)]
show_image_column(list_labels, res)
show_image_column(list_labels, res2)
show_image_column(list_labels, res3)

In [None]:
#SHAP and Integrated Gradients
torch.cuda.empty_cache()
shapeval = ShapEval(model, classes, normalize)
shapeval.gradient_exp(images[:176], images[176:200], labels[176:200])
shapeval2 = ShapEval(model2, classes, normalize)
shapeval2.gradient_exp(images[:176], images[176:200], labels[176:200])
int_grads_compare_3(model, model2, model3, images[176:200], labels[176:200], classes, normalize)
plt.show()