In [1]:
import paddle
import paddle.nn.functional as func

from paddle import  nn
from paddle.nn.initializer import TruncatedNormal, Constant

'''
#### Set truncated Gaussian distribution ####
'''
truncated_normal_initial = TruncatedNormal(std=.02)
zero = Constant(value=0.)
one = Constant(value=1.)


## do nothing
class Identity(nn.Layer):
    def __init__(self):
        
        super(Identity, self).__init__()
        
        
    def forward(self,inputs):
        return inputs

In [2]:
'''
#### Reshape the tensor while taking convolutional token embedding ####
'''
class Reshape(nn.Layer):
    def __init__(self,string,h,w):
        
        super().__init__()
        
        self.string = string
        self.h = h
        self.w = w
        
        
    def forward(self,inputs):

        if self.string == 'b c h w -> b (h w) c':
            N, C, H, W = inputs.shape
            x = paddle.reshape(x=inputs, shape=(N, -1, self.h*self.w)).transpose((0, 2, 1))

        if self.string == 'b (h w) c -> b c h w':
            N, _, C = inputs.shape
            x = paddle.reshape(x=inputs,shape=(N, self.h, self.w, -1)).transpose((0, 3, 1, 2))
            
        return x

In [3]:
'''
#### Depth separable convolutional layer ####
'''
class DepSepConv2D(nn.Layer):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        
        super(DepSepConv2D, self).__init__()

        #depthwise Conv2D
        self.depthwise = nn.Conv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 
                                   stride=stride, padding=padding, dilation=dilation, groups=in_channels)

        self.bn = nn.BatchNorm(num_channels=in_channels)

        #pointwise Conv2D
        self.pointwise = nn.Conv2D(in_channels=in_channels, out_channels=out_channels, 
                                   kernel_size=1, stride=1, padding=0, dilation=1, groups=1)

        
    def forward(self,input):
        x = self.depthwise(input)
        x = self.bn(x)
        x = self.pointwise(x)
        
        return  x

In [4]:
'''
#### Reshape the tensor in the module of transformer if necessary. ####
'''
def reshape_in_transformer(x, string, l=0, w=0, h=0, **kwargs):
    b, n , c = x.shape[:3]

    if string == 'b n (h d) -> b h n d':
        x = paddle.reshape(x=x,shape=(b, n, h, -1)).transpose((0, 2, 1, 3))

    if string == 'b (l w) n -> b n l w':
        x = paddle.reshape(x=x,shape=(b, l, w, -1)).transpose((0, 3, 1, 2))
        
    if string == 'b (h d) l w -> b h (l w) d':
        b, h_d, l, w = x.shape
        x = paddle.reshape(x=x,shape=(b, h, l*w, -1))

    return x

In [5]:
'''
#### Multi-Heat Attention ####
#### Here we use Depth Separable Convolutional Layer defined above, ####
#### insteading of Linear Project in ViT. ####
'''
class ConAtt(nn.Layer):
    def __init__(self, dim,img_size, heads=8, dim_head=64, 
                 kernel_size=3, q_stride=1, k_stride=1, v_stride=1, 
                 dropout=0., last_stage=False):
        
        super().__init__()
        
        self.last_stage = last_stage
        self.img_size = img_size
        inner_dim = dim_head * heads
        project_out = not (heads==1 and dim_head==dim)

        self.heads = heads
        self.scale = dim_head ** (-0.5)
        pad = (kernel_size - q_stride) // 2
        
        ## using Depth Separable Convolution to reduce the parameters.
        self.to_q = DepSepConv2D(in_channels=dim, out_channels=inner_dim, 
                                 kernel_size=kernel_size, stride=q_stride, padding=pad)
        self.to_k = DepSepConv2D(in_channels=dim, out_channels=inner_dim, 
                                 kernel_size=kernel_size, stride=k_stride, padding=pad)
        self.to_v = DepSepConv2D(in_channels=dim, out_channels=inner_dim, 
                                 kernel_size=kernel_size, stride=v_stride, padding=pad)

        self.out = nn.Sequential(nn.Linear(in_features=inner_dim, out_features=dim), 
                                 nn.Dropout(dropout)) if project_out else Identity()


    def forward(self,x):
        ## x shape: ([1, 3136, 64])
        b, n, c, h = *x.shape, self.heads
        
        if self.last_stage:
            cls_token = x[:, 0]
            x = x[:, 1:]
            ## class token for classification
            cls_token = reshape_in_transformer(x=paddle.unsqueeze(cls_token, axis=1), 
                                               string='b n (h d) -> b h n d', h=h)

        ## x shape: ([1, 64, 56, 56])
        x = reshape_in_transformer(x=x, string='b (l w) n -> b n l w',
                                   l=self.img_size, w=self.img_size)

        ## q shape: ([1, 64, 56, 56])
        q = self.to_q(x) 
        
        ## q shape: ([1, 1, 3136, 64])
        q = reshape_in_transformer(x=q, string='b (h d) l w -> b h (l w) d', h=h)

        k = self.to_k(x)
        k = reshape_in_transformer(x=k, string='b (h d) l w -> b h (l w) d', h=h)

        v = self.to_v(x)
        v = reshape_in_transformer(x=v, string='b (h d) l w -> b h (l w) d', h=h)

        if self.last_stage:
            q = paddle.concat((cls_token, q), axis=2)
            v = paddle.concat((cls_token, v), axis=2)
            k = paddle.concat((cls_token, k), axis=2)


        ## calculate final attention
        attention = (q.matmul(k.transpose((0,1,3,2)))) * self.scale

        ## take softmax
        attention = func.softmax(attention, axis=-1)

        ## matmul v
        out = (attention.matmul(v)).transpose((0, 2, 1, 3)).reshape((b, n, c))

        #linear project
        out = self.out(out)
        
        return  out

In [6]:
'''
#### Residual skip connection in transformer module ####
'''
class Residual(nn.Layer):
    def __init__(self, fn):
        
        super().__init__()
        
        self.fn = fn

        
    def forward(self, inputs, **kwargs):
        x = self.fn(inputs, **kwargs)
        
        return (x + inputs)

In [7]:
'''
#### Add Layer Normalization ####
'''
class LayerNorm(nn.Layer):
    def __init__(self, dim, fn):
        
        super().__init__()
        
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
        
    def forward(self, inputs, **kwargs):
        
        return self.fn(self.norm(inputs), **kwargs)

In [8]:
'''
#### MLP Layer ####
'''
class FeedForward(nn.Layer):
    def __init__(self, dim, hidden_dim, dropout=0.):
        
        super().__init__()
        
        self.neu_net = nn.Sequential(nn.Linear(in_features=dim, out_features=hidden_dim),
                                     nn.GELU(),
                                     nn.Dropout(dropout),
                                     nn.Linear(in_features=hidden_dim, out_features=dim),
                                     nn.Dropout(dropout))
        
        
    def forward(self,inputs):
        
        return self.neu_net(inputs)

In [9]:
'''
#### Defination of Transformer Layer ####
'''
class Transformer(nn.Layer):
    def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
        
        super().__init__()
        
        self.layers = nn.LayerList([nn.LayerList(
            [LayerNorm(dim=dim, fn=ConAtt(dim,img_size, heads=heads, dim_head=dim_head, 
                                          dropout=dropout, last_stage=last_stage)),
             LayerNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
            ]) for _ in range(depth)
        ])

        
    def forward(self, x):
        for attention, feed_forward in self.layers:
            x = attention(x) + x
            x = feed_forward(x) + x
            
        return x

In [10]:
'''
#### Final CvT ####
'''
class CvT(nn.Layer):
    def __init__(self, image_size, in_channels, num_classes, 
                 dim=64, kernels=[7, 3, 3], strides=[4, 2, 2],
                 heads=[1, 3, 6], depth=[1, 2, 10], 
                 pool='cls', dropout=0., emb_dropout=0., scale_dim=4):
        
        super().__init__()

        ## pool type must be either cls or mean pooling
        self.pool = pool
        self.dim = dim
        
        ## stage 1
        self.stage1_conv_embed = nn.Sequential(nn.Conv2D(in_channels=in_channels, out_channels=dim, 
                                                         kernel_size=kernels[0], stride=strides[0], 
                                                         padding=2),
                                               Reshape('b c h w -> b (h w) c', 
                                                       h=image_size//4, w=image_size//4),
                                               nn.LayerNorm(dim))
        
        self.stage1_transformer = nn.Sequential(Transformer(dim=dim, img_size=image_size//4, 
                                                            depth=depth[0], heads=heads[0], 
                                                            dim_head=self.dim, mlp_dim=dim*scale_dim, 
                                                            dropout=dropout), 
                                                Reshape(string='b (h w) c -> b c h w', 
                                                        h=image_size//4, w=image_size//4))

        ## stage 2
        in_channels = dim
        scale = heads[1] // heads[0]
        dim = scale * dim
        
        self.stage2_conv_embed = nn.Sequential(nn.Conv2D(in_channels=in_channels, out_channels=dim, 
                                                         kernel_size=kernels[1], stride=strides[1], 
                                                         padding=1),
                                               Reshape(string='b c h w -> b (h w) c',
                                                        h=image_size//8, w=image_size//8), 
                                               nn.LayerNorm(dim))
        
        self.stage2_transformer = nn.Sequential(Transformer(dim=dim, img_size=image_size//8, 
                                                            depth=depth[1], heads=heads[1], 
                                                            dim_head=self.dim, mlp_dim=dim*scale_dim, 
                                                            dropout=dropout), 
                                                Reshape(string='b (h w) c -> b c h w', 
                                                         h=image_size//8, w=image_size//8))

        ## stage 3
        in_channels = dim
        scale = heads[2] // heads[1]
        dim = scale * dim
        
        self.stage3_conv_embed = nn.Sequential(nn.Conv2D(in_channels=in_channels, out_channels=dim, 
                                                         kernel_size=kernels[2], stride=strides[2], 
                                                         padding=1),
                                               Reshape(string='b c h w -> b (h w) c', 
                                                        h=image_size//16, w=image_size//16),
                                               nn.LayerNorm(dim))
        
        self.stage3_transformer = nn.Sequential(Transformer(dim=dim, img_size=image_size//16, 
                                                            depth=depth[2], heads=heads[2], 
                                                            dim_head=self.dim, mlp_dim=dim*scale_dim, 
                                                            dropout=dropout, last_stage=True))

        ## class token
        self.cls_token = self.create_parameter(shape=(1, 1, dim), default_initializer=zero)
        
        self.add_parameter("cls_token", self.cls_token)

        self.dropout_large = nn.Dropout(emb_dropout)

        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(in_features=dim, 
                                                                   out_features=num_classes))

        
    def forward(self, inputs):
        x = self.stage1_conv_embed(inputs)
        x = self.stage1_transformer(x)

        x = self.stage2_conv_embed(x)
        x = self.stage2_transformer(x)

        x = self.stage3_conv_embed(x)

        b, n, _ = x.shape
        
        cls_tokens = self.cls_token.expand((b, -1, -1))
        
        x = paddle.concat((cls_tokens, x), axis=1)

        x = self.stage3_transformer(x)

        x = paddle.mean(x, axis=1) if self.pool == 'mean' else x[:,0]

        x = self.mlp_head(x)

        return x

In [11]:
'''
###############
#### Check ####
###############
'''
def main():
    model = CvT(image_size=512, in_channels=1, num_classes=4)

#     paddle.summary(model,input_size=(1, 1, 512, 512))

    out = model(paddle.randn(shape=(1, 1, 512, 512)))

    print(f'Shape of output: {out.shape}')


if __name__ == '__main__':
    main()
    

Shape of output: [1, 4]
