# Basics | Multi Headed Attention in Computer Vision

By [Akshaj Verma](https://akshajverma.com)

This notebook takes you through the different types of attention methods wrt computer vision using PyTorch.

In [1]:
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils

## Self Attention for Images

Let's define a tensor that we obtain after passing an image throught multiple conv layers. 

Let the size of this tensor be `(4, 5, 5)`. This means that our image (latent representation after mulitple conv operations) is of size `(5 x 5)` and has `4` channels.

![Self attention in SAGAN paper](../../assets/sagan_att.png)

[Reference](https://arxiv.org/pdf/1905.08008v1.pdf)

In [2]:
img = [float(i) for i in range(100)]
img = torch.tensor(img)

In [3]:
img = img.view([4, 5, 5])
img

tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.]],

        [[25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.]],

        [[50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.],
         [60., 61., 62., 63., 64.],
         [65., 66., 67., 68., 69.],
         [70., 71., 72., 73., 74.]],

        [[75., 76., 77., 78., 79.],
         [80., 81., 82., 83., 84.],
         [85., 86., 87., 88., 89.],
         [90., 91., 92., 93., 94.],
         [95., 96., 97., 98., 99.]]])

We `unsqueeze(0)` to add a dimension of 1 at index 0. This dimension corresponds to the batch size. 
We do it because `nn.Conv2d()` requires it.

In [4]:
input_img = img.unsqueeze(0)
input_img.shape

torch.Size([1, 4, 5, 5])

We will now pass this image representation through `3` conv operations with `1x1` filters. This will give use `3` different representations. Each of these `3` representations will have `2` channels (let's say). So, we have effectively reduced the number of channels from `4` to `2` using `1x1` conv filters.

In [5]:
cnn_f = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=1, stride=1)
cnn_g = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=1, stride=1)
cnn_h = nn.Conv2d(in_channels=4, out_channels=4, kernel_size=1, stride=1)

In [6]:
f = cnn_f(input_img) # B x C/k x H x W
g = cnn_g(input_img) # B x C/k x H x W
h = cnn_h(input_img) # B x C x H x W

print("op_cnn_f: ", f.shape)
print("op_cnn_g: ", g.shape)
print("op_cnn_h: ", h.shape)

op_cnn_f:  torch.Size([1, 2, 5, 5])
op_cnn_g:  torch.Size([1, 2, 5, 5])
op_cnn_h:  torch.Size([1, 4, 5, 5])


Now, we will flatten out the image. So, our image representation will now have the shape - `B x C x W*H` where `N = H * W`.


We will decrease the channels by a factor of `k` for both `f` and `g`. For `h`, we will keep the number of channels the same.

In [7]:
f = f.view(1, 2, -1) # B x C/k x N
g = g.view(1, 2, -1) # B x C/k x N
h = h.view(1, 4, -1) # B x C x N

print("changed shape of f: ", f.shape)
print("changed shape of g: ", g.shape)
print("changed shape of h: ", h.shape)

changed shape of f:  torch.Size([1, 2, 25])
changed shape of g:  torch.Size([1, 2, 25])
changed shape of h:  torch.Size([1, 4, 25])


Now, we will perform matrix multiplication on `f` and `g`. The output of this matrix multiplication should be = `B x N x N`. 

Current shapes:  
`f` = `B x C x N`  
`g` = `B x C x N`


For `f @ g` to have the shape `B x N x N`, we will transpose `f` so that its shape becomes `f` = `B x N x C`. 

In [8]:
ft = f.permute(0, 2, 1) # B x N x C/k
ft.shape

torch.Size([1, 25, 2])

We will perform matrix multiplication of `f` (transposed) and `g` to obtain `s`. 

`s = f.T @ g`

In [9]:
s = torch.bmm(ft, g) # B x N x N
s.shape

torch.Size([1, 25, 25])

In [10]:
a = torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]])
print(a.shape)
print(a)

torch.Size([2, 3, 3])
tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]],

        [[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])


Now, we'll pass `s` through a softmax function to obtain the attention map denoted by `b`. Softmax is taken at the last dimension.

In [11]:
b = F.softmax(s, dim = -1) # B x N x N
b.shape

torch.Size([1, 25, 25])

After this, we will now perform matrix multiplication between `h` and the `attention_map`.

In [12]:
print(f"Shape of h: \n{h.shape}\n")
print(f"Shape of attention map b: \n{b.shape}")

Shape of h: 
torch.Size([1, 4, 25])

Shape of attention map b: 
torch.Size([1, 25, 25])


Finally, we calculate `o` where `o = h @ b`.

In [13]:
# h : B x C x N
# b : B x N x N

hb = torch.bmm(h, b) # B x C x N 
hb.shape

torch.Size([1, 4, 25])

Finally, we will now pass this `hb` through a `1 x 1 ` Conv layer to obtain `o`. 

`o` is the output of this attention module. 


Before we do that, we need to first convert our `B x C x N` representation back to `B x C x H x W`. 

In [14]:
hb_reshaped = hb.view(input_img.shape) # B x C x H x W
hb_reshaped.shape

torch.Size([1, 4, 5, 5])

In [15]:
cnn_o = nn.Conv2d(in_channels=4, out_channels=4, kernel_size=1, stride=1)

In [16]:
o = cnn_o(hb_reshaped) # B x C x H x W
o.shape

torch.Size([1, 4, 5, 5])