Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs: Tutorial: Image Classification VGG #133

Merged
merged 1 commit into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'sphinx_copybutton',
'sphinxcontrib.datatemplates',
'sphinxcontrib.bibtex',
'nbsphinx',
]


Expand All @@ -54,6 +55,7 @@ def config_inited_handler(app, config):


def setup(app):
app.add_config_value('REVISION', 'master', 'env')
app.add_config_value('generated_path', '_generated', 'env')
app.connect('config-inited', config_inited_handler)

Expand All @@ -66,6 +68,30 @@ def setup(app):
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []

# interactive badges for binder and colab
nbsphinx_prolog = r"""
{% set docname = 'docs/source/' + env.doc2path(env.docname, base=False) %}

.. raw:: html

<div class="admonition note">
This page was generated from
<a class="reference external" href="https://github.com/chr5tphr/zennit/blob/{{ env.config.REVISION }}/{{ docname|e }}">{{ docname|e }}</a>
<br />
Interactive online version:
<span style="white-space: nowrap;">
<a href="https://mybinder.org/v2/gh/chr5tphr/zennit/{{ env.config.REVISION|e }}?filepath={{ docname|e }}">
<img alt="launch binder" src="https://mybinder.org/badge_logo.svg" style="vertical-align:text-bottom">
</a>
</span>
<span style="white-space: nowrap;">
<a href="https://colab.research.google.com/github/chr5tphr/zennit/blob/{{ env.config.REVISION|e }}/{{ docname|e }}">
<img alt="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom">
</a>
</span>
</div>
"""

# autosummary_generate = True

copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
Expand Down
248 changes: 248 additions & 0 deletions docs/source/tutorial/image-classification-vgg.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "810f3c46",
"metadata": {},
"source": [
"# Image Classification with VGG\n",
"\n",
"This tutorial will introduce the attribution of image classifiers using VGG11\n",
"on ImageNet. Feel free to replace VGG11 with any other version of VGG.\n",
"\n",
"First, we install **Zennit**. This includes its dependencies `Pillow`,\n",
"`torch` and `torchvision`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0e2f1fe",
"metadata": {},
"outputs": [],
"source": [
"%pip install zennit"
]
},
{
"cell_type": "markdown",
"id": "3a2dc4cd",
"metadata": {},
"source": [
"Then, we import necessary modules, classes and functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9a3fa5e",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"import torch\n",
"from PIL import Image\n",
"from torchvision.transforms import Compose, Resize, CenterCrop\n",
"from torchvision.transforms import ToTensor, Normalize\n",
"from torchvision.models import vgg11_bn\n",
"\n",
"from zennit.attribution import Gradient, SmoothGrad\n",
"from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox\n",
"from zennit.image import imgify, imsave\n",
"from zennit.torchvision import VGGCanonizer"
]
},
{
"cell_type": "markdown",
"id": "e48e434e",
"metadata": {},
"source": [
"We download an image of the [Dornbusch\n",
"Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia\n",
"Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfaa9f3a",
"metadata": {},
"outputs": [],
"source": [
"torch.hub.download_url_to_file(\n",
" 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',\n",
" 'dornbusch-lighthouse.jpg',\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0798a0c1",
"metadata": {},
"source": [
"We load and prepare the data. The image is resized such that the shorter side\n",
"is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a\n",
"`torch.Tensor`, and then normalized according the channel-wise mean and\n",
"standard deviation of the ImageNet dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4989b65c",
"metadata": {},
"outputs": [],
"source": [
"# define the base image transform\n",
"transform_img = Compose([\n",
" Resize(256),\n",
" CenterCrop(224),\n",
"])\n",
"# define the full tensor transform\n",
"transform = Compose([\n",
" transform_img,\n",
" ToTensor(),\n",
" Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
"])\n",
"\n",
"# load the image\n",
"image = Image.open('dornbusch-lighthouse.jpg')\n",
"\n",
"# transform the PIL image and insert a batch-dimension\n",
"data = transform(image)[None]"
]
},
{
"cell_type": "markdown",
"id": "882b4dd8",
"metadata": {},
"source": [
"We can look at the original image and the cropped image:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "072a3ad0",
"metadata": {},
"outputs": [],
"source": [
"# display the original image\n",
"display(image)\n",
"# display the resized and cropped image\n",
"display(transform_img(image))"
]
},
{
"cell_type": "markdown",
"id": "3d45bd9b",
"metadata": {},
"source": [
"Then, we initialize the model and load the hyperparameters. Set\n",
"`pretrained=True` to use the pre-trained model instead of the random one:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16b50d14",
"metadata": {},
"outputs": [],
"source": [
"# load the model and set it to evaluation mode\n",
"model = vgg11_bn(pretrained=False).eval()"
]
},
{
"cell_type": "markdown",
"id": "7f677171",
"metadata": {},
"source": [
"Compute the attribution using the ``EpsilonPlusFlat`` **Composite**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e17da931",
"metadata": {},
"outputs": [],
"source": [
"# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n",
"# needed with batch-norm)\n",
"canonizer = VGGCanonizer()\n",
"\n",
"# create a composite, specifying the canonizers, if any\n",
"composite = EpsilonPlusFlat(canonizers=[canonizer])\n",
"\n",
"# choose a target class for the attribution (label 437 is lighthouse)\n",
"target = torch.eye(1000)[[437]]\n",
"\n",
"# create the attributor, specifying model and composite\n",
"with Gradient(model=model, composite=composite) as attributor:\n",
" # compute the model output and attribution\n",
" output, attribution = attributor(data, target)\n",
"\n",
"print(f'Prediction: {output.argmax(1)[0].item()}')"
]
},
{
"cell_type": "markdown",
"id": "e49b5056",
"metadata": {},
"source": [
"Visualize the attribution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7eda5200",
"metadata": {},
"outputs": [],
"source": [
"# sum over the channels\n",
"relevance = attribution.sum(1)\n",
"\n",
"# create an image of the visualize attribution\n",
"img = imgify(relevance, symmetric=True, cmap='coldnhot')\n",
"\n",
"# show the image\n",
"display(img)"
]
},
{
"cell_type": "markdown",
"id": "b6bf7029",
"metadata": {},
"source": [
"Here, `imgify` produces a PIL-image, which can be saved with `.save()`.\n",
"To directly save the visualized attribution, we can use `imsave` instead:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "838c2b4b",
"metadata": {},
"outputs": [],
"source": [
"# directly save the visualized attribution\n",
"imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/source/tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. toctree::
:maxdepth: 1

image-classification-vgg
..
image-classification-with-vgg-and-resnet
image-segmentation-with-unet
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def replace(mobj):
'sphinx-rtd-theme>=1.0.0',
'sphinxcontrib.datatemplates>=0.9.0',
'sphinxcontrib.bibtex>=2.4.1',
'nbsphinx>=0.8.8',
'ipykernel>=6.13.0',
],
'tests': [
'pytest',
Expand Down