In [1]:
import timm
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from typing import Optional, Dict, Tuple
from ptflops import get_model_complexity_info
from pprint import pprint
import torch
import pdb
from torch import nn
from thop import profile
import torchsummary
from micronet.utils import cfg_m0
from micronet.backbone import *
%load_ext autoreload
%autoreload 2
# model_name = 'micronet_m2'
# model = MicroNet(cfg_m0)

In [26]:
# channel wise attention
class CA_layer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CA_layer, self).__init__()
        # global average pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channel, channel//reduction, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(channel//reduction),
            nn.Hardswish(),
            nn.Conv2d(channel//reduction, channel, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(channel),
            nn.Hardsigmoid()
        )

    def forward(self, x):
        y = self.fc(self.gap(x))
        return x*y.expand_as(x)
class gcc_ca_mf_block(nn.Module):
    def __init__(self,
                 dim: int,
                 meta_kernel_size: int,
                 instance_kernel_method='crop',
                 use_pe:Optional[bool]=True,
                 mid_mix: Optional[bool]=True,
                 bias: Optional[bool]=True,
                 ffn_dim: Optional[int]=2,
                 ffn_dropout=0.0,
                 dropout=0.1):

        super(gcc_ca_mf_block, self).__init__()

        # spatial part,
        self.pre_Norm_1 = nn.BatchNorm2d(num_features=dim)
        self.pre_Norm_2 = nn.BatchNorm2d(num_features=dim)

        self.meta_kernel_1_H = nn.Conv2d(dim, dim, (meta_kernel_size, 1), groups=dim).weight
        self.meta_kernel_1_W = nn.Conv2d(dim, dim, (1, meta_kernel_size), groups=dim).weight
        self.meta_kernel_2_H = nn.Conv2d(dim, dim, (meta_kernel_size, 1), groups=dim).weight
        self.meta_kernel_2_W = nn.Conv2d(dim, dim, (1, meta_kernel_size), groups=dim).weight

        if bias:
            self.meta_1_H_bias = nn.Parameter(torch.randn(dim))
            self.meta_1_W_bias = nn.Parameter(torch.randn(dim))
            self.meta_2_H_bias = nn.Parameter(torch.randn(dim))
            self.meta_2_W_bias = nn.Parameter(torch.randn(dim))
        else:
            self.meta_1_H_bias = None
            self.meta_1_W_bias = None
            self.meta_2_H_bias = None
            self.meta_2_W_bias = None

        self.instance_kernel_method = instance_kernel_method

        if use_pe:
            self.meta_pe_1_H = nn.Parameter(torch.randn(1, dim, meta_kernel_size, 1))
            self.meta_pe_1_W = nn.Parameter(torch.randn(1, dim, 1, meta_kernel_size))
            self.meta_pe_2_H = nn.Parameter(torch.randn(1, dim, meta_kernel_size, 1))
            self.meta_pe_2_W = nn.Parameter(torch.randn(1, dim, 1, meta_kernel_size))


        if mid_mix:
            self.mixer = nn.ChannelShuffle(groups=2)

        self.mid_mix = mid_mix
        self.use_pe = use_pe
        self.dim = dim

        # channel part
        self.ffn = nn.Sequential(
            nn.BatchNorm2d(num_features=2*dim),
            nn.Conv2d(2*dim, ffn_dim, kernel_size=(1, 1), bias=True),
            nn.Hardswish(),
            nn.Dropout(p=ffn_dropout),
            nn.Conv2d(ffn_dim, 2*dim, kernel_size=(1, 1), bias=True),
            nn.Dropout(p=dropout)
        )

        self.ca = CA_layer(channel=2*dim)

    def get_instance_kernel(self, instance_kernel_size):
        if self.instance_kernel_method == 'crop':
            return self.meta_kernel_1_H[:, :, : instance_kernel_size,:], \
                   self.meta_kernel_1_W[:, :, :, :instance_kernel_size], \
                   self.meta_kernel_2_H[:, :, :instance_kernel_size, :], \
                   self.meta_kernel_2_W[:, :, :, :instance_kernel_size]

        elif self.instance_kernel_method == 'interpolation_bilinear':
            H_shape = [instance_kernel_size, 1]
            W_shape = [1, instance_kernel_size]
            return F.interpolate(self.meta_kernel_1_H, H_shape, mode='bilinear', align_corners=True), \
                   F.interpolate(self.meta_kernel_1_W, W_shape, mode='bilinear', align_corners=True), \
                   F.interpolate(self.meta_kernel_2_H, H_shape, mode='bilinear', align_corners=True), \
                   F.interpolate(self.meta_kernel_2_W, W_shape, mode='bilinear', align_corners=True),

        else:
            print('{} is not supported!'.format(self.instance_kernel_method))

    def get_instance_pe(self, instance_kernel_size):
        if self.instance_kernel_method == 'crop':
            return self.meta_pe_1_H[:, :, :instance_kernel_size, :]\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   self.meta_pe_1_W[:, :, :, :instance_kernel_size]\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   self.meta_pe_2_H[:, :, :instance_kernel_size, :]\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   self.meta_pe_2_W[:, :, :, :instance_kernel_size]\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size)

        elif self.instance_kernel_method == 'interpolation_bilinear':
            return F.interpolate(self.meta_pe_1_H, [instance_kernel_size, 1], mode='bilinear', align_corners=True)\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   F.interpolate(self.meta_pe_1_W, [1, instance_kernel_size], mode='bilinear', align_corners=True)\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   F.interpolate(self.meta_pe_2_H, [instance_kernel_size, 1], mode='bilinear', align_corners=True)\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size), \
                   F.interpolate(self.meta_pe_2_W, [1, instance_kernel_size], mode='bilinear', align_corners=True)\
                       .expand(1, self.dim, instance_kernel_size, instance_kernel_size)
        else:
            print('{} is not supported!'.format(self.instance_kernel_method))

    def forward(self, x: Tensor) -> Tensor:

        x_1, x_2 = torch.chunk(x, 2, 1)
        x_1_res, x_2_res = x_1, x_2
        _, _, f_s, _ = x_1.shape

        K_1_H, K_1_W, K_2_H, K_2_W = self.get_instance_kernel(f_s)

        if self.use_pe:
            pe_1_H, pe_1_W, pe_2_H, pe_2_W = self.get_instance_pe(f_s)

        # **************************************************************************************************sptial part
        # pre norm
        if self.use_pe:
            x_1, x_2 = x_1 + pe_1_H, x_2 + pe_1_W

        x_1, x_2 = self.pre_Norm_1(x_1), self.pre_Norm_2(x_2)

        # stage 1
        x_1_1 = F.conv2d(torch.cat((x_1, x_1[:, :, :-1, :]), dim=2), weight=K_1_H, bias=self.meta_1_H_bias, padding=0,
                         groups=self.dim)
        x_2_1 = F.conv2d(torch.cat((x_2, x_2[:, :, :, :-1]), dim=3), weight=K_1_W, bias=self.meta_1_W_bias, padding=0,
                         groups=self.dim)
        if self.mid_mix:
            mid_rep = torch.cat((x_1_1, x_2_1), dim=1)
            x_1_1, x_2_1 = torch.chunk(self.mixer(mid_rep), chunks=2, dim=1)

        if self.use_pe:
            x_1_1, x_2_1 = x_1_1 + pe_2_W, x_2_1 + pe_2_H

        # stage 2
        x_1_2 = F.conv2d(torch.cat((x_1_1, x_1_1[:, :, :, :-1]), dim=3), weight=K_2_W, bias=self.meta_2_W_bias,
                         padding=0, groups=self.dim)
        x_2_2 = F.conv2d(torch.cat((x_2_1, x_2_1[:, :, :-1, :]), dim=2), weight=K_2_H, bias=self.meta_2_H_bias,
                         padding=0, groups=self.dim)

        # residual
        x_1 = x_1_res + x_1_2
        x_2 = x_2_res + x_2_2

        # *************************************************************************************************channel part
        x_ffn = torch.cat((x_1, x_2), dim=1)
        x_ffn = x_ffn + self.ca(self.ffn(x_ffn))

        return x_ffn

In [31]:
model = gcc_ca_mf_block(32, 12, 'interpolation_bilinear')

In [32]:

model(torch.randn(10, 64, 32, 32)).shape

torch.Size([10, 64, 32, 32])

In [33]:
model_name = 'micro_former_shift_26m'
model = timm.create_model(model_name)


block: 12544, cnn-drop 0.0000, mlp-drop 0.0000
block: 3136, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 12, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 12, token: 128
block: 3136, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 12, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 24, token: 128
block: 784, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 24, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 24, token: 128
block: 784, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 24, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 48, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 48, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 48, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 48, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 64, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 64, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 96, token: 128
block: 49, cn

In [3]:
# model_name = 'mobile_former_26m_micro_m0'
# model = timm.create_model(model_name)


In [4]:
model_name = 'mobile_former_26m'
model = timm.create_model(model_name)


block: 12544, cnn-drop 0.0000, mlp-drop 0.0000
block: 3136, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 12, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 12, token: 128
block: 3136, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 12, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 24, token: 128
block: 784, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 24, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 24, token: 128
block: 784, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 24, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 48, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 48, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 48, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 48, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 64, token: 128
block: 196, cnn-drop 0.0000, mlp-drop 0.0000
L2G: 2 heads, inp: 64, token: 128
G2G: 4 heads
use ffn
G2L: 2 heads, inp: 96, token: 128
block: 49, cn

In [2]:
model_name = 'mobilevitv2_050'
model = timm.create_model(model_name)

In [None]:
input = torch.randn(1, 6, 224, 224)
macs, params, layer_info = profile(model, inputs=(input, ), ret_layer_info=True)
print(model_name, macs/1e6)
# mobile_former_294m 293.074048

In [3]:
with torch.cuda.device(0):
  # net = model
  net = model
  macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

ByobNet(
  1.36 M, 99.129% Params, 363.17 MMac, 100.000% MACs, 
  (stem): ConvNormAct(
    432, 0.032% Params, 5.42 MMac, 1.492% MACs, 
    (conv): Conv2d(432, 0.032% Params, 5.42 MMac, 1.492% MACs, 3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      0, 0.000% Params, 0.0 Mac, 0.000% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
      (act): SiLU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, inplace=True)
    )
  )
  (stages): Sequential(
    1.1 M, 80.347% Params, 357.48 MMac, 98.433% MACs, 
    (0): Sequential(
      1.82 k, 0.133% Params, 22.88 MMac, 6.300% MACs, 
      (0): BottleneckBlock(
        1.82 k, 0.133% Params, 22.88 MMac, 6.300% MACs, 
        (conv1_1x1): ConvNormAct(
          512, 0.037% Params, 6.42 MMac, 1.768% MACs, 
          (conv): Conv2d(512, 0.037% Params, 6.42 MMac, 1.768% MACs, 16, 32, kernel_size=(1, 1), stride=(1, 1), bias

In [6]:
from torchinfo import summary
summary(model, input_size=(1, 3, 224, 224), verbose=0)

Layer (type:depth-idx)                        Output Shape              Param #
MobileFormer                                  [1, 1000]                 --
├─Embedding: 1-1                              --                        512
├─Sequential: 1-2                             [1, 12, 112, 112]         --
│    └─Conv2d: 2-1                            [1, 12, 112, 112]         324
│    └─BatchNorm2d: 2-2                       [1, 12, 112, 112]         24
│    └─ReLU6: 2-3                             [1, 12, 112, 112]         --
├─Sequential: 1-3                             [1, 128, 7, 7]            --
│    └─DnaBlock3: 2-4                         [1, 12, 112, 112]         --
│    │    └─Sequential: 3-1                   [1, 12, 112, 112]         576
│    └─DnaBlock3: 2-5                         [1, 16, 56, 56]           --
│    │    └─Local2Global: 3-2                 [4, 1, 128]               5,016
│    │    └─GlobalBlock: 3-3                  [4, 1, 128]               99,456
│    │    

mobile_former_96m 96.817488 dw  
mobile_former_96m 95.488608 sepdw