# RDFIA TME 9 : Visualization of Neural Networks


In [None]:
%load_ext autoreload
%autoreload 2

import os
import random
from pathlib import Path

import torch
from torch.nn import functional as F
import torchvision
import torchvision.transforms as T
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
from PIL import Image

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'viridis'
plt.rcParams['savefig.dpi'] = 150


## Loading the model

In [None]:
from utils import *

# Set URL for pretrained model SqueezeNet (Iandola et al (2016))
# This is the SqueezeNet network from Iandola et al (2016) trained on ImageNet,
# it achieve comparable results to AlexNet on ImageNet while being very compact
torchvision.models.vgg.model_urls[
    "squeezenet1_1"] = "http://webia.lip6.fr/~douillard/rdfia/squeezenet1_1-f364aa15.pth"
os.environ["TORCH_HOME"] = "./pytorch_models"

# loading the model
model = torchvision.models.squeezenet1_1(pretrained=True)
# put in in test mode
model.eval()
# freeze the paramaters
for param in model.parameters():
    param.requires_grad = False

## Loading example images

In [None]:
# Load example images: 25 images from ImageNet
f = np.load("data/imagenet_val_25.npz", allow_pickle=True)
data, target, class_names = f["X"], f["y"], f["label_map"].item()
class_names = {k: v.split(',')[0] for (k, v) in class_names.items()}

def show_images(savepath=None):
    plt.figure(figsize=(15, 7))
    for i in range(24):
        plt.subplot(4, 6, i + 1)
        plt.imshow(data[i])
        plt.title(class_names[target[i]])
        plt.axis('off')
    plt.gcf().tight_layout()
    if savepath: plt.savefig(savepath)
    plt.show()

show_images(savepath='figures/Images.jpeg')

## Saliency maps

In [None]:
from saliency_maps import compute_saliency_maps

def show_saliency_maps(data, target, model, class_names, savepath=None):
    # convert data and target from numpy arrays to Torch Tensors
    data_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in data], dim=0)
    target_tensor = torch.LongTensor(target)

    # compute saliency maps for images in X
    saliency = compute_saliency_maps(data_tensor, target_tensor, model)
    # convert the saliency map from torch.Tensor to numpy.array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = data.shape[0]
    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(data[i])
        plt.axis('off')
        plt.title(class_names[target[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.gcf().tight_layout()
    if savepath: plt.savefig(savepath)
    plt.show()

for i in range(5):
    show_saliency_maps(data[5*i: 5*i + 5], target[5*i: 5*i + 5], model, class_names,
                       savepath='./figures/Saliency-maps_batch{}'.format(i))

## Fooling images

In [None]:
from fooling_examples import make_fooling_image

x_tensor = preprocess(Image.fromarray(data[0]))
dest_y = 6
_ = make_fooling_image(x_tensor, dest_y, model, max_iters=100, confidence=0.99, plot_progress=True)

In [None]:
def show_fooling_images(data, target, destination_y, model, class_names, savedir=None):

    data_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in data], dim=0)
    dest_y = destination_y
    for idx in range(len(target)):
        # prepare tensor data and its fooling version
        x_tensor = data_tensor[idx][None]
        y = target[idx]
        # original prediction score for the class of x
        # (not necessarily the predicted label!)
        y_score_before = F.softmax(model(x_tensor), 1)[0, y].item()

        x_fooling = make_fooling_image(x_tensor, destination_y, model)
        # verify the predicted class on the fooling example
        scores = model(x_fooling)
        dest_y_score, pred_y = map(lambda t: t.item(), F.softmax(scores, 1).max(1))
        assert dest_y == pred_y, "The model is not fooled!"

        # Plots
        x_fooling_np = deprocess(x_fooling.clone())
        x_fooling_np = np.asarray(x_fooling_np).astype(np.uint8)

        plt.subplot(1, 4, 1)
        plt.imshow(data[idx])
        plt.title("Real: {}\nConfidence: {:.2f}%".format(class_names[y], y_score_before*100))
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.imshow(x_fooling_np)
        plt.title("Fooled: {}\nConfidence: {:.2f}%".format(class_names[dest_y], dest_y_score*100))
        plt.axis('off')

        plt.subplot(1, 4, 3)
        x_pre = preprocess(Image.fromarray(data[idx]))
        diff = np.asarray(deprocess(x_fooling - x_pre, should_rescale=False))
        plt.imshow(diff)
        plt.title('Difference')
        plt.axis('off')

        plt.subplot(1, 4, 4)
        diff = np.asarray(deprocess(10 * (x_fooling - x_pre), should_rescale=False))
        plt.imshow(diff)
        plt.title('Magnified difference (10x)')
        plt.axis('off')

        plt.gcf().set_size_inches(12, 5)
        plt.gcf().tight_layout()
        if savedir:
            plt.savefig(Path(savedir)/'Fooling_{}.jpg'.format(idx))
        plt.show()

show_fooling_images(data, target, 6, model, class_names, savedir='./figures')

## Class visualization

In [None]:
from class_visualization import create_class_visualization

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # uncomment this to use GPU

# target_y = 281 # Tabby cat
# target_y = 187 # Yorkshire Terrier
target_y = 76 # Tarantula
# target_y = 78 # Tick
# target_y = 683 # Oboe
# target_y = 366 # Gorilla
# target_y = 604 # Hourglass
# target_y = np.random.randint(1000) # random class

init_img = None
l2_reg = 1e-2
lr =  5
num_iterations = 1000
blur_every = 10
max_jitter = 16
show_every = 100
out = create_class_visualization(target_y, model, dtype, init_img, l2_reg, lr, num_iterations,
                                 blur_every, max_jitter, show_every, class_names)