<a href="https://colab.research.google.com/github/cheul0518/Competitions/blob/main/HuBMAP_HPA/CoaT_small.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install timm &> /dev/null
!pip install einops &> /dev/null

In [21]:
# Modified from # https://github.com/mlpc-ucsd/CoaT/blob/main/src/models/coat.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model

from einops import rearrange
from functools import partial
from torch import nn, einsum

class LayerNorm2d(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))
        self.eps = eps

    def forward(self, x):
        batch_size, C, H, W = x.shape
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None] # self.weight.Size([dim]) => self.weight[:,None,None].size([dim,1,1])
        return x

# def _cfg_coat(url='', **kwargs):
#     return {'url':url,
#             'num_classes':1000,
#             'input_size':(3, 224, 224),
#             'pool_size':None,
#             'crop_pct':.9,
#             'interpolation':'bicubic',
#             'mean': IMAGENET_DEFAULT_MEAN,
#             'std': IMAGENET_DEFAULT_STD,
#             'first_conv': 'patch_embed.proj',
#             'classifier': 'head',
#             **kwargs
#             }
class Mlp(nn.Module):
    ''' Feed Forward Network(FFN)'''
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class ConvRelPosEnc(nn.Module):
      """ Convolutional Relative Position Encoding"""
      def __init__(self, Ch, h, window): # Ch=embed_dims//num_heads, h=num_heads, window=crpe_window
          """ Intialization.
              Ch: Channels per head
              h: Number of heads
              window: Window size(s) in convolutional relative positional encoding. It can have two forms:
                  1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc
                  2. A dict mapping window size to #attention head splits
                      (e.g. {window size 1: #attention head split 1, window_size 2: #attention head split 2})
                      It will apply different window size to the attention head splits.
          """
          super().__init__()

          if isinstance(window, int):
              window = {window: h}
              self.window = window
          elif isinstance(window, dict):
              self.window = window
          else:
              raise ValueError()
          
          self.conv_list = nn.ModuleList()
          self.head_splits = []
          
          for cur_window, cur_head_split in window.items(): # window = {3:2, 5:3, 7:3}
              dilation = 1
              padding_size = (cur_window + (cur_window-1)*(dilation-1))//2
              cur_conv = nn.Conv2d(cur_head_split*Ch, # crpe_window.values() * embed_dims//num_heads
                                   cur_head_split*Ch,
                                   kernel_size=(cur_window,cur_window),
                                   padding=(padding_size,padding_size),
                                   dilation=(dilation,dilation),
                                   groups=cur_head_split*Ch)
              self.conv_list.append(cur_conv)
              self.head_splits.append(cur_head_split) # self.head_splits = [2, 3, 3] at the end
          self.channel_splits = [x*Ch for x in self.head_splits] # For CoaT small, self.channel_splits = [38, 57, 57]/[80,120,120]

      def forward(self, q, v, size):
          B, h, N, Ch = q.shape # q = [B, h, N(=out_H*out_W)+1, Ch], where Ch = C'(embed_dims[i])//num_heads(=8). For CoaT Small, Ch=[19,40,40,40]
          H, W = size # out_H, out_W, where out_H,out_w=H//patch_size[0],w//patch_size[1]
          assert N == 1 + H*W

          # Convolutional relative position encoding
          q_img = q[:,:,1:,:] # [B, h, N(=out_H*out_W), Ch]
          v_img = v[:,:,1:,:] # [B, h, N(=out_H*out_W), Ch]

          v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # [B, h, N, Ch] -> [B, C'(=h*Ch), out_H, out_W]

          v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels. For CoaT smal, [38,57,57]/[80,120,120]
          conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] # [B, channel_splits[i], out_H, out_W], where channel_splits.sum() = Ch 
          conv_v_img = torch.cat(conv_v_img_list, dim=1) # [B, Ch, out_H, out_W]
          conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, out_H, out_W] -> [B, h, out_H*out_W, Ch]

          EV_hat_img = q_img * conv_v_img
          zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device)
          EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Add cls_token
          return EV_hat # [B, h, N+1, Ch]

class FactorAtt_ConvRelPosEnc(nn.Module):
      """ Factorized attention with convolutional relative position encoding class """
      def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):
          super().__init__()
          self.num_heads = num_heads
          head_dim = dim // num_heads
          self.scale = qk_scale or head_dim ** -0.5

          self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
          self.attn_drop = nn.Dropout(attn_drop) # Not used
          self.proj = nn.Linear(dim, dim)
          self.proj_drop = nn.Dropout(proj_drop)
          
          # Shared convolutional relative position encoding
          self.crpe = shared_crpe

      def forward(self, x, size):
          B, N, C = x.shape # [B, N(=out_H * out_W)+1, C'], where out_H,out_W = H//patch_size[0],W//patch_size[1], C'= embed_dims[i]
          assert N == 1 + size[0] * size[1]

          # Generate Q, K, V
          qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4) # Shape: [3, B, h, N+1, Ch]
          q, k, v = qkv[0], qkv[1], qkv[2] # Each shape: [B, h, N+1, Ch]. Note q, k, v are all different

          # Factorized attention
          k_softmax = k.softmax(dim=2)
          k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch]
          factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N+1, Ch]

          # Convolutional relative position encoding
          crpe = self.crpe(q, v, size=size) # Shape: [B, h, N+1, Ch]

          # Merge and reshape
          x = self.scale * factor_att + crpe
          x = x.transpose(1,2).reshape(B, N, C) # [B, N+1, C']

          # Output projection
          x = self.proj(x)
          x = self.proj_drop(x)
          
          return x # [B, N+1, C']

class ConvPosEnc(nn.Module):
      """ Convolutional Postion Encoding.
          Note: this module is similar to the conditional positional encoding in CPVT
      """
      def __init__(self, dim, k=3):
          super(ConvPosEnc, self).__init__()
          self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) # keep the size
      
      def forward(self, x, size):
          B, N, C = x.shape # B, N(=out_H*out_W) + 1(cls token), C', where out_H,out_W = H//patch_size[0],W//patch_size[1], C'= embed_dims[i]
          H, W = size
          assert N == 1 + H * W

          # Extract cls token and image tokens
          cls_token, img_tokens = x[:,:1], x[:,1:] # Shape: [B, 1, C'], [B, N(=out_H*out_W), C']

          # Depthwise convolution
          feat = img_tokens.transpose(1, 2).view(B, C, H, W) # [B, C', out_H, out_W]
          x = self.proj(feat) + feat
          x = x.flatten(2).transpose(1,2) # [B, N(=out_H*out_W), C']

          # Combine with CLS token
          x = torch.cat((cls_token, x), dim=1) # [B, N(=out_H*out_W)+1, C']. Note: cls_token must be the first param in the eq.
          return x # [B, N(=out_H*out_W)+1, C']

class SerialBlock(nn.Module):
      """ Serial block class
          Note: in this implementation, each serial block only contains a conv-attention and a FFN module
      """
      def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                   drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
          super().__init__()
          
          # Conv-Attention
          self.cpe = shared_cpe
          self.norm1 = norm_layer(dim)
          self.factoratt_crpe = FactorAtt_ConvRelPosEnc(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                        attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
          self.drop_path = DropPath(drop_path) if drop_path >0. else nn.Identity()

          # MLP
          self.norm2 = norm_layer(dim)
          mlp_hidden_dim = int(dim * mlp_ratio)
          self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
          
      def forward(self, x, size):
          # Conv-Attention
          x = self.cpe(x, size) # Apply convolutional position encoding. Shape: in&out:[B,N+1,C]
          cur = self.norm1(x)
          cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutoinal position encoding. Shape:[B,N+1,C']

          # MLP
          cur = self.norm2(x)
          cur = self.mlp(cur) # Shape:[B,N+1,C']
          x = x + self.drop_path(cur)
          return x

class ParallelBlock(nn.Module):
    """ Parallel block class. """
    def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0.,attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 shared_cpes=None, shared_crpes=None):
        super().__init__()

        # Conv-Attention
        self.cpes = shared_cpes

        self.norm12 = norm_layer(dims[1])
        self.norm13 = norm_layer(dims[2])
        self.norm14 = norm_layer(dims[3])
        self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                       attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1])
        self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                       attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2])
        self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                       attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3])
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # MLP
        self.norm22 = norm_layer(dims[1])
        self.norm23 = norm_layer(dims[2])
        self.norm24 = norm_layer(dims[3])
        assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation.
        assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
        mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
        self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def upsample(self, x, output_size, size):
        """ Feature map up-sampling"""
        return self.interpolate(x, output_size=output_size, size=size)

    def downsample(self, x, output_size, size):
        """ Feature map down-sampling"""
        return self.interpolate(x, output_size=output_size, size=size)
    
    def interpolate(self, x, output_size, size):
        """ Feature map interpolation """
        B, N, C = x.shape
        H, W = size
        assert N == 1 + H*W

        cls_token = x[:,:1,:]
        img_tokens = x[:,1:,:]

        img_tokens = img_tokens.transpose(1,2).reshape(B, C, H, W)
        img_tokens = F.interpolate(img_tokens, size=output_size, mode='bilinear') # may have alignment issue
        img_tokens = img_tokens.reshape(B, C, -1).transpose(1,2)

        out = torch.cat((cls_token, img_tokens), dim=1)

        return out
    
    def forward(self, x1, x2, x3, x4, sizes):
        _, (H2,W2), (H3,W3), (H4,W4) = sizes

        # Conv-Attention: Note: x1 is ignored
        x2 = self.cpes[1](x2, size=(H2,W2))
        x3 = self.cpes[2](x3, size=(H3,W3))
        x4 = self.cpes[3](x4, size=(H4,W4))

        cur2 = self.norm12(x2)
        cur3 = self.norm13(x3)
        cur4 = self.norm14(x4)
        cur2 = self.factoratt_crpe2(cur2, size=(H2,W2))
        cur3 = self.factoratt_crpe3(cur3, size=(H3,W3))
        cur4 = self.factoratt_crpe4(cur4, size=(H4,W4))
        upsample3_2 = self.upsample(cur3, output_size=(H2,W2), size=(H3,W3))
        upsample4_3 = self.upsample(cur4, output_size=(H3,W3), size=(H4,W4))
        upsample4_2 = self.upsample(cur4, output_size=(H2,W2), size=(H4,W4))
        downsample2_3 = self.downsample(cur2, output_size=(H3,W3), size=(H2,W2))
        downsample3_4 = self.downsample(cur3, output_size=(H4,W4), size=(H3,W3))
        downsample2_4 = self.downsample(cur2, output_size=(H4,W4), size=(H2,W2))
        cur2 = cur2 + upsample3_2 + upsample4_2
        cur3 = cur3 + upsample4_3 + downsample2_3
        cur4 = cur4 + downsample3_4 + downsample2_4
        x2 = x2 + self.drop_path(cur2)
        x3 = x3 + self.drop_path(cur3)
        x4 = x4 + self.drop_path(cur4)

        # MLP
        cur2 = self.norm22(x2)
        cur3 = self.norm23(x3)
        cur4 = self.norm24(x4)
        cur2 = self.mlp2(cur2)
        cur3 = self.mlp3(cur3)
        cur4 = self.mlp4(cur4)
        x2 = x2 + self.drop_path(cur2)
        x3 = x3 + self.drop_path(cur3)
        x4 = x4 + self.drop_path(cur4)

        return x1, x2, x3, x4
        
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding"""
    def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = to_2tuple(patch_size) # e.g) 16 -> (16, 16)
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        _, _, H, W = x.shape
        out_H, out_W, = H//self.patch_size[0], W//self.patch_size[1]
        y = self.proj(x) # Shape: [B, C, H, W] -> [B, C', out_H, out_W], where C' = embed_dim
        x = self.proj(x).flatten(2).transpose(1,2) # Shape: [B, C, H, W] - > [B, N(=out_H*out_W), C'], where C' = embed_dim
        out = self.norm(x)
        return out, (out_H, out_W)

class CoaT(nn.Module):
    """ Coat Class """
    def __init__(self, patch_size=16, in_chans=3, embed_dims=[0,0,0,0],
                 serial_depths=[0,0,0,0], parallel_depth=0,
                 num_heads=0, mlp_ratios=[0,0,0,0], qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 return_interm_layers=True, out_features=None,
                 crpe_window={3:2, 5:3, 7:3},
                 pretrain=None,
                 out_norm=nn.Identity, # Use nn.Identity, nn.BatchNorm2d, LayerNorm2d                 
                 **kwargs):
        super().__init__()
        self.return_interm_layers = return_interm_layers
        self.pretrain = pretrain
        self.embed_dims = embed_dims
        self.out_features = out_features
        # self.num_classes = num_classes is removed since this code's no longer for classification

        # Patch embeddings
        self.patch_embed1 = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
        self.patch_embed2 = PatchEmbed(patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
        self.patch_embed3 = PatchEmbed(patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
        self.patch_embed4 = PatchEmbed(patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])

        # Class tokens
        self.cls_token1 = nn.Parameter(torch.zeros(1,1,embed_dims[0]))
        self.cls_token2 = nn.Parameter(torch.zeros(1,1,embed_dims[1]))
        self.cls_token3 = nn.Parameter(torch.zeros(1,1,embed_dims[2]))
        self.cls_token4 = nn.Parameter(torch.zeros(1,1,embed_dims[3]))

        # Convolutional position encodings
        self.cpe1 = ConvPosEnc(dim=embed_dims[0],k=3)
        self.cpe2 = ConvPosEnc(dim=embed_dims[1],k=3)
        self.cpe3 = ConvPosEnc(dim=embed_dims[2],k=3)
        self.cpe4 = ConvPosEnc(dim=embed_dims[3],k=3)

        # Convolutional relative position encodings
        self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0]//num_heads, h=num_heads, window=crpe_window)
        self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1]//num_heads, h=num_heads, window=crpe_window)
        self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2]//num_heads, h=num_heads, window=crpe_window)
        self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3]//num_heads, h=num_heads, window=crpe_window)

        # Enable stochastic depth
        dpr = drop_path_rate

        # Serial Blocks 1.
        self.serial_blocks1 = nn.ModuleList([
            SerialBlock(dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 
                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, 
                        shared_cpe=self.cpe1, shared_crpe=self.crpe1) for _ in range(serial_depths[0])])
        
        # Serial Blocks 2.
        self.serial_blocks2 = nn.ModuleList([
            SerialBlock(dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 
                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, 
                        shared_cpe=self.cpe2, shared_crpe=self.crpe2) for _ in range(serial_depths[1])])

        # Serial Blocks 3.
        self.serial_blocks3 = nn.ModuleList([
            SerialBlock(dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 
                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, 
                        shared_cpe=self.cpe3, shared_crpe=self.crpe3) for _ in range(serial_depths[2])])
        # Serial Blocks 4.
        self.serial_blocks4 = nn.ModuleList([
            SerialBlock(dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 
                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, 
                        shared_cpe=self.cpe4, shared_crpe=self.crpe4) for _ in range(serial_depths[3])])

        # Parallel Blocks.
        self.parallel_depth = parallel_depth
        if self.parallel_depth > 0:
            self.parallel_blocks = nn.ModuleList([
                ParallelBlock(dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
                              shared_cpes=[self.cpe1,self.cpe2,self.cpe3,self.cpe4],
                              shared_crpes=[self.crpe1,self.crpe2,self.crpe3,self.crpe4]) for _ in range(parallel_depth)])
        
        self.out_norm = nn.ModuleList([out_norm(embed_dims[i]) for i in range(4)])

        # Initialize weights
        trunc_normal_(self.cls_token1, std=.02)
        trunc_normal_(self.cls_token2, std=.02)
        trunc_normal_(self.cls_token3, std=.02)
        trunc_normal_(self.cls_token4, std=.02)
        self.apply(self._init_weights) # Applies the function callable to each element in the tensor, replacing each element with the value returned by callable
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0) # Fills the input Tensor with the value.
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token1','cls_token2','cls_token3','cls_token4'}

    def insert_cls(self, x, cls_token):
        """ Insert CLS token """
        cls_tokens = cls_token.expand(x.shape[0],-1,-1) # [1,1,C'] -> [B,1,C']
        x = torch.cat((cls_tokens,x),dim=1)
        return x

    def remove_cls(self, x):
        """ Remove CLS token"""
        return x[:, 1:, :]

    def forward(self, x0):
        B = x0.shape[0]

        # Serial Blocks 1.
        x1, (H1, W1) = self.patch_embed1(x0) # [B,N,C'], where N=(H/patch_size[0])*(W/patch_size[1]), C'=embed_dims[0]
        cls = self.cls_token1 # [1,1,C']
        x1 = self.insert_cls(x1, cls) # [B,N+1,C']      
        for blk in self.serial_blocks1:
            x1 = blk(x1, size=(H1,W1))
        x1_nocls = self.remove_cls(x1) # [B,N,C']
        x1_nocls = x1_nocls.reshape(B,H1,W1,-1).permute(0,3,1,2).contiguous() # [B, C', out_H, out_W]

        # Serial Blocks 2.
        x2, (H2, W2) = self.patch_embed2(x1_nocls) # [B,N',C''], where N'=(out_H/patch_size[0])*(out_W/patch_size[1]), C''=embed_dims[1]
        cls = self.cls_token2
        x2 = self.insert_cls(x2, cls)        
        for blk in self.serial_blocks2:
            x2 = blk(x2, size=(H2,W2))
        x2_nocls = self.remove_cls(x2)
        x2_nocls = x2_nocls.reshape(B,H2,W2,-1).permute(0,3,1,2).contiguous()

        # Serial Blocks 3.
        x3, (H3, W3) = self.patch_embed3(x2_nocls)
        cls = self.cls_token3
        x3 = self.insert_cls(x3, cls)
        for blk in self.serial_blocks3:
            x3 = blk(x3, size=(H3,W3))
        x3_nocls = self.remove_cls(x3)
        x3_nocls = x3_nocls.reshape(B,H3,W3,-1).permute(0,3,1,2).contiguous()
        
        # Serial Blocks 4.
        x4, (H4, W4) = self.patch_embed4(x3_nocls)
        cls = self.cls_token4
        x4 = self.insert_cls(x4, cls)
        for blk in self.serial_blocks4:
            x4 = blk(x4, size=(H4,W4))
        x4_nocls = self.remove_cls(x4)
        x4_nocls = x4_nocls.reshape(B,H4,W4,-1).permute(0,3,1,2).contiguous()       

        # Lite version
        if self.parallel_depth == 0:
            x1_nocls = self.out_norm[0](x1_nocls)
            x2_nocls = self.out_norm[1](x2_nocls)
            x3_nocls = self.out_norm[2](x3_nocls)
            x4_nocls = self.out_norm[3](x4_nocls)
            return [x1_nocls, x2_nocls, x3_nocls, x4_nocls]
        
        # Parallel blocks
        if self.parallel_depth > 0:
            for blk in self.parallel_blocks:
                x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1,W1),(H2,W2),(H3,W3),(H4,W4)])

            x1_nocls = self.remove_cls(x1)
            x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0,3,1,2).contiguous()
            x1_nocls = self.out_norm[0](x1_nocls)

            x2_nocls = self.remove_cls(x2)
            x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0,3,1,2).contiguous()
            x2_nocls = self.out_norm[1](x2_nocls)

            x3_nocls = self.remove_cls(x3)
            x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0,3,1,2).contiguous()
            x3_nocls = self.out_norm[2](x3_nocls)

            x4_nocls = self.remove_cls(x4)
            x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0,3,1,2).contiguous()
            x4_nocls = self.out_norm[3](x4_nocls)            

            return [x1_nocls,x2_nocls,x3_nocls,x4_nocls]

class coat_small(CoaT):
    def __init__(self, **kwargs):
        super(coat_small, self).__init__(patch_size=4, 
                                         embed_dims=[152,320,320,320],
                                         serial_depths=[2,2,2,2],
                                         parallel_depth=6,
                                         num_heads=8,
                                         mlp_ratios=[4,4,4,4],
                                         pretrain='coat_small_7479cf9b.pth',
                                         **kwargs)
        
if 1:
    x = torch.randn((2,3,800,800))
    y = coat_small()
    out = y(x)
    p = [x.shape for x in out]
    print(p)

torch.Size([2, 152, 200, 200])
torch.Size([2, 320, 100, 100])
torch.Size([2, 320, 50, 50])
torch.Size([2, 320, 25, 25])
