<a href="https://colab.research.google.com/github/kitarikes/My_Code-Stash/blob/main/pytorch_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## nnクラスの使い方

In [54]:
import torch
import torch.nn as nn

In [55]:
class SimpleMlp(nn.Module):
  def __init__(self,
               vec_length: int=16,
               hidden_unit_1: int=8,
               hidden_unit_2: int=2):
    super(SimpleMlp, self).__init__()
    
    self.layer1 = nn.Linear(vec_length, hidden_unit_1)
    self.relu = nn.ReLU()
    self.layer2 = nn.Linear(hidden_unit_1, hidden_unit_2)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    out = self.layer1(x)
    out = self.relu(out)
    out = self.layer2(out)
    return out



In [56]:
vec = 16
u1 = 8
u2 = 2

bs = 4

x = torch.randn(bs, vec) #4*16

In [57]:
x

tensor([[-0.4357, -0.8163,  0.1530, -0.3898,  0.0692, -1.4470,  0.4773,  0.9313,
          0.1823, -0.6926, -0.4840, -0.5682,  0.7434,  0.8738,  0.2804,  0.1215],
        [ 0.1552, -1.0927, -0.3763, -0.5281, -0.0605,  0.5286, -2.3574,  1.1081,
          0.0484, -0.7260,  1.0879,  0.3583, -1.9325,  1.0415,  0.1661, -0.6680],
        [ 0.2197,  0.7555,  0.0322,  1.1035,  0.1629,  0.1661,  0.4502,  0.0176,
         -0.4979,  0.7792, -0.5262,  0.3570, -0.0778,  0.0599, -0.4966, -0.5918],
        [ 2.2016,  0.0503, -1.8721, -0.3182,  0.5474, -0.7765, -0.2650,  0.2996,
          0.4807,  0.4485, -0.0668,  0.8647,  0.1631,  0.8188,  0.9984, -0.0612]])

In [58]:
net = SimpleMlp(vec, u1, u2)

In [59]:
net

SimpleMlp(
  (layer1): Linear(in_features=16, out_features=8, bias=True)
  (relu): ReLU()
  (layer2): Linear(in_features=8, out_features=2, bias=True)
)

In [60]:
net(x)

tensor([[ 0.2400,  0.0343],
        [ 0.3019,  0.1133],
        [ 0.4721,  0.0036],
        [ 0.3196, -0.1167]], grad_fn=<AddmmBackward0>)

## Input Layer

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [72]:
class VitInputLayer(nn.Module):
  def __init__(self,
               in_channels:int=3,
               emb_dim:int=384,
               num_patch_row:int=2,
               image_size:int=32):
    super(VitInputLayer, self).__init__()
    self.in_channels = in_channels
    self.emb_dim = emb_dim
    # 縦方向パッチ数
    self.num_patch_row = num_patch_row
    # 画像一片
    self.image_size = image_size
    # 総パッチの数
    self.num_patch = self.num_patch_row**2
    self.patch_size = int(self.image_size // self.num_patch_row)
    print(f"{self.patch_size}=")

    self.patch_emb_layer = nn.Conv2d(
        in_channels=self.in_channels,
        out_channels = self.emb_dim,
        kernel_size=self.patch_size,
        stride=self.patch_size
    )

      # クラストークン 
    self.cls_token = nn.Parameter(
            torch.randn(1, 1, emb_dim) 
        )

        # 位置埋め込み
        ## クラストークンが先頭に結合されているため、
        ## 長さemb_dimの位置埋め込みベクトルを(パッチ数+1)個用意 
    self.pos_emb = nn.Parameter(
            torch.randn(1, self.num_patch+1, emb_dim) 
        )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ 
        引数:
            x: 入力画像。形状は、(B, C, H, W)。[式(1)]
                B: バッチサイズ、C:チャンネル数、H:高さ、W:幅
        返り値:
            z_0: ViTへの入力。形状は、(B, N, D)。
                B:バッチサイズ、N:トークン数、D:埋め込みベクトルの長さ
        """
        # パッチの埋め込み & flatten [式(3)]
        ## パッチの埋め込み (B, C, H, W) -> (B, D, H/P, W/P) 
        ## ここで、Pはパッチ1辺の大きさ
        z_0 = self.patch_emb_layer(x)
        return z_0, self.patch_emb_layer.weight

        ## パッチのflatten (B, D, H/P, W/P) -> (B, D, Np) 
        ## ここで、Npはパッチの数(=H*W/Pˆ2)
        z_0 = z_0.flatten(2)

        ## 軸の入れ替え (B, D, Np) -> (B, Np, D) 
        z_0 = z_0.transpose(1, 2)

        # パッチの埋め込みの先頭にクラストークンを結合 [式(4)] 
        ## (B, Np, D) -> (B, N, D)
        ## N = (Np + 1)であることに留意
        ## また、cls_tokenの形状は(1,1,D)であるため、
        ## repeatメソッドによって(B,1,D)に変換してからパッチの埋め込みとの結合を行う 
        z_0 = torch.cat(
            [self.cls_token.repeat(repeats=(x.size(0),1,1)), z_0], dim=1)

        # 位置埋め込みの加算 [式(5)] 
        ## (B, N, D) -> (B, N, D) 
        z_0 = z_0 + self.pos_emb
        return z_0

# batch_size, channel, height, width= 2, 3, 32, 32
# x = torch.randn(batch_size, channel, height, width) 
# input_layer = VitInputLayer(num_patch_row=2) 
# z_0=input_layer(x)

# # (2, 5, 384)(=(B, N, D))になっていることを確認。 
# print(z_0.shape)


In [73]:

batch_size, channel, height, width= 2, 3, 32, 32
x = torch.randn(batch_size, channel, height, width) 
x.shape

torch.Size([2, 3, 32, 32])

In [77]:
input_layer = VitInputLayer(num_patch_row=2) 

y, w = input_layer(x)

16=


In [80]:
y.shape, w.shape

(torch.Size([2, 384, 2, 2]), torch.Size([384, 3, 16, 16]))

In [66]:
y.flatten(2).shape

torch.Size([2, 384, 4])

In [84]:
w_tmp = torch.randn((16*16, 16*16*3, 384))

w_tmp.shape

torch.Size([384, 3, 16, 16])

In [87]:
torch.matmul(x, w_tmp)

RuntimeError: ignored