# 图片来源
https://github.com/lucidrains/vit-pytorch
# 代码来源
https://github.com/gupta-abhay/pytorch-vit

# layer

In [None]:
import torch
import torch.nn as nn


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class Attention(nn.Module):
    def __init__(
        self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super(Attention, self).__init__()

        assert (
            dim % num_heads == 0
        ), "Embedding dimension should be divisible by number of heads"

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class FeedForward(nn.Module):
    """
    Implementation of MLP for transformer
    """

    def __init__(self, dim, hidden_dim, dropout_rate=0.0, revised=False):
        super(FeedForward, self).__init__()
        if not revised:
            """
            Original: https://arxiv.org/pdf/2010.11929.pdf
            """
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(p=dropout_rate),
                nn.Linear(hidden_dim, dim),
            )
        else:
            """
            Scaled ReLU: https://arxiv.org/pdf/2109.03810.pdf
            """
            self.net = nn.Sequential(
                nn.Conv1d(dim, hidden_dim, kernel_size=1, stride=1),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(p=dropout_rate),
                nn.Conv1d(hidden_dim, dim, kernel_size=1, stride=1),
                nn.BatchNorm1d(dim),
                nn.GELU(),
            )

        self.revised = revised
        self._init_weights()

    def _init_weights(self):
        for name, module in self.net.named_children():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.bias, std=1e-6)

    def forward(self, x):
        if self.revised:
            x = x.permute(0, 2, 1)
            x = self.net(x)
            x = x.permute(0, 2, 1)
        else:
            x = self.net(x)

        return x


class OutputLayer(nn.Module):
    def __init__(
        self,
        embedding_dim,
        num_classes=1000,
        representation_size=None,
        cls_head=False,
    ):
        super(OutputLayer, self).__init__()

        self.num_classes = num_classes
        modules = []
        if representation_size:
            modules.append(nn.Linear(embedding_dim, representation_size))
            modules.append(nn.Tanh())
            modules.append(nn.Linear(representation_size, num_classes))
        else:
            modules.append(nn.Linear(embedding_dim, num_classes))

        self.net = nn.Sequential(*modules)

        if cls_head:
            self.to_cls_token = nn.Identity()

        self.cls_head = cls_head
        self.num_classes = num_classes
        self._init_weights()

    def _init_weights(self):
        for name, module in self.net.named_children():
            if isinstance(module, nn.Linear):
                if module.weight.shape[0] == self.num_classes:
                    nn.init.zeros_(module.weight)
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        if self.cls_head:
            x = self.to_cls_token(x[:, 0])
        else:
            """
            Scaling Vision Transformer: https://arxiv.org/abs/2106.04560
            """
            x = torch.mean(x, dim=1)

        return self.net(x)

# patch_embed

In [None]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from utils import trunc_normal_


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class EmbeddingStem(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        channels=3,
        embedding_dim=768,
        hidden_dims=None,
        conv_patch=False,
        linear_patch=False,
        conv_stem=True,
        conv_stem_original=True,
        conv_stem_scaled_relu=False,
        position_embedding_dropout=None,
        cls_head=True,
    ):
        super(EmbeddingStem, self).__init__()

        assert (
            sum([conv_patch, conv_stem, linear_patch]) == 1
        ), "Only one of three modes should be active"

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert (
            image_height % patch_height == 0 and image_width % patch_width == 0
        ), "Image dimensions must be divisible by the patch size."

        assert not (
            conv_stem and cls_head
        ), "Cannot use [CLS] token approach with full conv stems for ViT"

        if linear_patch or conv_patch:
            self.grid_size = (
                image_height // patch_height,
                image_width // patch_width,
            )
            num_patches = self.grid_size[0] * self.grid_size[1]

            if cls_head:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
                num_patches += 1

            # positional embedding
            self.pos_embed = nn.Parameter(
                torch.zeros(1, num_patches, embedding_dim)
            )
            self.pos_drop = nn.Dropout(p=position_embedding_dropout)

        if conv_patch:
            self.projection = nn.Sequential(
                nn.Conv2d(
                    channels,
                    embedding_dim,
                    kernel_size=patch_size,
                    stride=patch_size,
                ),
            )
        elif linear_patch:
            patch_dim = channels * patch_height * patch_width
            self.projection = nn.Sequential(
                Rearrange(
                    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                    p1=patch_height,
                    p2=patch_width,
                ),
                nn.Linear(patch_dim, embedding_dim),
            )
        elif conv_stem:
            assert (
                conv_stem_scaled_relu ^ conv_stem_original
            ), "Can use either the original or the scaled relu stem"

            if not isinstance(hidden_dims, list):
                raise ValueError("Cannot create stem without list of sizes")

            if conv_stem_original:
                """
                Conv stem from https://arxiv.org/pdf/2106.14881.pdf
                """

                hidden_dims.insert(0, channels)
                modules = []
                for i, (in_ch, out_ch) in enumerate(
                    zip(hidden_dims[:-1], hidden_dims[1:])
                ):
                    modules.append(
                        nn.Conv2d(
                            in_ch,
                            out_ch,
                            kernel_size=3,
                            stride=2 if in_ch != out_ch else 1,
                            padding=1,
                            bias=False,
                        ),
                    )
                    modules.append(nn.BatchNorm2d(out_ch),)
                    modules.append(nn.ReLU(inplace=True))

                modules.append(
                    nn.Conv2d(
                        hidden_dims[-1], embedding_dim, kernel_size=1, stride=1,
                    ),
                )
                self.projection = nn.Sequential(*modules)

            elif conv_stem_scaled_relu:
                """
                Conv stem from https://arxiv.org/pdf/2109.03810.pdf
                """
                assert (
                    len(hidden_dims) == 1
                ), "Only one value for hidden_dim is allowed"
                mid_ch = hidden_dims[0]

                # fmt: off
                self.projection = nn.Sequential(
                    nn.Conv2d(
                        channels, mid_ch,
                        kernel_size=7, stride=2, padding=3, bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch, mid_ch,
                        kernel_size=3, stride=1, padding=1, bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch, mid_ch,
                        kernel_size=3, stride=1, padding=1, bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch, embedding_dim,
                        kernel_size=patch_size // 2, stride=patch_size // 2,
                    ),
                )
                # fmt: on

            else:
                raise ValueError("Undefined convolutional stem type defined")

        self.conv_stem = conv_stem
        self.conv_patch = conv_patch
        self.linear_patch = linear_patch
        self.cls_head = cls_head

        self._init_weights()

    def _init_weights(self):
        if not self.conv_stem:
            trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        if self.conv_stem:
            x = self.projection(x)
            x = x.flatten(2).transpose(1, 2)
            return x

        # paths for cls_token / position embedding
        elif self.linear_patch:
            x = self.projection(x)
        elif self.conv_patch:
            x = self.projection(x)
            x = x.flatten(2).transpose(1, 2)

        if self.cls_head:
            cls_token = self.cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_token, x), dim=1)
        return self.pos_drop(x + self.pos_embed)

# transformer

<div align=center>
<img src=.\img\transformerencoder.png width="40%"/>
</div>

In [None]:
from torch import nn
from modules import Attention, FeedForward, PreNorm


class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        mlp_ratio=4.0,
        attn_dropout=0.0,
        dropout=0.0,
        qkv_bias=True,
        revised=False,
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        assert isinstance(
            mlp_ratio, float
        ), "MLP ratio should be an integer for valid "
        mlp_dim = int(mlp_ratio * dim)

        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PreNorm(
                            dim,
                            Attention(
                                dim,
                                num_heads=heads,
                                qkv_bias=qkv_bias,
                                attn_drop=attn_dropout,
                                proj_drop=dropout,
                            ),
                        ),
                        PreNorm(
                            dim,
                            FeedForward(dim, mlp_dim, dropout_rate=dropout,),
                        )
                        if not revised
                        else FeedForward(
                            dim, mlp_dim, dropout_rate=dropout, revised=True,
                        ),
                    ]
                )
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# model

<div align=center>
<img src=.\img\vit.gif width="40%"/>
</div>

In [None]:
import torch.nn as nn

from patch_embed import EmbeddingStem
from transformer import Transformer
from modules import OutputLayer


class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        in_channels=3,
        embedding_dim=768,
        num_layers=12,
        num_heads=12,
        qkv_bias=True,
        mlp_ratio=4.0,
        use_revised_ffn=False,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        use_conv_stem=True,
        use_conv_patch=False,
        use_linear_patch=False,
        use_conv_stem_original=True,
        use_stem_scaled_relu=False,
        hidden_dims=None,
        cls_head=False,
        num_classes=1000,
        representation_size=None,
    ):
        super(VisionTransformer, self).__init__()

        # embedding layer
        self.embedding_layer = EmbeddingStem(
            image_size=image_size,
            patch_size=patch_size,
            channels=in_channels,
            embedding_dim=embedding_dim,
            hidden_dims=hidden_dims,
            conv_patch=use_conv_patch,
            linear_patch=use_linear_patch,
            conv_stem=use_conv_stem,
            conv_stem_original=use_conv_stem_original,
            conv_stem_scaled_relu=use_stem_scaled_relu,
            position_embedding_dropout=dropout_rate,
            cls_head=cls_head,
        )

        # transformer
        self.transformer = Transformer(
            dim=embedding_dim,
            depth=num_layers,
            heads=num_heads,
            mlp_ratio=mlp_ratio,
            attn_dropout=attn_dropout_rate,
            dropout=dropout_rate,
            qkv_bias=qkv_bias,
            revised=use_revised_ffn,
        )
        self.post_transformer_ln = nn.LayerNorm(embedding_dim)

        # output layer
        self.cls_layer = OutputLayer(
            embedding_dim,
            num_classes=num_classes,
            representation_size=representation_size,
            cls_head=cls_head,
        )

    def forward(self, x):
        x = self.embedding_layer(x)
        x = self.transformer(x)
        x = self.post_transformer_ln(x)
        x = self.cls_layer(x)
        return x