# Image Classification with VGG

This tutorial will introduce the attribution of image classifiers using VGG11
and ResNet18 on ImageNet. Feel free to replace VGG11 and ResNet18 with any
other version of VGG or ResNet respectively.

## Preparation

First, we install **Zennit**. This includes its dependencies `Pillow`,
`torch` and `torchvision`:

In [None]:
%pip install zennit

Then, we import necessary modules, classes and functions:

In [None]:
import logging

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop
from torchvision.transforms import ToTensor, Normalize
from torchvision.models import vgg11_bn, resnet18

from zennit.attribution import Gradient
from zennit.composites import EpsilonGammaBox, EpsilonPlusFlat
from zennit.image import imgify, imsave
from zennit.torchvision import VGGCanonizer, ResNetCanonizer

We download an image of the [Dornbusch
Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia
Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):

In [None]:
torch.hub.download_url_to_file(
    'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',
    'dornbusch-lighthouse.jpg',
)

We load and prepare the data. The image is resized such that the shorter side
is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a
`torch.Tensor`, and then normalized according the channel-wise mean and
standard deviation of the ImageNet dataset:

In [None]:
# define the base image transform
transform_img = Compose([
    Resize(256),
    CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
    transform_img,
    ToTensor(),
    transform_norm,
])

# load the image
image = Image.open('dornbusch-lighthouse.jpg')

# transform the PIL image and insert a batch-dimension
data = transform(image)[None]

We can look at the original image and the cropped image:

In [None]:
# display the original image
display(image)
# display the resized and cropped image
display(transform_img(image))

## VGG11 - LRP with EpsilonGammaBox
Then, we initialize the VGG16 model and load the hyperparameters. Set
`pretrained=True` to use the pre-trained model instead of the random one:

In [None]:
# load the model and set it to evaluation mode
model = vgg11_bn(pretrained=False).eval()

Compute the attribution using the ``EpsilonGammaBox`` **Composite**:

In [None]:
# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only
# needed with batch-norm)
canonizer = VGGCanonizer()

# the EpsilonGammaBox composite needs the lowest and highest values, which are
# here for ImageNet 0. and 1. with a different normalization for each channel
low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))

# create a composite, specifying arguments and the canonizers, if any
composite = EpsilonGammaBox(low=low, high=high, canonizers=[canonizer])

# choose a target class for the attribution (label 437 is lighthouse)
target = torch.eye(1000)[[437]]

# create the attributor, specifying model and composite
with Gradient(model=model, composite=composite) as attributor:
    # compute the model output and attribution
    output, attribution = attributor(data, target)

print(f'Prediction: {output.argmax(1)[0].item()}')

Visualize the attribution:

In [None]:
# sum over the channels
relevance = attribution.sum(1)

# create an image of the visualize attribution
img = imgify(relevance, symmetric=True, cmap='coldnhot')

# show the image
display(img)

Here, `imgify` produces a PIL-image, which can be saved with `.save()`.

## More Visualization
We can try out different color-maps by either using another built-in color map, or using the color-map specification language:

In [None]:
print('Built-in color-map bwr')
display(imgify(relevance, symmetric=True, cmap='bwr'))

print('CMLS code for color-map from cyan to grey to purple')
display(imgify(relevance, symmetric=True, cmap='0ff,444,f0f'))

print('CMSL code for grey scale for negative and red for positive values')
display(imgify(relevance, symmetric=True, cmap='fff,000,f00'))

To directly save the visualized attribution, we can use `imsave` instead:

In [None]:
# directly save the visualized attribution
imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')

## ResNet18 - LRP with EpsilonPlusFlat
Then, we initialize the ResNet18 model and load the hyperparameters. Set
`pretrained=True` to use the pre-trained model instead of the random one:

In [None]:
# load the model and set it to evaluation mode
model = resnet18(pretrained=False).eval()

Compute the attribution using the ``EpsilonPlusFlat`` **Composite**:

In [None]:
# use the ResNet-specific canonizer
canonizer = ResNetCanonizer()

# create a composite, specifying the canonizers, if any
composite = EpsilonPlusFlat(canonizers=[canonizer])

# choose a target class for the attribution (label 437 is lighthouse)
target = torch.eye(1000)[[437]]

# create the attributor, specifying model and composite
with Gradient(model=model, composite=composite) as attributor:
    # compute the model output and attribution
    output, attribution = attributor(data, target)

print(f'Prediction: {output.argmax(1)[0].item()}')

Visualize the attribution:

In [None]:
# sum over the channels
relevance = attribution.sum(1)

# create an image of the visualize attribution
img = imgify(relevance, symmetric=True, cmap='coldnhot')

# show the image
display(img)