In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SpatialAttention(nn.Module):

    def __init__(self, input_shape):

        super(SpatialAttention, self).__init__()

        self.C = input_shape[0]
        self.H = input_shape[1]
        self.W = input_shape[2]

        self.alpha = nn.Parameter(torch.tensor(0.0))

        self.conv1 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)
        self.conv3 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)

    
    def forward(self, x):

        N = self.H * self.W

        a = x
        b = self.conv1(x)
        c = self.conv2(x)
        d = self.conv3(x)

        b = b.view(-1, self.C, N)
        c = c.view(-1, self.C, N)
        d = d.view(-1, self.C, N)

        c = torch.bmm(c.transpose(1, 2), b)
        S = nn.Softmax()(c)
        S = S.transpose(1, 2)

        d = self.alpha * torch.bmm(d, S)
        d = d.view(-1, self.C, self.H, self.W)
        E = a + d

        return E

In [3]:
sa = SpatialAttention((64, 256, 256))

In [4]:
test = torch.zeros((1, 64, 256, 256))

In [5]:
e = sa(test)

  S = nn.Softmax()(c)


: 

: 

### Channel-Split Attention

In [None]:
class CSA(nn.Module):

    def __init__(self, in_channels):

        self.C = in_channels
        self.conv_1 = nn.Conv2d(in_channels, in_channels, 1)
        self.conv_3x3_1 = nn.Conv2d(in_channels / 2, in_channels / 2, 3, 1, "same")
        self.conv_3x3_2 = nn.Conv2d(in_channels / 2, in_channels / 2, 3, 1, "same")
        self.conv_3x3_3 = nn.Conv2d(in_channels, in_channels / 2, 3, 1, "same")

        self.group_1 = nn.Conv2d(in_channels / 2, in_channels, 1, 1)
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.group_2 = nn.Conv2d(in_channels, in_channels, 1, 1)
        self.softmax = nn.Softmax(dim=1)
        self.final_conv = nn.Conv2d(in_channels, in_channels, 1, 1)

    def forward(self, input):

        H = input.shape[2]
        W = input.shape[3]

        N = H * W

        F = self.conv_1(input)
        F_1, F_2 = F.split(self.C / 2, dim=1)

        F_1 = self.conv_3x3_1(F_1)
        F_2 = self.conv_3x3_2(F_2)
        F_2 = torch.concat([F_1, F_2], dim=1)
        F_2 = self.conv_3x3_3(F_2)

        F = torch.concat([F_1, F_2], dim=1)

        #Global maxpooling
        F = torch.mean(F, dim=(2,3))

        F = self.group_1(F)
        F = self.bn(F)
        F = self.relu(F)
        F = self.group_2(F)

        F_1_s, F_2_s = F.split(self.C / 2, dim=1)

        F_1_s = self.softmax(F_1)
        F_2_s = self.softmax(F_2)

        F_1_final = F_1 * F_1_s
        F_2_final = F_2 * F_2_s

        F_final = torch.concat([F_1_final, F_2_final], dim=1)
        F_final = self.final_conv(F_final)

        output = F_final + input

        return output

In [2]:
test = torch.zeros(4, 512, 32, 32)

In [3]:
ta = torch.concat([test, test], dim=1)

In [4]:
ta.shape

torch.Size([4, 1024, 32, 32])