-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added SqueezeAndExcitation Block example
- Loading branch information
Showing
1 changed file
with
165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"# Squeeze and Excitation example\n", | ||
"This notebook will show you how to create a Squeeze and Excitation block (2D)\n", | ||
"described in the original publication (https://arxiv.org/abs/1709.01507)\n", | ||
"\n", | ||
"First step: Include libraries:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import useful_layers as ul\n", | ||
"\n", | ||
"in_channels = 5 # Dummy value for in_channels" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"Next step: build a simple SE block" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"class SqueezeAndExcitationBlock2D(nn.Module):\n", | ||
" def __init__(self, in_channels):\n", | ||
" super(SqueezeAndExcitationBlock2D, self).__init__()\n", | ||
" self.se_map_layer = ul.layers.ChannelAttention2D(in_channels=in_channels)\n", | ||
" def forward(self, x):\n", | ||
" attention_map = self.se_map_layer(x) # calculate attention map\n", | ||
" return x * attention_map # multiply the input with the attention map (scaling)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"Here we create a simple layer. It contains the useful_layers ChannelAttention2D layer\n", | ||
"for the calculation of the SE-Map.\n", | ||
"\n", | ||
"The SE map could be used together with different operations, so the SE-Map is not applied\n", | ||
"to the original input by default.\n", | ||
"\n", | ||
"So lets implement the forward-function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"block1 = SqueezeAndExcitationBlock2D(in_channels)\n", | ||
"block2 = ul.blocks.ScalingBlock(block1.se_map_layer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"The output of both blocks is the same:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/jan/anaconda3/envs/usefullayers/lib/python3.9/site-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n", | ||
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": "True" | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"dummy_input = torch.randn(2, 5, 5, 5) # b, c, h, w\n", | ||
"\n", | ||
"block1_output = block1(dummy_input).detach().numpy()\n", | ||
"block2_output = block2(dummy_input).detach().numpy()\n", | ||
"np.array_equal(block1_output, block2_output)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |