In [1]:
import numpy as np
import random
import torch

In [290]:
class channelPool(torch.nn.Module):
    def __init__(self, in_channels, h, w, batch_size):
        super(channelPool, self).__init__()
        self.batch_size = batch_size
        self.h = h
        self.w = h


    def forward(self, x):
        # torch.reshape(v, (1,2,2))
        x, i = torch.max(x, dim = 1)
        return torch.reshape(x, (self.batch_size, 1, self.h, self.w))


class spatialPool(torch.nn.Module):
    def __init__(self, in_channels, batch_size):
        super(spatialPool, self).__init__()
        self.in_channels = in_channels
        self.batch_size = batch_size

    def forward(self, x):
        x, i = torch.max(x, dim = -1)
        x, i = torch.max(x, dim = -1)
        return torch.reshape(x, (self.batch_size, self.in_channels, 1, 1))

In [291]:
32/16

2.0

In [296]:
IN_CHANNELS = 32
H = 224
W = 224
batch_size = 2

input = torch.tensor(np.random.random((batch_size, IN_CHANNELS, H, W)), dtype = torch.float32)
# input

In [297]:
# sp = spatialPool(in_channels = IN_CHANNELS)
# v = sp(input)
# # v

In [298]:
# cp = channelPool(in_channels = IN_CHANNELS)
# v = cp(input)
# v

In [312]:
class attnNet(torch.nn.Module):
    def __init__(self, in_channels, h, w, batch_size):
        super(attnNet, self).__init__()
        self.sP = spatialPool(in_channels = in_channels, batch_size = batch_size)
        self.cP = channelPool(in_channels = in_channels, h = h, w = w, batch_size = batch_size)
        self.out_channels = int(in_channels/16)
        self.convR = torch.nn.Conv2d(in_channels = in_channels, out_channels = self.out_channels, kernel_size = (1,1), bias=True)
        self.convA = torch.nn.Conv2d(in_channels = self.out_channels, out_channels = self.out_channels, kernel_size = (1,1), bias=True)
        self.convB = torch.nn.Conv2d(in_channels = self.out_channels, out_channels = self.out_channels, kernel_size = (3,3), bias=True, padding = 1)
        self.convC = torch.nn.Conv2d(in_channels = self.out_channels, out_channels = self.out_channels, kernel_size = (7,7), bias=True, padding = 3)
        self.convE = torch.nn.Conv2d(in_channels = self.out_channels * 3, out_channels = in_channels, kernel_size = (1,1), bias=True)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        '''
        For a given input feature mapFin∈RC×H×W, the global pooling  operations along the spatial and channel dimensions are applied respectively to obtain the channel descriptorFc∈RC×1×1and the spatial descriptorFs∈R1×H×W
        '''
        fs = self.sP(x)
        fc = self.cP(x)
        '''
        Then, theFcandFs are expanded to the size ofC×H×W. Element-wise multiplication(Eq.9) is used for computing the initial integrated 3D spatial-channel descriptorFsc∈RC×H×W. Fsc¼Fs⊗Fc
        '''
        fsc = torch.mul(fs, fc)
        '''
        Then, a convolution block is applied toFscto refine the spatial-channel dependencies.
        '''
        '''
        A convolution with 1 × 1 kernel and a channel reduction ratior= 16 is applied in the first subblock for reducingcomputational burden and the number of parameters. 
        '''
        r = self.convR(fsc)
        '''The middle subblock contains three convolution operations with the kernelsizes of 1 × 1, 3 × 3, and 7 × 7 separately. Hence, the effective usageof the contextual information can be guaranteed by the different receptive field
        '''
        a = self.convA(r)
        b = self.convB(r)
        c = self.convC(r)
        '''
        The outputs of the three convolutions are concatenated...
        '''
        cat = torch.cat((a,b,c), dim = 1)
        '''
        ...and fed into the third subblock with a kernel size of 1 × 1 to output a feature map with the same channel number as the input map. 
        '''
        e = self.convE(cat)
        '''Finally, a sigmoid activationσis applied to compute the final3D spatial-channel attention mapMsc(Fin)∈RC×H×W, i.e.'''
        out = self.sigmoid(e)
        return out

In [313]:
model = attnNet(IN_CHANNELS, H, W, batch_size)

model(input)

tensor([[[[0.4665, 0.4559, 0.4573,  ..., 0.4568, 0.4594, 0.4650],
          [0.4688, 0.4553, 0.4590,  ..., 0.4576, 0.4589, 0.4664],
          [0.4651, 0.4520, 0.4546,  ..., 0.4555, 0.4558, 0.4640],
          ...,
          [0.4645, 0.4516, 0.4506,  ..., 0.4525, 0.4520, 0.4596],
          [0.4656, 0.4518, 0.4520,  ..., 0.4537, 0.4517, 0.4596],
          [0.4667, 0.4602, 0.4600,  ..., 0.4646, 0.4613, 0.4679]],

         [[0.4248, 0.4277, 0.4279,  ..., 0.4207, 0.4159, 0.4089],
          [0.4231, 0.4267, 0.4257,  ..., 0.4166, 0.4124, 0.4063],
          [0.4241, 0.4266, 0.4269,  ..., 0.4182, 0.4128, 0.4067],
          ...,
          [0.4207, 0.4258, 0.4273,  ..., 0.4173, 0.4128, 0.4096],
          [0.4185, 0.4220, 0.4226,  ..., 0.4153, 0.4112, 0.4087],
          [0.4141, 0.4152, 0.4149,  ..., 0.4076, 0.4050, 0.4071]],

         [[0.4301, 0.4261, 0.4274,  ..., 0.4272, 0.4283, 0.4269],
          [0.4282, 0.4241, 0.4256,  ..., 0.4253, 0.4264, 0.4268],
          [0.4262, 0.4222, 0.4233,  ..., 0

In [314]:
model(input).shape

torch.Size([2, 32, 224, 224])