# ViT程式碼1-timm庫 實作ViT

**reference**

- rwightman resource：https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py



**timm 庫**

PyTorch Image Models（簡稱 timm）是一個龐大的 PyTorch 代碼集合，包含以下內容：

- 影像模型
- 層（layers）
- 工具（utilities）
- 優化器（optimizers）
- 調度器（schedulers）
- 資料加載器 / 增強（data-loaders / augmentations）
- 訓練 / 驗證腳本
其目標是將各種最新的 SOTA 模型整合在一起，並具備重現 ImageNet 訓練結果的能力。

**ViT Network architecture**

<img src='./images/ViT.png' width="600" align="left">

ViT 整體架構可以透過以下五個步驟描述：

1. 將輸入影像切分為多個 patches
將影像分割成小的區塊（patches），作為後續處理的基礎單位。

2. 從每個 patch 中獲取線性嵌入（表徵，representation），稱為 Patch Embeddings
使用線性投影對每個 patch 進行特徵提取，轉換為一維向量表示。

3. 將位置嵌入（Position Embeddings）和 `[cls] token` 添加到每個 Patch Embeddings
為每個 patch 加入位置資訊以保留空間結構，並插入一個特別的 `[cls] token`，用於分類任務的輸出。

4. 通過 Transformer Encoder 並獲取 [cls] token 的輸出值
將所有嵌入（包含 `[cls] token`）輸入 Transformer Encoder，進行多層處理，最後提取 `[cls] token` 的輸出。

5. 通過 `MLP Head` 傳遞 `[cls] token` 的表徵以獲得最終的類別預測
將 `[cls] token` 的最終表徵傳入 `MLP Head`，得到分類結果。

注意：
Transformer Encoder 內部包含一個 `MLP`，但它與用於分類的 `MLP Head` 是不同的結構。

<img src="./images/vit-01.png" width="1000" align="left" alt="" title="fig-2 Simplified Model Overview" />

現在具體看一下這五個步驟。 假設要對大小為 `224 x 224` 的青蛙的 `3` 通道 (RGB) 輸入圖像進行分類。

 <strong>第一步</strong> 是在影像上創建大小為 `16 x 16` 的patches。因此創建 `14 x 14` 或 `196` 個這樣的 patches。 可以將這些 patches 放在一條直線上，其中第一個 patch 來自輸入影像的左上角，最後一個 patch 來自右下角。從圖中可以看出，patch 大小為 `3 x 16 x 16 = 768`，其中 `3`表示通道數（RGB）。

在 <strong>第二步</strong>中，將這些 patches 通過線性投影層，獲得每個影像 patch 的 `1 x 768` 矢量表示（`3 x 16 x 16 = 768`），這些表示在圖中以紫色顯示。在論文中，作者將 patches 的這些表徵稱為 Patch Embeddings。因為總共有 `196` 個 patches，每個 patch 都表示為一個 `1 x 768` 長的向量。因此，patch embedding matrix 的總大小為 `196 x 768`。

在<strong>第三步</strong>中，採用大小為 `196 x 768` 且類似於 BERT 的 patch embedding matrix，將 `[cls]` token前置添加到該 embedded patches 序列中。添加 `[cls]` token後，Patch Embeddings 的大小變為 `197 x 768`，然後與 Position Embeddings 相加，Position Embeddings 的大小也為 `197 x 768`。

在<strong>第四步</strong>中，將這些帶有位置信息和前置 `[cls]` token的預處理 patch embeddings 傳遞給 Transformer 編碼器，並獲得 `[cls]` token的學習特徵。Transformer Encoder 的輸出的大小為 `1 x 768`，然後作為最後<strong>第五步</strong>的一部分被餵送到 `MLP Head`（線性層）以獲得類別預測。

在查看了整體架構之後，現在將在以下部分中詳細查看各個步驟。

## Patch Embeddings

<p>In this section we will be looking at <strong>steps one and two</strong> in detail. That is the process of <u>getting patch embeddings from an input image</u>.</p>

<img src="./images/vit-02.png" width="1000" align="left" />

到目前為止，從一個輸入影像中獲取 patch embeddings 的方式是首先將影像分割成固定大小的 patches，然後使用線性投影層獲得每個 patch 的線性嵌入。

但是，實際上可以使用<strong>2D 卷積</strong>操作將這兩個步驟合併成一個步驟。從實現的角度來看，這樣做更好，因為 GPU 已針對執行卷積操作進行了優化，並且無需先將圖像拆分為區塊。

如果將 `out_channels` 的數量設置為 `768`，並且將 `kernel_size` 和 `stride` 都設置為 `16`，那麼如圖所示，一旦執行卷積操作（其中 2-D Convolution 的核大小為 `3 x 16 x 16`），可以得到大小為 `196 x 768` 的<strong>Patch Embeddings</strong> matrix，如下所示：

In [1]:
import torch
import torch.nn as nn
import torchvision

In [2]:
# input image `B, C, H, W`
x = torch.randn(1, 3, 224, 224)
# 2D conv
conv = nn.Conv2d(3, 768, 16, 16)
conv(x).reshape(-1, 196).transpose(0,1).shape

torch.Size([196, 768])

## [cls] token & Position Embeddings

在<strong>第三步</strong>，前置添加 [cls] token，並將 Positional Embeddings 相加到 Patch Embeddings 中。

與 BERT 的 [class] token 類似，將可學習的嵌入添加到 Patch Embedding 序列中，其在 Transformer 編碼器（稱為<strong>Z<sub>L</sub><sup>0</sup></strong>）輸出處的狀態用作影像特徵。在預訓練和微調期間，分類頭都附加到<strong>Z<sub>L</sub><sup>0</sup></strong>。

Position Embeddings 也被添加到 Patch Embeddings 中以保留位置信息。使用標準的可學習一維位置嵌入，生成的嵌入向量序列用作編碼器的輸入。

這個過程可以很容易地可視化如下：

<img src="./images/vit-03.png" width="1000" align="left" />

<p>As can be seen , the <code class="language-plaintext highlighter-rouge">[cls]</code> token is a vector of size <code class="language-plaintext highlighter-rouge">1 x 768</code>. We <strong>prepend</strong> it to the <strong>Patch Embeddings</strong>, thus, the updated size of <strong>Patch Embeddings</strong> becomes <code class="language-plaintext highlighter-rouge">197 x 768</code>.</p>

<p>Next, we add <strong>Positional Embeddings</strong> of size <code class="language-plaintext highlighter-rouge">197 x 768</code> to the <strong>Patch Embeddings</strong> with <code class="language-plaintext highlighter-rouge">[cls]</code> token to get <strong>combined embeddings</strong> which are then fed to the <code class="language-plaintext highlighter-rouge">Transformer Encoder</code>. This is a pretty standard step that comes from the original Transformer paper - <a href="https://arxiv.org/abs/1706.03762">Attention is all you need</a>.</p>

<blockquote>
  <p>Note that the Positional Embeddings and <code class="language-plaintext highlighter-rouge">cls</code> token vector is nothing fancy but rather just a trainable <code class="language-plaintext highlighter-rouge">nn.Parameter</code> matrix/vector.</p>
</blockquote>


##  Transformer Encoder


在本節中，我們將詳細研究 <strong>Transformer Encoder</strong>。 <strong>Transformer 編碼器</strong> 由<strong>多頭注意力（Multi-Head Attention）</strong> 和 <strong>MLP</strong> 塊的交替組成。此外，在每個塊之前使用 Layer Norm，在每個塊之後使用殘差連接。

<strong>Transformer Encoder</strong> 的層/塊（layer/block）可以如下所示進行可視化：

<img src="./images/vit-07.png"  width="800" />

<p>The first layer of the <strong>Transformer Encoder</strong> accepts <strong>combined embeddings</strong> of shape<code class="language-plaintext highlighter-rouge">197 x 768</code> as input. For all subsequent layers, the inputs are the outputs <code class="language-plaintext highlighter-rouge">Out</code> matrix of shape <code class="language-plaintext highlighter-rouge">197 x 768</code> from the previous layer of the <strong>Transformer Encoder</strong>. There are a total of <u>12 such layers</u> in the <strong>Transformer Encoder</strong> of the ViT-Base architecture.</p>

<p>Inside the layer, the inputs are first passed through a <strong>Layer Norm</strong>, and then fed to <strong>Multi-Head Attention</strong> block.</p>

<p>Inside the <strong>Multi-Head Attention</strong>, the inputs are first converted to <code class="language-plaintext highlighter-rouge">197 x 2304 (768*3)</code> shape using a <strong>Linear layer</strong> to get the <strong>qkv</strong> matrix. Next we reshape this <strong>qkv</strong> matrix into <code class="language-plaintext highlighter-rouge">197 x 3 x 768</code> where each of the three matrices of shape <code class="language-plaintext highlighter-rouge">197 x 768</code> represent the <strong>q</strong>, <strong>k</strong> and <strong>v</strong> matrices. These <strong>q</strong>, <strong>k</strong> and <strong>v</strong> matrices are further reshaped to <code class="language-plaintext highlighter-rouge">12 x 197 x 64</code> to represent the 12 attention heads. Once we have the <strong>q</strong>, <strong>k</strong> and <strong>v</strong> matrices, we finally perform the attention operation inside the <strong>Multi-Head Attention</strong> block which is given by the equation:</p>

<img src="./images/vit-08.png" alt="" width="1000" />

<p>Once we get the outputs from the <strong>Multi-Head Attention</strong> block, these are added to the inputs (skip connection) to get the final outouts that again get passed to <strong>Layer Norm</strong> before being fed to the <strong>MLP</strong> Block.</p>

<p>The <strong>MLP</strong>, is a Multi-Layer Perceptron block consists of two linear layers and a GELU non-linearity. The outputs from the <strong>MLP</strong> block are again added to the inputs (skip connection) to get the final output from one layer of the <strong>Transformer Encoder</strong>.</p>

<p>Having looked at a single layer inside the <strong>Transformer Encoder</strong>, let’s now zoom out and look at the complete <strong>Transformer Encoder</strong>.</p>

<img src="./images/vit-06.png" alt="" title="fig-6 Transformer Encoder" />
    
從上圖可以看出，一個 <strong>Transformer Encoder</strong> 由 12 層組成。第一層的輸出被餵送到第二層，第二層的輸出被餵送到第三層，直到我們從 <strong>Transformer Encoder</strong> 的第 12 層獲得最終輸出，然後將其餵送到 <strong>MLP Head</strong> 以獲得類別預測。

## The Vision Transformer in PyTorch 

在詳細了解了 Vision Transformer 架構之後，現在讓我們看一下如何在 PyTorch 中實現該架構。我們將參考 <a href="https://github.com/rwightman/pytorch-image-models">timm</a> 的程式碼來解釋實現過程。下面的程式碼段直接從 <a href="https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py">here</a> 複製而來。

首先，我們從底層開始構建 Vision Transformer。那麼，如何獲得 Patch Embeddings 呢？

In [3]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

<p>As we know, we use a <strong>2-D Convolution</strong> where <code class="language-plaintext highlighter-rouge">stride</code>, <code class="language-plaintext highlighter-rouge">kernel_size</code> are set to <code class="language-plaintext highlighter-rouge">patch_size</code>. Thus, that is exactly what the class above does. We set <code class="language-plaintext highlighter-rouge">self.proj</code> to be a <code class="language-plaintext highlighter-rouge">nn.Conv2d</code> which goes from 3-channels to <code class="language-plaintext highlighter-rouge">768</code> and to get <code class="language-plaintext highlighter-rouge">196 x 768</code> patch embedding matrix.</p>

```shell
patch_embed = PatchEmbed()
x = torch.randn(1, 3, 224, 224)
patch_embed(x).shape 
>> torch.Size([1, 196, 768])
```

It is also pretty easy to implement the MLP Block inside the Transformer Encoder below:

In [4]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Basically, it consists of two layers and a GELU activation layer. There isn’t a lot happening in this class and is pretty easy to implement. Next, we implement Attention as below:

In [5]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

<p>As described inside the <strong>Multi-Head Attention</strong> block, we use a Linear layer to get the <strong>qkv</strong> matrix. Also, we apply the attention operation inside the <code class="language-plaintext highlighter-rouge">forward</code> method above like so:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>        <span class="n">attn</span> <span class="o">=</span> <span class="p">(</span><span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">scale</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>The above code implements attention equation. Since we have already implemented the <strong>Attention</strong> Layer and <strong>MLP</strong> block, let’s quickly implement a single layer of the <strong>Transformer Encoder</strong>. A single <code class="language-plaintext highlighter-rouge">Block</code> consists of Layer Norm, Attention and MLP block.</p>

In [6]:
class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

從上面的 `forward` 方法可以看出，這個 `Block` 接受輸入 `x`，將它們傳遞給 `self.norm1`，這是 `LayerNorm`，然後進行注意力操作。接下來，在通過 `self.mlp` 和 `Dropout` 之前，再次對注意力操作後的輸出進行歸一化，以從該單個 block 中獲得輸出 `Out` 矩陣。

現在我們已經有了所有的部分，Vision Transformer 的完整架構可以如下實現：

In [7]:
class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

## timm 官方程式碼

In [8]:
class VisionTransformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
        - https://arxiv.org/abs/2010.11929
    Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
        - https://arxiv.org/abs/2012.12877
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None, weight_init=''):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
            weight_init: (str): weight init scheme
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        self.init_weights(weight_init)

    def init_weights(self, mode=''):
        assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
        trunc_normal_(self.pos_embed, std=.02)
        if self.dist_token is not None:
            trunc_normal_(self.dist_token, std=.02)
        if mode.startswith('jax'):
            # leave cls token as zeros to match jax impl
            named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
        else:
            trunc_normal_(self.cls_token, std=.02)
            self.apply(_init_vit_weights)

    def _init_weights(self, m):
        # this fn left here for compat with downstream users
        _init_vit_weights(m)

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=''):
        _load_weights(self, checkpoint_path, prefix)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'dist_token'}

    def get_classifier(self):
        if self.dist_token is None:
            return self.head
        else:
            return self.head, self.head_dist

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        if self.num_tokens == 2:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])  # x must be a tuple
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x