In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SimAM(nn.Module):
    def __init__(self, coeff_lambda=1e-4):
        super(SimAM, self).__init__()
        self.coeff_lambda = coeff_lambda

    def forward(self, X):
        """
        X: input tensor with shape (batch_size, num_channels, height, width)
        """
        assert X.dim() == 4, "shape of X must have 4 dimension"

        # spatial size
        n = X.shape[2] * X.shape[3] - 1

        # square of (t - u)
        d = (X - X.mean(dim=[2,3], keepdim=True)).pow(2)
        # print(f'd={d}')

        # d.sum() / n is channel variance
        v = d.sum(dim=[2,3], keepdim=True) / n
        # print(f'v={v}')

        # E_inv groups all importance of X
        E_inv = d / (4 * (v + self.coeff_lambda)) + 0.5
        # print(f'E_inv={E_inv}')

        # return attended features
        return X * F.sigmoid(E_inv)

In [None]:
simam = SimAM()

In [None]:
import torch


for shape in [(1,1,3,3), (1,2,3,3), (2,1,3,3), (32, 3, 32, 32)]:
  X = torch.rand(shape)
  y = simam(X)

  print(X.shape)
  print(y.shape)

  print('\n')


torch.Size([1, 1, 3, 3])
torch.Size([1, 1, 3, 3])


torch.Size([1, 2, 3, 3])
torch.Size([1, 2, 3, 3])


torch.Size([2, 1, 3, 3])
torch.Size([2, 1, 3, 3])


torch.Size([32, 3, 32, 32])
torch.Size([32, 3, 32, 32])




In [None]:
from torchvision import models

model = models.mobilenet_v2(pretrained=True)

In [None]:
class Squeezer(nn.Module):
  def __init__(self):
    super(Squeezer, self).__init__()

  def forward(self, x):
    return torch.squeeze(x, dim=-1)

In [None]:
mobilenet_v2_simam = nn.Sequential()

for name, module in model.named_children():
  if name == 'features':
    feature_seq = nn.Sequential()
    mobilenet_v2_simam.add_module(name, feature_seq)

    for n0, m0 in module.named_children():
      if m0._get_name() == "Conv2dNormActivation":
        block = nn.Sequential()
        feature_seq.add_module(n0, block)

        for n1, m1 in m0.named_children():
          block.add_module(n1, m1)
          if (m1._get_name() == "Conv2d" or
              m1._get_name() == "BatchNorm2d"):
            block.add_module(f'{n1}_simam', SimAM())

      elif m0._get_name() == "InvertedResidual":
        block_big = nn.Sequential()                                             # BLOCKNAME: InvertedResidual
        feature_seq.add_module(n0, block_big)

        for n1, m1 in m0.named_children():
          block_med = nn.Sequential()                                           # BLOCKNAME: Sequential
          block_big.add_module(n1, block_med)

          for n2, m2 in m1.named_children():
            if m2._get_name() == "Conv2dNormActivation":
              block_small = nn.Sequential()                                     # BLOCKNAME: Conv2dNormActivation
              block_med.add_module(n2, block_small)

              for n3, m3 in m2.named_children():
                block_small.add_module(n3, m3)
                if (m3._get_name() == "Conv2d" or
                    m3._get_name() == "BatchNorm2d"):
                  block_small.add_module(f'{n3}_simam', SimAM())
            else:
              block_med.add_module(n2, m2)
              if (m2._get_name() == "Conv2d" or
                m2._get_name() == "BatchNorm2d"):
                block_med.add_module(f'{n2}_simam', SimAM())

      else:
        raise ValueError("Sum ting wong")
  else:
    # continue
    mobilenet_v2_simam.add_module('squeeze0', Squeezer())
    mobilenet_v2_simam.add_module('squeeze1', Squeezer())

    mobilenet_v2_simam.add_module(name, module)
print(mobilenet_v2_simam)


Sequential(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (0_simam): SimAM()
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1_simam): SimAM()
      (2): ReLU6(inplace=True)
    )
    (1): Sequential(
      (conv): Sequential(
        (0): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (0_simam): SimAM()
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1_simam): SimAM()
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1_simam): SimAM()
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2_simam): SimAM()
      )
    )
    (2): Sequential(
      (conv): Sequential(
        (0): Seque

In [None]:
from glob import glob
import os

if not os.path.isfile('/content/cifar-10-python.tar.gz'):
  !wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
if not os.path.exists('/content/cifar-10-batches-py'):
  !tar -xf /content/cifar-10-python.tar.gz

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

data = {}

for file in glob("./cifar-10-batches-py/*"):
  if not file.endswith('.html'):
    data[os.path.basename(file)] = unpickle(file)

--2024-03-22 10:26:18--  https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2024-03-22 10:26:21 (55.9 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]



In [None]:
data['data_batch_1'][b'data'].shape

(10000, 3072)

In [None]:
import numpy as np

def read_numpy_to_3d(data_1d):
  data_3d = np.zeros((data_1d.shape[0], 3, 32, 32))
  for i in range(data_1d.shape[0]):
    data_3d[i][0] = data_1d[i][0:1024].reshape(32, 32)
    data_3d[i][1] = data_1d[i][1024:2048].reshape(32, 32)
    data_3d[i][2] = data_1d[i][2048:3072].reshape(32, 32)
  return data_3d


In [None]:
batch_1_data = read_numpy_to_3d(data['data_batch_1'][b'data'])
batch_1_data

array([[[[ 59.,  43.,  50., ..., 158., 152., 148.],
         [ 16.,   0.,  18., ..., 123., 119., 122.],
         [ 25.,  16.,  49., ..., 118., 120., 109.],
         ...,
         [208., 201., 198., ..., 160.,  56.,  53.],
         [180., 173., 186., ..., 184.,  97.,  83.],
         [177., 168., 179., ..., 216., 151., 123.]],

        [[ 62.,  46.,  48., ..., 132., 125., 124.],
         [ 20.,   0.,   8., ...,  88.,  83.,  87.],
         [ 24.,   7.,  27., ...,  84.,  84.,  73.],
         ...,
         [170., 153., 161., ..., 133.,  31.,  34.],
         [139., 123., 144., ..., 148.,  62.,  53.],
         [144., 129., 142., ..., 184., 118.,  92.]],

        [[ 63.,  45.,  43., ..., 108., 102., 103.],
         [ 20.,   0.,   0., ...,  55.,  50.,  57.],
         [ 21.,   0.,   8., ...,  50.,  50.,  42.],
         ...,
         [ 96.,  34.,  26., ...,  70.,   7.,  20.],
         [ 96.,  42.,  30., ...,  94.,  34.,  34.],
         [116.,  94.,  87., ..., 140.,  84.,  72.]]],


       [[[154.

In [None]:
import torch
simam(simam(simam(torch.Tensor(batch_1_data))))

tensor([[[[ 23.0655,  19.5074,  21.2200,  ...,  38.8042,  37.0070,  35.9787],
          [  9.5187,   0.0000,  10.4938,  ...,  31.4538,  30.9181,  31.3168],
          [ 13.5731,   9.5187,  20.9917,  ...,  30.7889,  31.0490,  29.6915],
          ...,
          [ 94.9364,  77.7673,  71.8846,  ...,  39.4849,  22.4895,  21.8755],
          [ 49.9837,  45.3290,  55.3085,  ...,  53.3666,  28.3235,  26.6787],
          [ 47.8167,  42.7382,  49.2294,  ..., 120.2534,  36.7385,  31.4538]],

         [[ 17.7945,  14.8076,  15.2109,  ...,  33.4161,  30.8751,  30.5483],
          [  8.1461,   0.0000,   3.6941,  ...,  22.2171,  21.3453,  22.0403],
          [  9.3946,   3.2675,  10.2667,  ...,  21.5175,  21.5175,  19.6585],
          ...,
          [ 64.3999,  45.2705,  52.6999,  ...,  33.8203,  11.3529,  12.1157],
          [ 36.5110,  30.2295,  39.1713,  ...,  41.6438,  17.7945,  16.1760],
          [ 39.1713,  32.2688,  38.0551,  ...,  92.5291,  28.7433,  22.9396]],

         [[ 15.2872,  11.4134,

In [None]:
model(torch.Tensor(batch_1_data[0:3]))

tensor([[-5.6307,  1.8367,  1.0201,  ...,  0.1076,  6.6831, -3.8734],
        [-0.7303, -0.8216,  0.5331,  ...,  0.7346, -3.8684,  8.8462],
        [ 2.5241, -2.8335, -3.4567,  ..., -0.4353, -0.0102,  1.9500]],
       grad_fn=<AddmmBackward0>)

In [None]:
batch_1_data = batch_1_data/256

In [None]:
mobilenet_v2_simam[0][:-5](torch.Tensor(batch_1_data[0:3]))

tensor([[[[ 0.1933,  0.1395],
          [-0.0240, -0.2220]],

         [[-0.0145,  0.1435],
          [-0.4159,  0.0389]],

         [[-0.3909,  0.1966],
          [-0.0850,  0.0158]],

         ...,

         [[-0.3911, -0.0119],
          [ 0.0265, -0.0870]],

         [[-0.0356, -0.0198],
          [-0.1628, -0.0181]],

         [[ 0.0483, -0.0288],
          [-0.0852, -0.2260]]],


        [[[ 0.0819, -0.3597],
          [-0.0663, -0.3055]],

         [[-0.0030,  0.2442],
          [-0.2517,  0.0160]],

         [[-0.0503,  0.1545],
          [ 0.0985, -0.0068]],

         ...,

         [[-0.0401, -0.1218],
          [ 0.2835, -0.0725]],

         [[-0.2323,  0.5693],
          [-0.3577, -0.0713]],

         [[-0.0723, -0.2503],
          [-0.1873,  0.0859]]],


        [[[ 0.2843,  0.1122],
          [ 0.0503,  0.1194]],

         [[ 0.6483, -0.1161],
          [ 0.1379, -0.4479]],

         [[-0.0209,  0.0130],
          [ 0.2123, -0.1407]],

         ...,

         [[-0.0434,  

In [None]:
nn.Sequential(mobilenet_v2_simam[0][:-5],
              mobilenet_v2_simam[0][-5][0][:-5],
              mobilenet_v2_simam[0][-5][0][-5][:-4],
              mobilenet_v2_simam[0][-5][0][-5][-4])(torch.Tensor(torch.rand(3,3,32,32)))

tensor([[[[-0.0257]],

         [[-0.0045]],

         [[ 0.0171]],

         ...,

         [[-0.0314]],

         [[ 0.0006]],

         [[-0.0120]]],


        [[[-0.0276]],

         [[-0.0285]],

         [[ 0.0091]],

         ...,

         [[-0.0593]],

         [[ 0.0162]],

         [[-0.0264]]],


        [[[-0.0200]],

         [[-0.0220]],

         [[ 0.0000]],

         ...,

         [[-0.0304]],

         [[ 0.0025]],

         [[-0.0186]]]], grad_fn=<ConvolutionBackward0>)

In [None]:
mobilenet_v2_simam[0][-5][0][-5][-4]

SimAM()

In [None]:
mobilenet_v2_simam[0][-5]

Sequential(
  (conv): Sequential(
    (0): Sequential(
      (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (0_simam): SimAM()
      (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1_simam): SimAM()
      (2): ReLU6(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
      (0_simam): SimAM()
      (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1_simam): SimAM()
      (2): ReLU6(inplace=True)
    )
    (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (2_simam): SimAM()
    (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3_simam): SimAM()
  )
)