-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Docs: Tutorial: Image Classification VGG
- added support for tutorial ipynb using nbsphinx: my initial plan was to add myST markdown notebooks, however, these do not work with binder and colab, since nothing except notebooks in github repositories can be used with these services - added buttons for colab and binder to the tutorial notebooks; I would have preferred to use pyodide, however, torch does not yet work with pyodide - added an image classification tutorial using VGG11 - two new documentation requirements: nbsphinx and ipykernel - the tutorial ipython notebooks are executed when building the documentation
- Loading branch information
Showing
4 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "810f3c46", | ||
"metadata": {}, | ||
"source": [ | ||
"# Image Classification with VGG\n", | ||
"\n", | ||
"This tutorial will introduce the attribution of image classifiers using VGG11\n", | ||
"on ImageNet. Feel free to replace VGG11 with any other version of VGG.\n", | ||
"\n", | ||
"First, we install **Zennit**. This includes its dependencies `Pillow`,\n", | ||
"`torch` and `torchvision`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c0e2f1fe", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install zennit" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a2dc4cd", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, we import necessary modules, classes and functions:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a9a3fa5e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import logging\n", | ||
"\n", | ||
"import torch\n", | ||
"from PIL import Image\n", | ||
"from torchvision.transforms import Compose, Resize, CenterCrop\n", | ||
"from torchvision.transforms import ToTensor, Normalize\n", | ||
"from torchvision.models import vgg11_bn\n", | ||
"\n", | ||
"from zennit.attribution import Gradient, SmoothGrad\n", | ||
"from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox\n", | ||
"from zennit.image import imgify, imsave\n", | ||
"from zennit.torchvision import VGGCanonizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e48e434e", | ||
"metadata": {}, | ||
"source": [ | ||
"We download an image of the [Dornbusch\n", | ||
"Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia\n", | ||
"Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bfaa9f3a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"torch.hub.download_url_to_file(\n", | ||
" 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',\n", | ||
" 'dornbusch-lighthouse.jpg',\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0798a0c1", | ||
"metadata": {}, | ||
"source": [ | ||
"We load and prepare the data. The image is resized such that the shorter side\n", | ||
"is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a\n", | ||
"`torch.Tensor`, and then normalized according the channel-wise mean and\n", | ||
"standard deviation of the ImageNet dataset:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4989b65c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define the base image transform\n", | ||
"transform_img = Compose([\n", | ||
" Resize(256),\n", | ||
" CenterCrop(224),\n", | ||
"])\n", | ||
"# define the full tensor transform\n", | ||
"transform = Compose([\n", | ||
" transform_img,\n", | ||
" ToTensor(),\n", | ||
" Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n", | ||
"])\n", | ||
"\n", | ||
"# load the image\n", | ||
"image = Image.open('dornbusch-lighthouse.jpg')\n", | ||
"\n", | ||
"# transform the PIL image and insert a batch-dimension\n", | ||
"data = transform(image)[None]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "882b4dd8", | ||
"metadata": {}, | ||
"source": [ | ||
"We can look at the original image and the cropped image:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "072a3ad0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# display the original image\n", | ||
"display(image)\n", | ||
"# display the resized and cropped image\n", | ||
"display(transform_img(image))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3d45bd9b", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, we initialize the model and load the hyperparameters. Set\n", | ||
"`pretrained=True` to use the pre-trained model instead of the random one:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "16b50d14", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# load the model and set it to evaluation mode\n", | ||
"model = vgg11_bn(pretrained=False).eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7f677171", | ||
"metadata": {}, | ||
"source": [ | ||
"Compute the attribution using the ``EpsilonPlusFlat`` **Composite**:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e17da931", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n", | ||
"# needed with batch-norm)\n", | ||
"canonizer = VGGCanonizer()\n", | ||
"\n", | ||
"# create a composite, specifying the canonizers, if any\n", | ||
"composite = EpsilonPlusFlat(canonizers=[canonizer])\n", | ||
"\n", | ||
"# choose a target class for the attribution (label 437 is lighthouse)\n", | ||
"target = torch.eye(1000)[[437]]\n", | ||
"\n", | ||
"# create the attributor, specifying model and composite\n", | ||
"with Gradient(model=model, composite=composite) as attributor:\n", | ||
" # compute the model output and attribution\n", | ||
" output, attribution = attributor(data, target)\n", | ||
"\n", | ||
"print(f'Prediction: {output.argmax(1)[0].item()}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e49b5056", | ||
"metadata": {}, | ||
"source": [ | ||
"Visualize the attribution:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7eda5200", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# sum over the channels\n", | ||
"relevance = attribution.sum(1)\n", | ||
"\n", | ||
"# create an image of the visualize attribution\n", | ||
"img = imgify(relevance, symmetric=True, cmap='coldnhot')\n", | ||
"\n", | ||
"# show the image\n", | ||
"display(img)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b6bf7029", | ||
"metadata": {}, | ||
"source": [ | ||
"Here, `imgify` produces a PIL-image, which can be saved with `.save()`.\n", | ||
"To directly save the visualized attribution, we can use `imsave` instead:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "838c2b4b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# directly save the visualized attribution\n", | ||
"imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"formats": "ipynb,md:myst" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters