diff --git a/tutorials/optimviz/GettingStarted_ModelPreparation_OptimViz.ipynb b/tutorials/optimviz/GettingStarted_ModelPreparation_OptimViz.ipynb new file mode 100644 index 0000000000..ea83ff0146 --- /dev/null +++ b/tutorials/optimviz/GettingStarted_ModelPreparation_OptimViz.ipynb @@ -0,0 +1,469 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "GettingStarted_ModelPreparation_OptimViz.ipynb", + "provenance": [], + "collapsed_sections": [ + "3MSB2RhA4h8E" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Preparing Models For Captum's Optim Module\n", + "\n", + "While most models will work out of the box with the Optim module, some model may require a few minor changes for full compatibility. This tutorial demonstrates how to easily perform the suggested & required changes to models for use with the Optim module." + ], + "metadata": { + "id": "QVpft54KA-P_" + } + }, + { + "cell_type": "code", + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import captum.optim as opt\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ], + "metadata": { + "id": "KD5InqKt3Hjc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Model Layer Changes\n", + "\n", + "The Optim module's layer related functions, and optimization systems rely on layers being defined as `nn.Module` classes rather than functional layers. Specifically, Optim's loss optimization and activation collection rely on PyTorch's hook system via [`register_forward_hook`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_forward_hook#torch.nn.Module.register_forward_hook), and functional layers do not support hooks.\n", + "Other functions like `replace_layers` can only detect `nn.Module` objects inside models.\n", + "\n", + "\n", + "For the purpose of this tutorial, our main toy model does not use any functional layers. Though if you are wishing to use your own model then you should verify that all applicable functional layers have been changed to their `nn.Module` equivalents in your chosen model.\n", + "\n", + "* A list of all PyTorch's `torch.nn.functional` layers can be found [here](https://pytorch.org/docs/stable/nn.functional.html), and each layer has links to their `nn.Module` equivalents.\n", + "\n", + "* The most common change that you will likely encounter, is converting the functional [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html#torch.nn.functional.relu) layers to [`nn.ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html)." + ], + "metadata": { + "id": "3MSB2RhA4h8E" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Tutorial Setup\n", + "\n", + "Below we define a simple toy model and a functional version of the toy model for usage in our examples." + ], + "metadata": { + "id": "QGIfQki3Dn2M" + } + }, + { + "cell_type": "code", + "source": [ + "class ToyModel(torch.nn.Module):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " self.basic_module = torch.nn.Sequential(\n", + " torch.nn.Conv2d(3, 4, kernel_size=3, stride=2),\n", + " torch.nn.ReLU(),\n", + " torch.nn.MaxPool2d(kernel_size=3, stride=2),\n", + " )\n", + " self.conv = torch.nn.Conv2d(4, 4, kernel_size=3, stride=2)\n", + " self.bn = torch.nn.BatchNorm2d(4)\n", + " self.relu = torch.nn.ReLU()\n", + " self.pooling = torch.nn.AdaptiveAvgPool2d((2, 2))\n", + " self.linear = torch.nn.Linear(16, 4)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x = self.basic_module(x)\n", + " x = self.conv(x)\n", + " x = self.bn(x)\n", + " x = self.relu(x)\n", + " x = self.pooling(x)\n", + " x = x.flatten()\n", + " x = self.linear(x)\n", + " return x\n", + "\n", + "\n", + "class ToyModelFunctional(torch.nn.Module):\n", + " \"\"\"Functional layer only version of our toy model\"\"\"\n", + "\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x = F.conv2d(x, weight=torch.ones([4, 3, 3, 3]), kernel_size=3, stride=2)\n", + " x = F.relu(x)\n", + " x = F.max_pool2d(x, kernel_size=3, stride=2)\n", + "\n", + " x = F.conv2d(x, weight=torch.ones([4, 3, 3, 3]), kernel_size=3, stride=2)\n", + " x = F.batch_norm(x, running_mean=torch.ones([4]), running_var=torch.ones([4]))\n", + " x = F.relu(x)\n", + " x = F.adaptive_avg_pool2d(input, (2, 2))\n", + " x = x.flatten()\n", + " x = F.linear(input, weight=torch.ones([4, 16]))\n", + " return x" + ], + "metadata": { + "id": "X79d0fh_3LuT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## The Basics: Targetable Layers\n", + "\n", + "The optim module's `opt.models.collect_activations` function and loss objectives (`opt.loss.`) rely on forward hooks using PyTorch's hook system. This means that functional layers cannot be used as optimization targets, and activations cannot be collected for them.\n", + "\n", + "Models can easily be checked for compatible layers via the `opt.models.get_model_layers` function as we'll see below." + ], + "metadata": { + "id": "UjEdNgauOdbZ" + } + }, + { + "cell_type": "code", + "source": [ + "# Functional version of the toy model with no nn.Module layers\n", + "toy_model_functional = ToyModelFunctional().eval().to(device)\n", + "\n", + "# Get hookable layers\n", + "possible_targets = opt.models.get_model_layers(toy_model_functional)\n", + "\n", + "print(\"Possible targets:\", possible_targets)" + ], + "metadata": { + "id": "uEPS3SOqcl47", + "outputId": "fe01c649-97e2-4565-db99-96ced48ce15b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Possible targets: []\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "As you can see, no layers capable of being hooked were found in our functional layer model.\n", + "\n", + "Below we use the `opt.models.get_model_layers` function to see a list of all the hookable layers in our non-functional model that we can use as targets." + ], + "metadata": { + "id": "46YGHAeRdBmE" + } + }, + { + "cell_type": "code", + "source": [ + "# Toy model with only nn.Module layers\n", + "target_model = ToyModel().eval().to(device)\n", + "\n", + "# Get hookable layers\n", + "possible_targets = opt.models.get_model_layers(target_model)\n", + "\n", + "# Display hookable layers\n", + "print(\"Possible targets:\")\n", + "for t in possible_targets:\n", + " print(\" target_model.\" + t)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TlZ5UwiVPptG", + "outputId": "169fb32f-3648-444c-b89b-db9f5cf9121a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Possible targets:\n", + " target_model.basic_module\n", + " target_model.basic_module[0]\n", + " target_model.basic_module[1]\n", + " target_model.basic_module[2]\n", + " target_model.conv\n", + " target_model.bn\n", + " target_model.relu\n", + " target_model.pooling\n", + " target_model.linear\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "We can then easily use any of the targets found above for optimization and activation collection, as we show below." + ], + "metadata": { + "id": "iHTSN71dWh5o" + } + }, + { + "cell_type": "code", + "source": [ + "target_model = ToyModel().eval().to(device)\n", + "\n", + "# Set layer target\n", + "target_layer = target_model.conv\n", + "\n", + "# Collect activations from target\n", + "activations_dict = opt.models.collect_activations(\n", + " model=target_model, targets=target_layer\n", + ")\n", + "\n", + "# Collect target from activations dict\n", + "activations = activations_dict[target_layer]\n", + "\n", + "# Display activation shape\n", + "print(\"Output shape of the {} layer activations:\".format(type(target_layer)))\n", + "print(\" {} \\n\".format(activations.shape))\n", + "\n", + "# We can also use the target for loss objectives\n", + "loss_fn = opt.loss.LayerActivation(target=target_layer)\n", + "\n", + "# Print loss objective\n", + "print(\"Loss objective:\", loss_fn)\n", + "print(\" target:\", loss_fn.target)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tiD7qBzlQ6Zw", + "outputId": "674df320-9fb4-46aa-8bf2-1acd534a7a61" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Output shape of the layer activations:\n", + " torch.Size([1, 4, 27, 27]) \n", + "\n", + "Loss objective: LayerActivation []\n", + " target: Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Visualization: Redirected ReLU\n", + "\n", + "In some cases, the target of interest may not be activated at all by the initial random input. If this happens, the zero derivative stops the gradient from flowing backwards and thus we never move towards any meaningful visualization. To solve this problem, we can replace the ReLU layers in a model with a special version of ReLU called `RedirectedReLU`. The `RedirectedReLU` layer allows the gradient to flow temporarily in these zero gradient situations.\n", + "\n", + "Below we use the `opt.models.replace_layers` function to replace all instances of `nn.ReLU` in our toy model with `opt.models.RedirectedReluLayer`." + ], + "metadata": { + "id": "MlGvyhd0AalX" + } + }, + { + "cell_type": "code", + "source": [ + "relu_model = ToyModel().eval().to(device)\n", + "\n", + "# Replace the ReLU with RedirectedReluLayer\n", + "opt.models.replace_layers(\n", + " relu_model, layer1=torch.nn.ReLU, layer2=opt.models.RedirectedReluLayer\n", + ")\n", + "\n", + "# Show modified model\n", + "print(relu_model)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4w34RcZU_DrU", + "outputId": "596aef9f-26d8-4e87-fdaf-71211e29699b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ToyModel(\n", + " (basic_module): Sequential(\n", + " (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))\n", + " (1): RedirectedReluLayer()\n", + " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))\n", + " (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): RedirectedReluLayer()\n", + " (pooling): AdaptiveAvgPool2d(output_size=(2, 2))\n", + " (linear): Linear(in_features=16, out_features=4, bias=True)\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Circuits: Linear Operation Layers\n", + "\n", + "Certain functions like `opt.circuits.extract_expanded_weights` require using modules that only perform linear operations. This can become slightly more complicated when dealing with layers that have multiple preset set variables. Luckily the `opt.models.replace_layers` function can easily handle these variable transfers for layer types like pooling layers if the `transfer_vars` variable is set to `True`.\n", + "\n", + "\n", + "Common linear layer replacements are as follows:\n", + "\n", + "* `nn.ReLU` layers need to be skipped, which can be done by replacing them with either `nn.Identity` or Captum's `SkipLayer` layer.\n", + "\n", + "* `nn.MaxPool2d` layers need to be converted to their linear `nn.AvgPool2d` layer equivalents.\n", + "\n", + "* `nn.AdaptiveMaxPool2d` layers need to be converted to their linear `nn.AdaptiveAvgPool2d` layer equivalents.\n", + "\n", + "Some of the layers which are already linear operations are:\n", + "\n", + "* `nn.BatchNorm2d` is linear when it's in evaluation mode (`.eval()`).\n", + "* `nn.Conv2d` is linear.\n", + "* `nn.Linear` is linear." + ], + "metadata": { + "id": "KJVG3KDC31dy" + } + }, + { + "cell_type": "code", + "source": [ + "linear_only_model = ToyModel().eval().to(device)\n", + "\n", + "# Replace MaxPool2d with AvgPool2d using the same settings\n", + "opt.models.replace_layers(\n", + " linear_only_model,\n", + " layer1=torch.nn.MaxPool2d,\n", + " layer2=torch.nn.AvgPool2d,\n", + " transfer_vars=True, # Use same MaxPool2d parameters for AvgPool2d\n", + ")\n", + "\n", + "# Replace ReLU with Identity\n", + "opt.models.replace_layers(\n", + " linear_only_model, layer1=torch.nn.ReLU, layer2=torch.nn.Identity\n", + ")\n", + "\n", + "# Show modified model\n", + "print(linear_only_model)" + ], + "metadata": { + "id": "hYbm5Cg34She", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a35a33e2-04c3-4563-b139-ab28127b4f90" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ToyModel(\n", + " (basic_module): Sequential(\n", + " (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))\n", + " (1): Identity()\n", + " (2): AvgPool2d(kernel_size=3, stride=2, padding=0)\n", + " )\n", + " (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))\n", + " (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): Identity()\n", + " (pooling): AdaptiveAvgPool2d(output_size=(2, 2))\n", + " (linear): Linear(in_features=16, out_features=4, bias=True)\n", + ")\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Other: Relaxed Pooling\n", + "\n", + "Some attribution based operations like those used in activation atlas sample collection, require replacing the `nn.MaxPool2d` layers with a special relaxed version called `MaxPool2dRelaxed`. This is also extremely easy to do with the `replace_layers` function like we did above." + ], + "metadata": { + "id": "MXXUIcEBk7_k" + } + }, + { + "cell_type": "code", + "source": [ + "relaxed_pooling_model = ToyModel().eval().to(device).basic_module\n", + "\n", + "# Replace MaxPool2d with MaxPool2dRelaxed\n", + "opt.models.replace_layers(\n", + " relaxed_pooling_model,\n", + " torch.nn.MaxPool2d,\n", + " opt.models.MaxPool2dRelaxed,\n", + " transfer_vars=True, # Use same MaxPool2d parameters for MaxPool2dRelaxed\n", + ")\n", + "\n", + "# Show modified model\n", + "print(relaxed_pooling_model)" + ], + "metadata": { + "id": "fWjY33RKkFi8", + "outputId": "f0e0a0d9-fd1f-4857-ea60-e8a2127607fd", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Sequential(\n", + " (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))\n", + " (1): ReLU()\n", + " (2): MaxPool2dRelaxed(\n", + " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (avgpool): AvgPool2d(kernel_size=3, stride=2, padding=0)\n", + " )\n", + ")\n" + ] + } + ] + } + ] +} \ No newline at end of file