In [1]:
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import to_2tuple

  from .autonotebook import tqdm as notebook_tqdm


# 01.Convolutional Token Embedding

In [2]:
class ConvEmbed(nn.Module):
    '''
    img/token map to Conv Embedding
    '''
    
    def __init__(self,
                 patch_size=7, # [7, 3, 3]
                 in_chans=3,   # [3, dim of stage1, dim of stage2]
                 embed_dim=64, # [64, 192, 384]
                 stride=4,     # [4, 2, 2]
                 padding=2,    # [2, 1, 1]
                 norm_layer=None):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        
        self.proj = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        
    def forward(self, x):
        x = self.proj(x)
        
        _, _, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        return x
    

In [3]:
class AttentionConv(nn.Module):
    def __init__(self,
                 dim=64,        # [64,192,384]
                 num_heads=4,   # paper: [1,3,6], me: [4,8,16]
                 qkv_bias=False,
                 kernel_size=3,
                 padding_q=1,
                 padding_kv=1,
                 stride_q=1,
                 stride_kv=2,
                 ):
        super().__init__()
        self.stride_q = stride_q
        self.stride_kv = stride_kv
        self.dim = dim
        self.num_heads = num_heads        
        self.scale = dim ** -0.5
        
        self.conv_proj_q = self._build_projection(dim,
                                                  kernel_size,
                                                  padding_q,
                                                  stride_q,
                                                  )
        self.conv_proj_k = self._build_projection(dim,
                                                  kernel_size,
                                                  padding_kv,
                                                  stride_kv,
                                                  )
        
        self.conv_proj_v = self._build_projection(dim,
                                                  kernel_size,
                                                  padding_kv,
                                                  stride_kv,
                                                  )
        
        self.linear_proj_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.linear_proj_k = nn.Linear(dim, dim, bias=qkv_bias)
        self.linear_proj_v = nn.Linear(dim, dim, bias=qkv_bias)
        
    def _build_projection(self,
                          dim,
                          kernel_size,
                          padding,
                          stride,
                          ):
        
        proj = nn.Sequential(OrderedDict([
            ('depthwise', nn.Conv2d(
                dim,
                dim,
                kernel_size=kernel_size,
                padding=padding,
                stride=stride,
                bias=False,
                groups=dim)),
            ('bn', nn.BatchNorm2d(dim)),
            ('pointwise', nn.Conv2d(
                dim,
                dim,
                kernel_size=1)),
            ('rearrange', Rearrange('b c h w -> b (h w) c'))
        ]))
        
        return proj
    
    def forward(self, x, h, w):
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        
        q = self.conv_proj_q(x)
        k = self.conv_proj_k(x)
        v = self.conv_proj_v(x)
        
        q = rearrange(self.linear_proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.linear_proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.linear_proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
        
        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn =F.softmax(attn_score, dim=-1)
        
        x = torch.matmul(attn, v)
        batch_size, num_heads, seq_length, depth = x.size()
        x = x.view(batch_size, seq_length, num_heads * depth)
        
        return x

In [4]:
class Block(nn.Module):
    
    def __init__(self,
                 dim,
                 num_heads,
                 qkv_bias=False,
                 norm_layer=nn.LayerNorm,
                ):
        super().__init__()
        
        self.norm1 = norm_layer(dim)
        self.attn = AttentionConv(dim=dim,
                                  num_heads=num_heads,
                                  qkv_bias=qkv_bias)        
        self.norm2 = norm_layer(dim)
        
    def forward(self, x, h, w):
        res = x
        x = self.norm1(x)
        attn = self.attn(x, h, w)
        x = res + attn
        return self.norm2(x)

In [5]:
test_img = torch.Tensor(np.zeros((2,3,224,224))) # B, C, H, W

block = Block(dim=64,
              num_heads=4)

In [6]:
# Stage 1 

## Patch Embedding
convembed = ConvEmbed(patch_size=7, stride=4, padding=2)
stage1_img = convembed(test_img)

## Attention with Convolution
b, c, h, w = stage1_img.shape
stage1_img = rearrange(stage1_img, 'b c h w -> b (h w) c')
stage1_img = block(stage1_img, h=h, w=w)
stage1_img = rearrange(stage1_img, 'b (h w) c -> b c h w', h=h, w=w)

## Check Result
print(f'stage 1 | img shape: {test_img.shape} → Conv Embed Shape: {stage1_img.shape}')

stage 1 | img shape: torch.Size([2, 3, 224, 224]) → Conv Embed Shape: torch.Size([2, 64, 56, 56])


In [7]:
# Stage 2 

## Patch Embedding
convembed = ConvEmbed(patch_size=3, in_chans=64, stride=2, padding=1)
stage2_img = convembed(stage1_img)

## Attention with Convolution
b, c, h, w = stage2_img.shape
stage2_img = rearrange(stage2_img, 'b c h w -> b (h w) c')
stage2_img = block(stage2_img, h=h, w=w)
stage2_img = rearrange(stage2_img, 'b (h w) c -> b c h w', h=h, w=w)

## Check Result
print(f'stage 2 | img shape: {stage1_img.shape} → Conv Embed Shape: {stage2_img.shape}')

stage 2 | img shape: torch.Size([2, 64, 56, 56]) → Conv Embed Shape: torch.Size([2, 64, 28, 28])


In [8]:
# Stage 3 

## Patch Embedding
convembed = ConvEmbed(patch_size=3, in_chans=64, stride=2, padding=1)
stage3_img = convembed(stage2_img)

## Attention with Convolution
b, c, h, w = stage3_img.shape
stage3_img = rearrange(stage3_img, 'b c h w -> b (h w) c')
stage3_img = block(stage3_img, h=h, w=w)
stage3_img = rearrange(stage3_img, 'b (h w) c -> b c h w', h=h, w=w)

## Check Result
print(f'stage 3 | img shape: {stage2_img.shape} → Conv Embed Shape: {stage3_img.shape}')

stage 3 | img shape: torch.Size([2, 64, 28, 28]) → Conv Embed Shape: torch.Size([2, 64, 14, 14])
