In [1]:
import torch
import torch.nn as nn 

In [2]:
class sSE(nn.Module):
    def __init__(self, in_channels) -> None:
        super().__init__()
        self.Conv1x1 = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        y = self.Conv1x1(x) # x[bs, c, h, w] -> y[bs, 1, h, w]
        y = self.sigmoid(y)
        return x * y

In [3]:
class cSE(nn.Module):
    def __init__(self, in_channels) -> None:
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.Conv_Squeeze = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False)
        self.Conv_Excitation = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        y = self.avgpool(x) # [bs, c, h, w] -> [bs, c, 1, 1]
        y = self.Conv_Squeeze(x) # [bs, c/2, 1, 1]
        y= self.sigmoid(y)
        return x * y.expand_as(x)

In [4]:
class scSE(nn.Module):
    def __init__(self, in_channels) -> None:
        super().__init__()
        self.cSE = cSE(in_channels)
        self.sSE = sSE(in_channels)
        
    def forward(self, x):
        y_sSE = self.sSE(x)
        y_cSE = self.cSE(x)
        return y_sSE + y_cSE

In [5]:
bs, c, h, w = 10, 3, 64, 64
in_tensor = torch.ones(bs, c, h, w)

cs_se = scSE(c)
print("in shape:",in_tensor.shape)
out_tensor = cs_se(in_tensor)
print("out shape:", out_tensor.shape)

in shape: torch.Size([10, 3, 64, 64])
out shape: torch.Size([10, 3, 64, 64])
