<a href="https://colab.research.google.com/github/hhaemin/computer_vision/blob/main/10_multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch, torchvision
import torchvision.models as models
import torchvision.datasets as datasets

import torch.nn as nn

import matplotlib.pyplot as plt
from PIL import Image

## Multi-head Attention


https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention


## 1. Do pixel-level attention

In [2]:
B, C, H, W = 8, 3, 80, 80

query = torch.randn(B, C, H, W).reshape(B, C, -1)
key = torch.randn(B, C, H, W).reshape(B, C, -1)
value = torch.randn(B, C, H, W).reshape(B, C, -1)

print(query.shape, key.shape, value.shape)

embed_dim= H*W  ## input embedding dim
num_heads= 8    ## hidden layer flexiblity

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

attn_output, attn_output_weights = multihead_attn(query, key, value)

print(attn_output.shape, attn_output_weights.shape)
attn_output = attn_output.reshape(B, C, H, W)
print(attn_output.shape)

torch.Size([8, 3, 6400]) torch.Size([8, 3, 6400]) torch.Size([8, 3, 6400])
torch.Size([8, 3, 6400]) torch.Size([8, 3, 3])
torch.Size([8, 3, 80, 80])


## 2. Sampling patch

In [3]:
unfold = nn.Unfold(kernel_size=(16, 16), stride = (16,16))

q_patches = unfold(query.reshape(B, C, H, W))
k_patches = unfold(key.reshape(B, C, H, W))
v_patches = unfold(value.reshape(B, C, H, W))

print(q_patches.shape, k_patches.shape, v_patches.shape)

embed_dim= H*W // 16 // 16 
num_heads = 5


multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)


attn_output, attn_output_weights = multihead_attn(q_patches, k_patches, v_patches)


print(attn_output.shape, attn_output_weights.shape)

attn_output = attn_output.reshape(B, 3, 16, 16, 25)

print(attn_output.shape)
# attn_output = attn_output.reshape(B, C, H, W)
# print(attn_output.shape)

torch.Size([8, 768, 25]) torch.Size([8, 768, 25]) torch.Size([8, 768, 25])
torch.Size([8, 768, 25]) torch.Size([8, 768, 768])
torch.Size([8, 3, 16, 16, 25])


## 3. After MLP

In [4]:
mlp_embed_dim = 16

q_linear = nn.Linear(embed_dim, mlp_embed_dim)
k_linear = nn.Linear(embed_dim, mlp_embed_dim)
v_linear = nn.Linear(embed_dim, mlp_embed_dim)

q_embed = q_linear(q_patches)
k_embed = k_linear(k_patches)
v_embed = v_linear(v_patches)

print(q_embed.shape, k_embed.shape, v_embed.shape)

num_heads = 4


multihead_attn = nn.MultiheadAttention(mlp_embed_dim, num_heads, batch_first=True)


attn_output, attn_output_weights = multihead_attn(q_embed, k_embed, v_embed)

print(attn_output.shape, attn_output_weights.shape)

torch.Size([8, 768, 16]) torch.Size([8, 768, 16]) torch.Size([8, 768, 16])
torch.Size([8, 768, 16]) torch.Size([8, 768, 768])
