Skip to content

Commit

Permalink
Docs: Tutorial: Image Classification VGG
Browse files Browse the repository at this point in the history
- added support for tutorial ipynb using nbsphinx: my initial plan was
  to add myST markdown notebooks, however, these do not work with binder
  and colab, since nothing except notebooks in github repositories can
  be used with these services
- added buttons for colab and binder to the tutorial notebooks; I would
  have preferred to use pyodide, however, torch does not yet work with
  pyodide
- added an image classification tutorial using VGG11
- two new documentation requirements: nbsphinx and ipykernel
- the tutorial ipython notebooks are executed when building the
  documentation
  • Loading branch information
chr5tphr committed May 20, 2022
1 parent 2c15e9f commit 85e5e2e
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 0 deletions.
26 changes: 26 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'sphinx_copybutton',
'sphinxcontrib.datatemplates',
'sphinxcontrib.bibtex',
'nbsphinx',
]


Expand All @@ -54,6 +55,7 @@ def config_inited_handler(app, config):


def setup(app):
app.add_config_value('REVISION', 'master', 'env')
app.add_config_value('generated_path', '_generated', 'env')
app.connect('config-inited', config_inited_handler)

Expand All @@ -66,6 +68,30 @@ def setup(app):
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []

# interactive badges for binder and colab
nbsphinx_prolog = r"""
{% set docname = 'docs/source/' + env.doc2path(env.docname, base=False) %}
.. raw:: html
<div class="admonition note">
This page was generated from
<a class="reference external" href="https://github.com/chr5tphr/zennit/blob/{{ env.config.REVISION }}/{{ docname|e }}">{{ docname|e }}</a>
<br />
Interactive online version:
<span style="white-space: nowrap;">
<a href="https://mybinder.org/v2/gh/chr5tphr/zennit/{{ env.config.REVISION|e }}?filepath={{ docname|e }}">
<img alt="launch binder" src="https://mybinder.org/badge_logo.svg" style="vertical-align:text-bottom">
</a>
</span>
<span style="white-space: nowrap;">
<a href="https://colab.research.google.com/github/chr5tphr/zennit/blob/{{ env.config.REVISION|e }}/{{ docname|e }}">
<img alt="Open in Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom">
</a>
</span>
</div>
"""

# autosummary_generate = True

copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
Expand Down
248 changes: 248 additions & 0 deletions docs/source/tutorial/image-classification-vgg.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "810f3c46",
"metadata": {},
"source": [
"# Image Classification with VGG\n",
"\n",
"This tutorial will introduce the attribution of image classifiers using VGG11\n",
"on ImageNet. Feel free to replace VGG11 with any other version of VGG.\n",
"\n",
"First, we install **Zennit**. This includes its dependencies `Pillow`,\n",
"`torch` and `torchvision`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0e2f1fe",
"metadata": {},
"outputs": [],
"source": [
"%pip install zennit"
]
},
{
"cell_type": "markdown",
"id": "3a2dc4cd",
"metadata": {},
"source": [
"Then, we import necessary modules, classes and functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9a3fa5e",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"import torch\n",
"from PIL import Image\n",
"from torchvision.transforms import Compose, Resize, CenterCrop\n",
"from torchvision.transforms import ToTensor, Normalize\n",
"from torchvision.models import vgg11_bn\n",
"\n",
"from zennit.attribution import Gradient, SmoothGrad\n",
"from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox\n",
"from zennit.image import imgify, imsave\n",
"from zennit.torchvision import VGGCanonizer"
]
},
{
"cell_type": "markdown",
"id": "e48e434e",
"metadata": {},
"source": [
"We download an image of the [Dornbusch\n",
"Lighthouse](https://en.wikipedia.org/wiki/Dornbusch_Lighthouse) from [Wikimedia\n",
"Commons](https://commons.wikimedia.org/wiki/File:2006_09_06_180_Leuchtturm.jpg):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfaa9f3a",
"metadata": {},
"outputs": [],
"source": [
"torch.hub.download_url_to_file(\n",
" 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',\n",
" 'dornbusch-lighthouse.jpg',\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0798a0c1",
"metadata": {},
"source": [
"We load and prepare the data. The image is resized such that the shorter side\n",
"is 256 pixels in size, then center-cropped to `(224, 224)`, converted to a\n",
"`torch.Tensor`, and then normalized according the channel-wise mean and\n",
"standard deviation of the ImageNet dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4989b65c",
"metadata": {},
"outputs": [],
"source": [
"# define the base image transform\n",
"transform_img = Compose([\n",
" Resize(256),\n",
" CenterCrop(224),\n",
"])\n",
"# define the full tensor transform\n",
"transform = Compose([\n",
" transform_img,\n",
" ToTensor(),\n",
" Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
"])\n",
"\n",
"# load the image\n",
"image = Image.open('dornbusch-lighthouse.jpg')\n",
"\n",
"# transform the PIL image and insert a batch-dimension\n",
"data = transform(image)[None]"
]
},
{
"cell_type": "markdown",
"id": "882b4dd8",
"metadata": {},
"source": [
"We can look at the original image and the cropped image:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "072a3ad0",
"metadata": {},
"outputs": [],
"source": [
"# display the original image\n",
"display(image)\n",
"# display the resized and cropped image\n",
"display(transform_img(image))"
]
},
{
"cell_type": "markdown",
"id": "3d45bd9b",
"metadata": {},
"source": [
"Then, we initialize the model and load the hyperparameters. Set\n",
"`pretrained=True` to use the pre-trained model instead of the random one:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16b50d14",
"metadata": {},
"outputs": [],
"source": [
"# load the model and set it to evaluation mode\n",
"model = vgg11_bn(pretrained=False).eval()"
]
},
{
"cell_type": "markdown",
"id": "7f677171",
"metadata": {},
"source": [
"Compute the attribution using the ``EpsilonPlusFlat`` **Composite**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e17da931",
"metadata": {},
"outputs": [],
"source": [
"# use the VGG-specific canonizer (alias for SequentialMergeBatchNorm, only\n",
"# needed with batch-norm)\n",
"canonizer = VGGCanonizer()\n",
"\n",
"# create a composite, specifying the canonizers, if any\n",
"composite = EpsilonPlusFlat(canonizers=[canonizer])\n",
"\n",
"# choose a target class for the attribution (label 437 is lighthouse)\n",
"target = torch.eye(1000)[[437]]\n",
"\n",
"# create the attributor, specifying model and composite\n",
"with Gradient(model=model, composite=composite) as attributor:\n",
" # compute the model output and attribution\n",
" output, attribution = attributor(data, target)\n",
"\n",
"print(f'Prediction: {output.argmax(1)[0].item()}')"
]
},
{
"cell_type": "markdown",
"id": "e49b5056",
"metadata": {},
"source": [
"Visualize the attribution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7eda5200",
"metadata": {},
"outputs": [],
"source": [
"# sum over the channels\n",
"relevance = attribution.sum(1)\n",
"\n",
"# create an image of the visualize attribution\n",
"img = imgify(relevance, symmetric=True, cmap='coldnhot')\n",
"\n",
"# show the image\n",
"display(img)"
]
},
{
"cell_type": "markdown",
"id": "b6bf7029",
"metadata": {},
"source": [
"Here, `imgify` produces a PIL-image, which can be saved with `.save()`.\n",
"To directly save the visualized attribution, we can use `imsave` instead:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "838c2b4b",
"metadata": {},
"outputs": [],
"source": [
"# directly save the visualized attribution\n",
"imsave('attrib-1.png', relevance, symmetric=True, cmap='bwr')"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions docs/source/tutorial/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. toctree::
:maxdepth: 1

image-classification-vgg
..
image-classification-with-vgg-and-resnet
image-segmentation-with-unet
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def replace(mobj):
'sphinx-rtd-theme>=1.0.0',
'sphinxcontrib.datatemplates>=0.9.0',
'sphinxcontrib.bibtex>=2.4.1',
'nbsphinx>=0.8.8',
'ipykernel>=6.13.0',
],
'tests': [
'pytest',
Expand Down

0 comments on commit 85e5e2e

Please sign in to comment.