In [2]:
import torch
import torch.nn as nn

https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632
https://github.com/LKLQQ/ViG/blob/dea9ad27c2e5514ec85c3c2a082267d5d427e1e3/src/models/vig/vig.py#L101

In [3]:
image = torch.randn(1,3,224,224)

---

In [20]:
H = W = 224
C = 3
P = 16
N=H*W//(P**2)

In [24]:
img = image.reshape(-1, N, P**2*C) #Batch size, Patch_area*Channel PxPxC

In [25]:
img.shape

torch.Size([1, 196, 768])

In [26]:
# From the paper
emb_dim = 768

linear_proy = nn.Linear(P**2*C, 768)

linear_proy(img).shape

torch.Size([1, 196, 768])

In [44]:
# Code implementation (performance gain)

conv_proy = nn.Conv2d(C, emb_dim, kernel_size=P, stride=P)

conv_proy(image).shape

#+ pos embeddings

torch.Size([1, 768, 14, 14])

## ViG communitycode (Mindspore)

In [11]:
class PatchEmbed(nn.Module):
    """ Image to Visual Embeddings
    """

    def __init__(self, dim=768):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=dim//8, kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(dim//8),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=dim//8, out_channels=dim//4, kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(dim//4),
            nn.GELU(),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=dim//4, out_channels=dim//2, kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(dim//2),
            nn.GELU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=dim//2, out_channels=dim, kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(dim),
            nn.GELU(),
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, stride=1,
                      padding=1, bias=False),
            nn.BatchNorm2d(dim),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

In [12]:
pe = PatchEmbed()

In [15]:
pe(image).shape # ilegal?

torch.Size([1, 768, 14, 14])

## ViG official

In [4]:
class Stem(nn.Module):
    """ Image to Visual Embedding
    Overlap: https://arxiv.org/pdf/2106.13797.pdf
    """
    def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
        super().__init__()        
        self.convs = nn.Sequential(
            nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim//2),
            #act_layer(act),
            nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_dim),
            #act_layer(act),
            nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_dim),
        )

    def forward(self, x):
        x = self.convs(x)
        return x

In [5]:
patchifier = Stem()

In [7]:
patchifier(image).shape

torch.Size([1, 768, 56, 56])