Skip to content

Commit

Permalink
Added SqueezeAndExcitation Block example
Browse files Browse the repository at this point in the history
  • Loading branch information
jernsting committed Jul 13, 2021
1 parent cf8e114 commit ae3c3d5
Showing 1 changed file with 165 additions and 0 deletions.
165 changes: 165 additions & 0 deletions examples/SqueezeAndExcitation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Squeeze and Excitation example\n",
"This notebook will show you how to create a Squeeze and Excitation block (2D)\n",
"described in the original publication (https://arxiv.org/abs/1709.01507)\n",
"\n",
"First step: Include libraries:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"\n",
"import useful_layers as ul\n",
"\n",
"in_channels = 5 # Dummy value for in_channels"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Next step: build a simple SE block"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"class SqueezeAndExcitationBlock2D(nn.Module):\n",
" def __init__(self, in_channels):\n",
" super(SqueezeAndExcitationBlock2D, self).__init__()\n",
" self.se_map_layer = ul.layers.ChannelAttention2D(in_channels=in_channels)\n",
" def forward(self, x):\n",
" attention_map = self.se_map_layer(x) # calculate attention map\n",
" return x * attention_map # multiply the input with the attention map (scaling)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Here we create a simple layer. It contains the useful_layers ChannelAttention2D layer\n",
"for the calculation of the SE-Map.\n",
"\n",
"The SE map could be used together with different operations, so the SE-Map is not applied\n",
"to the original input by default.\n",
"\n",
"So lets implement the forward-function"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"block1 = SqueezeAndExcitationBlock2D(in_channels)\n",
"block2 = ul.blocks.ScalingBlock(block1.se_map_layer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The output of both blocks is the same:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jan/anaconda3/envs/usefullayers/lib/python3.9/site-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
]
},
{
"data": {
"text/plain": "True"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dummy_input = torch.randn(2, 5, 5, 5) # b, c, h, w\n",
"\n",
"block1_output = block1(dummy_input).detach().numpy()\n",
"block2_output = block2(dummy_input).detach().numpy()\n",
"np.array_equal(block1_output, block2_output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

0 comments on commit ae3c3d5

Please sign in to comment.