## Process of PVT 


참고 : https://github.com/whai362/PVT/blob/v2/detection/pvt.py


class pvt_tiny(PyramidVisionTransformer)의 1 iteration에 대해 (stage 1)

전반적인 연산 흐름, output의 크기 등을 이해하는 것에 초점

In [162]:
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

In [163]:
img_size = 224
in_chans = 3

patch_size = 4

embed_dims = [64, 128, 320, 512]
num_heads = [1, 2, 5, 8]
mlp_ratios = [8, 8, 4, 4],
qkv_bias = True
depths = [2, 2, 2, 2],
sr_ratios = [8, 4, 2, 1]
drop_rate = 0.0
drop_path_rate = 0.1

### Input

In [164]:
x = torch.randn(1, 3, 224, 224)  # B, C, H, W

### Patch Embedding

In [165]:
# img_size = to_2tuple(img_size)
# patch_size = to_2tuple(patch_size)

# print(img_size)
# print(patch_size)

In [166]:
embed_dim = 64  # embed_dims[0]
proj = nn.Conv2d(in_chans, embed_dim , kernel_size=patch_size, stride=patch_size)
norm = nn.LayerNorm(embed_dim)

In [167]:
print(proj(x).shape)
print(proj(x).flatten(2).shape)  
print(proj(x).flatten(2).transpose(1, 2).shape)  # 4x4 patch가 총 3136개 존재 : (B, N, C)

torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 3136])
torch.Size([1, 3136, 64])


In [168]:
x = proj(x).flatten(2).transpose(1, 2)
x = norm(x)

In [169]:
print(x.shape)

torch.Size([1, 3136, 64])


In [170]:
num_patches = 3136

In [171]:
img_size

224

In [172]:
H = img_size // patch_size
W = img_size // patch_size

print(H, W)

56 56


### Positional Encoding

In [173]:
pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
pos_drop = nn.Dropout(0.0)

In [174]:
print(pos_embed.shape)  # x와 동일한 shape (Patch Embedding 결과에 더해짐)

torch.Size([1, 3136, 64])


### Spatial Reduction Attention

![image](https://user-images.githubusercontent.com/44194558/145922486-e2dded10-1c63-4230-b5f0-faea32b6268e.png)

In [175]:
dim = 64  # embed_dims[0]
num_head = num_heads[0]
mlp_ratio = mlp_ratios[0]
sr_ratio = sr_ratios[0]  # 8
qkv_bias = True
drop = drop_rate
qk_scale = None

In [176]:
head_dim = dim // num_head
head_dim

64

In [177]:
q = nn.Linear(dim, dim, bias=qkv_bias)  # Query
kv = nn.Linear(dim, dim * 2, bias=qkv_bias)  # Key, Value
attn_drop = nn.Dropout(0)
proj = nn.Linear(dim, dim)
proj_drop = nn.Dropout(0)

In [178]:
# sr_ratio를 사용하여 Key, Value의 공간 차원 감소 (spatial reduction)
sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)

In [179]:
B, N, C = x.shape
print(B, N, C)  # Batch, Len of sequence, Channel

1 3136 64


In [180]:
print(q(x).shape)
print(q(x).reshape(B, N, num_head, C).shape)  
print(q(x).reshape(B, N, num_head, C).permute(0, 2, 1, 3).shape)  # B, H, N, C

torch.Size([1, 3136, 64])
torch.Size([1, 3136, 1, 64])
torch.Size([1, 1, 3136, 64])


In [181]:
q = q(x).reshape(B, N, num_head, C).permute(0, 2, 1, 3)

In [182]:
q.shape  # input tensor x와 동일한 shape

torch.Size([1, 1, 3136, 64])

In [183]:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
print(x.shape)
print(x_.shape)  # B, C, H, W

torch.Size([1, 3136, 64])
torch.Size([1, 64, 56, 56])


In [184]:
print(sr(x_).shape)  # Spatial Reduction (56x56x64) -> ((56x56)/8^2) x 64 = 49 x 64
print(sr(x_).reshape(B, C, -1).shape)
print(sr(x_).reshape(B, C, -1).permute(0, 2, 1).shape)

torch.Size([1, 64, 7, 7])
torch.Size([1, 64, 49])
torch.Size([1, 49, 64])


In [185]:
x_ = sr(x_).reshape(B, C, -1).permute(0, 2, 1)
print(x_.shape)

torch.Size([1, 49, 64])


In [186]:
print(kv(x_).shape)
print(kv(x_).reshape(B, -1, 2, num_head, C // num_head).shape)  # B, N, 2(k, v), H, C
print(kv(x_).reshape(B, -1, 2, num_head, C // num_head).permute(2, 0, 3, 1, 4).shape)

torch.Size([1, 49, 128])
torch.Size([1, 49, 2, 1, 64])
torch.Size([2, 1, 1, 49, 64])


In [187]:
kv = kv(x_).reshape(B, -1, 2, num_head, C // num_head).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]

In [188]:
print(q.shape)  # B, H, N, C
print(k.shape)  # B, H, N, C
print(v.shape)  # B, H, N, C

torch.Size([1, 1, 3136, 64])
torch.Size([1, 1, 49, 64])
torch.Size([1, 1, 49, 64])


query는 (56x56

In [189]:
print(k.transpose(-2, -1).shape)
print((q @ k.transpose(-2, -1)).shape)

torch.Size([1, 1, 64, 49])
torch.Size([1, 1, 3136, 49])


In [190]:
attn = (q @ k.transpose(-2, -1))
print(attn.shape)

torch.Size([1, 1, 3136, 49])


In [191]:
attn = attn.softmax(dim=-1)
print(attn.shape)

torch.Size([1, 1, 3136, 49])


In [192]:
attn = attn_drop(attn)
print(attn.shape)

torch.Size([1, 1, 3136, 49])


In [193]:
print((attn @ v).shape)
print((attn @ v).transpose(1, 2).shape)
print((attn @ v).transpose(1, 2).reshape(B, N, C).shape)

torch.Size([1, 1, 3136, 64])
torch.Size([1, 3136, 1, 64])
torch.Size([1, 3136, 64])


In [194]:
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
print(x.shape)

torch.Size([1, 3136, 64])


In [195]:
x = proj(x)
print(x.shape)

torch.Size([1, 3136, 64])


### MLP

In [196]:
out_features = 64
in_features = 64
hidden_features = 64
act_layer=nn.GELU
fc1 = nn.Linear(in_features, hidden_features)
act = act_layer()
fc2 = nn.Linear(hidden_features, out_features)
drop = nn.Dropout(drop)

In [197]:
x = fc1(x)
x = act(x)
x = drop(x)
x = fc2(x)
x = drop(x)

In [198]:
x.shape

torch.Size([1, 3136, 64])

### Output of stage1

![image](https://user-images.githubusercontent.com/44194558/145932869-41e889e8-601d-4598-a67c-8ff9c90d24cb.png)

<br/>

* 224 x 224 x 3 이미지를 입력으로 받아 56 x 56 x 64의 feature map 출력


* stage2의 입력으로 제공

In [200]:
x.reshape(B, H, W, -1).permute(0, 3, 1, 2).shape  # F1 : H/4 x W/4 x C1 = 56 x 56 x 64 

torch.Size([1, 64, 56, 56])