In [None]:
class MLPHead(nn.Module):
    def __init__(self, in_dim, hid_dim, num_classes, drop=0.0, use_ln=True):
        super().__init__()
        self.pre = nn.LayerNorm(in_dim) if use_ln else nn.Identity()
        self.fc1 = nn.Linear(in_dim, hid_dim)
        self.act = nn.Tanh()            # 논문에서 종종 쓰던 Tanh representation
        self.drop = nn.Dropout(drop)
        self.fc2 = nn.Linear(hid_dim, num_classes)

    def forward(self, x):                # x: [CLS] (B, D)
        x = self.pre(x)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x

class ViTClassifier(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
                 num_classes=1000, attn_drop=0.0, drop=0.0, mlp_dim=3072):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.encoder     = encoder(embed_dim, depth, num_heads, mlp_dim)
        self.head = MLPHead(in_dim=embed_dim, hid_dim=embed_dim, num_classes=num_classes, drop=0.1)

        # 간단 초기화
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):                          # x: (B,3,H,W)  ex) (B,3,224,224)
        tokens = self.patch_embed(x)               # (B, N+1, D)
        cls = self.encoder(tokens)             # [CLS]만
        #cls = enc_out[:, 0]                        # [CLS]만
        logits = self.head(cls)                    # (B, num_classes)
        return logits