In [2]:
import torch
import torch.nn as nn

In [9]:
class VitInputLayer(nn.Module):
    def __init__(self, 
                 in_channels:int=3,
                 emb_dim:int=384,
                 num_patch_row:int=2,
                 image_size:int=32
                ):
        """
        in_channels : 入力画像のチャンネル数
        emb_dim : 埋め込み後のベクトルの長さ
        num_patch : 高さ方向のパッチの数
        image_size : 入力画像の1辺の長さ，入力画像の高さと幅は同じであると仮定
        """

        super().__init__()
        self.in_channels=in_channels
        self.emb_dim = emb_dim
        self.num_patch_row = num_patch_row
        self.image_size = image_size

        self.num_patch = self.num_patch_row ** 2

        # パッチの大きさ
        # 例 : 入力画像の１辺の長さが32， patch_size_row=2の場合, patch_size = 16
        self.patch_size = int(self.image_size // self.num_patch_row)

        # 入寮画像のパッチへの分割 & パッチの埋め込みを一気に行う層 => kernelを学習することを埋め込みと呼んでいる？
        self.patch_emb_layer = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.emb_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size
        )

        # クラストークン
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, emb_dim)
        )

        # 位置埋め込み
        self.pos_emb = nn.Parameter(
            torch.randn(1, self.num_patch+1, emb_dim)
        )

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        """
        引数:
            x : 入力画像.shape => (B, C, H, W)
                B: batch_size, C: num_channels, H: height, W: width
        返り値:
            z_0: ViTへの入力.shape => (B, N, D)
                B: batch_size, N: num_token, D: emb_dim
        """

        # バッチの埋め込み & flatten
        ## (B, C, H, W) -> (B, D, H/P, W/P)
        z_0 = self.patch_emb_layer(x)
        ## (B, D, H/P, W/P) -> (B, D, Np)
        z_0 = z_0.flatten(2)
        ## 軸の入れ替え (B, D, Np) -> (B, Np, D)
        z_0 = z_0.transpose(1,2)

        # バッチの埋め込みの先頭に暮らすトークンを結合
        ## (B, Np, D) -> (B, N, D)
        z_0 = torch.cat(
            [self.cls_token.repeat(repeats=(x.size(0),1,1)), z_0], dim=1
        )

        # 位置埋め込みの加算
        ## (B, N, D) -> (B, N, D)
        z_0 = z_0 + self.pos_emb

        return z_0

In [11]:
batch_size, channel, height, width = 2, 3, 32, 32
x = torch.randn(batch_size, channel, height, width)
input_layer = VitInputLayer()
z_0 = input_layer(x)
print(z_0.shape)

torch.Size([2, 5, 384])


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self,
                 emb_dim:int=384,
                 head:int=3,
                 dropout:float=0.
    ):
        """
        引数:
            emb_dim : 埋め込み後のベクトル長
            head : ヘッドの数
            dropout : ドロップアウト率
        """

        super().__init__()
        self.head = head
        self.emb_dim = emb_dim
        self.head_dim = emb_dim // self.head
        self.sqrt_dh = self.head_dim**0.5 # D_hの二乗根, qk^Tを割るための係数
        
        # 入力をq,k,vに埋め込むための線形層
        self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)

        # ドロップアウト層
        self.attn_drop = nn.Dropout(dropout)

        # MHSAの結果を出力に埋め込むための線形層
        self.w_o = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Dropout(dropout)
        )

    def forward(self, z:torch.Tensor) -> torch.Tensor:
        """
        引数:
            z: MHSAのへの入力．shape => (B, N, D)
                B: batch_size, N: トークン数, D: 埋め込みベクトル長
        返り値:
            out: MHSAの出力. shape => (B, N, D)
        """
        batch_size, num_patch, _ = z.shape

        # 埋め込み
        q = self.w_q(z)
        k = self.w_k(z)
        v = self.w_v(z)

        # q, k, vをヘッドに分ける
        ## (B, N, D) -> (B, N, h, D/h)
        q = q.view(batch_size, num_patch, self.head, self.head_dim)
        k = k.view(batch_size, num_patch, self.head, self.head_dim)
        v = v.view(batch_size, num_patch, self.head, self.head_dim)

        ## Self-Attentionができるように
        ## (B, N, h, D/h) -> (B, h, N, D/h)
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        # 内積
        ## (B, h, N, D/h) -> (B, h, D/h, N)
        k_T = k.transpose(2,3)
        ## (B, h, N, D/h) x (B, h, D/h, N) -> (B, h, N, N)
        dots = (q @ k_T) / self.sqrt_dh
        ## 列方向にソフトマックス
        attn = F.softmax(dots, dim=-1)
        ## ドロップアウト
        attn = self.attn_drop(attn)

        #　加重和
        ## (B, h, N, N) x (B, h, N, D/h) -> (B, h, N, D/h)
        out = attn @ v
        ## (B, h, N, D/h) -> (B, N, h, D/h)
        out = out.transpose(1,2)
        ## (B, N, h, D/h) -> (B, N, D)
        out = out.reshape(batch_size, num_patch, self.emb_dim)

        return out

In [21]:
mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)
print(out.shape)

torch.Size([2, 5, 384])


In [28]:
class VitEncoderBlock(nn.Module):
    def __init__(self,
                 emb_dim:int=384,
                 head:int=8,
                 hidden_dim:int=384*4,
                 dropout:float=0.
                ):
        """
        引数:
            emb_dim: 埋め込み後のベクトルの長さ
            head: ヘッドの数
            hidden_dim: Encoder BlockのMLPにおける中間層のベクトルの長さ
            dropout: ドロップアウト率
        """
        super().__init__()

        self.ln1 = nn.LayerNorm(emb_dim)

        self.msa = MultiHeadSelfAttention(
            emb_dim=emb_dim,
            head=head,
            dropout=dropout
        )

        self.ln2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            nn.Dropout(dropout)
        )

    def forward(self, z:torch.Tensor) -> torch.Tensor:
        """
        引数:
            z: Encoder Blockへの入力. shape=(B, N, D)
                B: batch_size, N: トークン数, D: 埋め込みベクトルの長さ
        返り値:
            out: Encoder Blockの出力. shape=(B, N, D)
        """

        out = self.msa(self.ln1(z)) + z

        out = self.mlp(self.ln2(z)) * out

        return out

In [29]:
vit_enc = VitEncoderBlock()
z_1 = vit_enc(z_0)
print(z_1.shape)

torch.Size([2, 5, 384])


In [31]:
class Vit(nn.Module):
    def __init__(self,
                 in_channel:int=3,
                 num_classes:int=10,
                 emb_dim:int=384,
                 num_patch_row:int=2,
                 image_size:int=32,
                 num_blocks:int=7,
                 head:int=8,
                 hidden_dim:int=384*4,
                 dropout=0.
                ):

        super().__init__()

        self.input_layer = VitInputLayer(
            in_channel,
            emb_dim,
            num_patch_row,
            image_size
        )

        # Encoder Blockの多段
        self.encoder = nn.Sequential(*[
            VitEncoderBlock(
                emb_dim=emb_dim,
                head=head,
                hidden_dim=hidden_dim,
                dropout=dropout
            ) for _ in range(num_blocks)]
        )

        # MLP Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        引数:
            x: Vitへの入力画像. shape = (B, C, H, W)
        返り値
            out: Vitの出力. shape = (B, M)
            M: クラス数
        """
        # InputLayer
        ## (B, C, H, W) -> (B, N, D)
        out = self.input_layer(x)
        # Encoder
        ## (B, N, D) -> (B, N, D)
        out = self.encoder(out)
        # クラストークンだけ取り出す
        ## (B, N, D) -> (B, D)
        cls_token = out[:, 0]
        # MLP Head
        ## (B, D) -> (B, M)
        pred = self.mlp_head(cls_token)
        return pred
        

In [None]:
num_classes = 10
batch_size, channel, height, width = 2, 3, 32, 32
x = torch.randn(batch_size, channel, heightm width)
