In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

# 残差块

In [3]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [5]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [8]:
# 应用于resnet中的残差模块
net = BasicBlock(256, 256)
display(net)

BasicBlock(
  (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

# Bottleneck Residual Block

In [10]:
class Bottleneck(nn.Module):

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [11]:
net = Bottleneck(256, 256)
display(net)

Bottleneck(
  (conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

# Inception

In [12]:
class InceptionModule(nn.Module):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj):
        super(InceptionModule, self).__init__()
        
        # 1x1 卷积
        self.conv1x1 = nn.Conv2d(in_channels, out_1x1, kernel_size=1)
        
        # 3x3 卷积和降维卷积
        self.conv3x3_reduce = nn.Conv2d(in_channels, red_3x3, kernel_size=1)
        self.conv3x3 = nn.Conv2d(red_3x3, out_3x3, kernel_size=3, padding=1)
        
        # 5x5 卷积和降维卷积
        self.conv5x5_reduce = nn.Conv2d(in_channels, red_5x5, kernel_size=1)
        self.conv5x5 = nn.Conv2d(red_5x5, out_5x5, kernel_size=5, padding=2)
        
        # 池化和降维卷积
        self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.pool_proj = nn.Conv2d(in_channels, pool_proj, kernel_size=1)

    def forward(self, x):
        output_1x1 = torch.relu(self.conv1x1(x))
        
        output_3x3 = torch.relu(self.conv3x3_reduce(x))
        output_3x3 = torch.relu(self.conv3x3(output_3x3))
        
        output_5x5 = torch.relu(self.conv5x5_reduce(x))
        output_5x5 = torch.relu(self.conv5x5(output_5x5))
        
        output_pool = self.pool(x)
        output_pool = torch.relu(self.pool_proj(output_pool))
        
        # 拼接所有特征图
        outputs = [output_1x1, output_3x3, output_5x5, output_pool]
        return torch.cat(outputs, dim=1)  # 在通道维度上拼接


In [13]:
# 假设输入通道数为 3，输出各部分通道数为 64, 128, 32, 32, 32, 32
in_channels = 3
out_1x1 = 64
red_3x3 = 128
out_3x3 = 32
red_5x5 = 32
out_5x5 = 32
pool_proj = 32

# 创建Inception模块
inception = InceptionModule(in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj)

display("网络的形状是 : ", inception)

# 随机生成一个输入张量，大小为 (batch_size, channels, height, width)
batch_size = 1
height, width = 64, 64
inputs = torch.randn(batch_size, in_channels, height, width)

# 使用Inception模块进行前向传播
outputs = inception(inputs)


# 输出每个模块的形状
display("输出张量形状:", outputs.shape)


'网络的形状是 : '

InceptionModule(
  (conv1x1): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv3x3_reduce): Conv2d(3, 128, kernel_size=(1, 1), stride=(1, 1))
  (conv3x3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5x5_reduce): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv5x5): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
  (pool_proj): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
)

'输出张量形状:'

torch.Size([1, 160, 64, 64])

# Non-Local Block

In [15]:
class NLBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, mode='embedded', 
                 dimension=3, bn_layer=True):
        """Implementation of Non-Local Block with 4 different pairwise functions but doesn't include subsampling trick
        args:
            in_channels: original channel size (1024 in the paper)
            inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper)
            mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation
            dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal)
            bn_layer: whether to add batch norm
        """
        super(NLBlockND, self).__init__()

        assert dimension in [1, 2, 3]
        
        if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
            raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
            
        self.mode = mode
        self.dimension = dimension

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        # the channel size is reduced to half inside the block
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        
        # assign appropriate convolutional, max pool, and batch norm layers for different dimensions
        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        # function g in the paper which goes through conv. with kernel size 1
        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)

        # add BatchNorm layer after the last conv layer
        if bn_layer:
            self.W_z = nn.Sequential(
                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
                    bn(self.in_channels)
                )
            # from section 4.1 of the paper, initializing params of BN ensures that the initial state of non-local block is identity mapping
            nn.init.constant_(self.W_z[1].weight, 0)
            nn.init.constant_(self.W_z[1].bias, 0)
        else:
            self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)

            # from section 3.3 of the paper by initializing Wz to 0, this block can be inserted to any existing architecture
            nn.init.constant_(self.W_z.weight, 0)
            nn.init.constant_(self.W_z.bias, 0)

        # define theta and phi for all operations except gaussian
        if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate":
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
        
        if self.mode == "concatenate":
            self.W_f = nn.Sequential(
                    nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1),
                    nn.ReLU()
                )
            
    def forward(self, x):
        """
        args
            x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1
        """

        batch_size = x.size(0)
        
        # (N, C, THW)
        # this reshaping and permutation is from the spacetime_nonlocal function in the original Caffe2 implementation
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        if self.mode == "gaussian":
            theta_x = x.view(batch_size, self.in_channels, -1)
            phi_x = x.view(batch_size, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)

        elif self.mode == "embedded" or self.mode == "dot":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            f = torch.matmul(theta_x, phi_x)

        elif self.mode == "concatenate":
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
            
            h = theta_x.size(2)
            w = phi_x.size(3)
            theta_x = theta_x.repeat(1, 1, 1, w)
            phi_x = phi_x.repeat(1, 1, h, 1)
            
            concat = torch.cat([theta_x, phi_x], dim=1)
            f = self.W_f(concat)
            f = f.view(f.size(0), f.size(2), f.size(3))
        
        if self.mode == "gaussian" or self.mode == "embedded":
            f_div_C = F.softmax(f, dim=-1)
        elif self.mode == "dot" or self.mode == "concatenate":
            N = f.size(-1) # number of position in x
            f_div_C = f / N
        
        y = torch.matmul(f_div_C, g_x)
        
        # contiguous here just allocates contiguous chunk of memory
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        
        W_y = self.W_z(y)
        # residual connection
        z = W_y + x

        return z


In [16]:
net = NLBlockND(256)
display(net)

NLBlockND(
  (g): Conv3d(256, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (W_z): Sequential(
    (0): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (theta): Conv3d(256, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (phi): Conv3d(256, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)

# Spatial Transformer

In [17]:
class SpatialTransformer(nn.Module):
    def __init__(self, in_channels, spatial_size):
        super(SpatialTransformer, self).__init__()
        self.localization_network = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # 线性层用于预测空间变换参数
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * spatial_size * spatial_size, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)  # 3个参数（旋转、缩放、平移）每个有2个值
        )

        # 采样器
        self.sampler = torch.nn.functional.affine_grid

    def forward(self, x):
        # 获取变换参数
        theta = self.localization_network(x)
        theta = theta.view(-1, 10 * 4 * 4)  # 假设输入大小是 (10, 4, 4)
        theta = self.fc_loc(theta)
        theta = theta.view(-1, 2, 3)  # 调整参数形状为 (batch_size, 2, 3)

        # 生成采样网格
        grid = self.sampler(theta, x.size())

        # 执行空间变换
        x_transformed = F.grid_sample(x, grid)

        return x_transformed

In [18]:
net = SpatialTransformer(3, 3)
display(net)

SpatialTransformer(
  (localization_network): Sequential(
    (0): Conv2d(3, 8, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU(inplace=True)
  )
  (fc_loc): Sequential(
    (0): Linear(in_features=90, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=6, bias=True)
  )
)

# ResNeXt Block

In [19]:
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cardinality=32, stride=1):
        super(ResNeXtBlock, self).__init__()

        # 每个分组内部的通道数
        group_channels = out_channels // cardinality
        
        # 第一个卷积层
        self.conv1 = nn.Conv2d(in_channels, group_channels, kernel_size=1, stride=stride, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(group_channels)
        self.relu = nn.ReLU(inplace=True)

        # 分组卷积
        self.conv2 = nn.Conv2d(group_channels, group_channels, kernel_size=3, stride=1, padding=1, groups=cardinality, bias=False)
        self.bn2 = nn.BatchNorm2d(group_channels)

        # 最后一个卷积层
        self.conv3 = nn.Conv2d(group_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        # 如果输入通道数和输出通道数不一致，使用 1x1 卷积调整
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        shortcut = self.shortcut(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)

        x += shortcut
        x = self.relu(x)

        return x

In [20]:
net = ResNeXtBlock(3, 3)
display(net)



ResNeXtBlock(
  (conv1): Conv2d(3, 0, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(0, 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
  (bn2): BatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(0, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (shortcut): Sequential()
)

# Channel Attention

In [3]:
def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):

    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
        self.conv = nn.Conv2d(gate_channels, gate_channels // 2, kernel_size=(1, 1), stride=1)

    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return self.conv(x * scale)
    
    

In [5]:
net = ChannelGate(256*2)
display(net)

ChannelGate(
  (mlp): Sequential(
    (0): Flatten()
    (1): Linear(in_features=256, out_features=16, bias=True)
    (2): ReLU()
    (3): Linear(in_features=16, out_features=256, bias=True)
  )
  (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
)

In [35]:
summary(net, (256*10, 128, 128), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                 [-1, 2560]               0
            Linear-2                  [-1, 160]         409,760
              ReLU-3                  [-1, 160]               0
            Linear-4                 [-1, 2560]         412,160
           Flatten-5                 [-1, 2560]               0
            Linear-6                  [-1, 160]         409,760
              ReLU-7                  [-1, 160]               0
            Linear-8                 [-1, 2560]         412,160
            Conv2d-9       [-1, 1280, 128, 128]       3,278,080
Total params: 4,921,920
Trainable params: 4,921,920
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 160.00
Forward/backward pass size (MB): 160.08
Params size (MB): 18.78
Estimated Total Size (MB): 338.86
--------------------------------

In [36]:
x1 = torch.randn(1, 256*10, 128, 128)
net(x1).shape

torch.Size([1, 1280, 128, 128])

# SE block

In [3]:
from inspect import isfunction

In [7]:
class Swish(nn.Module):
    """
    Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941.
    """
    def forward(self, x):
        return x * torch.sigmoid(x)
    
class HSwish(nn.Module):
    """
    H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244.

    Parameters:
    ----------
    inplace : bool
        Whether to use inplace version of the module.
    """
    def __init__(self, inplace=False):
        super(HSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
    
class HSigmoid(nn.Module):
    """
    Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,'
    https://arxiv.org/abs/1905.02244.
    """
    def forward(self, x):
        return F.relu6(x + 3.0, inplace=True) / 6.0

In [11]:
def conv1x1(in_channels,
            out_channels,
            stride=1,
            groups=1,
            bias=False):
    """
    Convolution 1x1 layer.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    stride : int or tuple/list of 2 int, default 1
        Strides of the convolution.
    groups : int, default 1
        Number of groups.
    bias : bool, default False
        Whether the layer uses a bias vector.
    """
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=1,
        stride=stride,
        groups=groups,
        bias=bias)


def conv3x3(in_channels,
            out_channels,
            stride=1,
            padding=1,
            dilation=1,
            groups=1,
            bias=False):
    """
    Convolution 3x3 layer.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    stride : int or tuple/list of 2 int, default 1
        Strides of the convolution.
    padding : int or tuple/list of 2 int, default 1
        Padding value for convolution layer.
    groups : int, default 1
        Number of groups.
    bias : bool, default False
        Whether the layer uses a bias vector.
    """
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
        bias=bias)


def depthwise_conv3x3(channels,
                      stride):
    """
    Depthwise convolution 3x3 layer.

    Parameters:
    ----------
    channels : int
        Number of input/output channels.
    strides : int or tuple/list of 2 int
        Strides of the convolution.
    """
    return nn.Conv2d(
        in_channels=channels,
        out_channels=channels,
        kernel_size=3,
        stride=stride,
        padding=1,
        groups=channels,
        bias=False)


In [8]:
def get_activation_layer(activation):
    """
    Create activation layer from string/function.

    Parameters:
    ----------
    activation : function, or str, or nn.Module
        Activation function or name of activation function.

    Returns
    -------
    nn.Module
        Activation layer.
    """
    assert (activation is not None)
    if isfunction(activation):
        return activation()
    elif isinstance(activation, str):
        if activation == "relu":
            return nn.ReLU(inplace=True)
        elif activation == "relu6":
            return nn.ReLU6(inplace=True)
        elif activation == "swish":
            return Swish()
        elif activation == "hswish":
            return HSwish(inplace=True)
        else:
            raise NotImplementedError()
    else:
        assert (isinstance(activation, nn.Module))
        return activation



In [9]:
class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,' https://arxiv.org/abs/1709.01507.

    Parameters:
    ----------
    channels : int
        Number of channels.
    reduction : int, default 16
        Squeeze reduction value.
    approx_sigmoid : bool, default False
        Whether to use approximated sigmoid function.
    activation : function, or str, or nn.Module
        Activation function or name of activation function.
    """
    def __init__(self,
                 channels,
                 reduction=16,
                 approx_sigmoid=False,
                 activation=(lambda: nn.ReLU(inplace=True))):
        super(SEBlock, self).__init__()
        mid_cannels = channels // reduction

        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = conv1x1(
            in_channels=channels,
            out_channels=mid_cannels,
            bias=True)
        self.activ = get_activation_layer(activation)
        self.conv2 = conv1x1(
            in_channels=mid_cannels,
            out_channels=channels,
            bias=True)
        self.sigmoid = HSigmoid() if approx_sigmoid else nn.Sigmoid()

    def forward(self, x):
        w = self.pool(x)
        w = self.conv1(w)
        w = self.activ(w)
        w = self.conv2(w)
        w = self.sigmoid(w)
        x = x * w
        return x

In [12]:
se = SEBlock(256)
display(se)

SEBlock(
  (pool): AdaptiveAvgPool2d(output_size=1)
  (conv1): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))
  (activ): ReLU(inplace=True)
  (conv2): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid): Sigmoid()
)

In [15]:
x = (256, 128, 128)
summary(se, x, device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
 AdaptiveAvgPool2d-1            [-1, 256, 1, 1]               0
            Conv2d-2             [-1, 16, 1, 1]           4,112
              ReLU-3             [-1, 16, 1, 1]               0
            Conv2d-4            [-1, 256, 1, 1]           4,352
           Sigmoid-5            [-1, 256, 1, 1]               0
Total params: 8,464
Trainable params: 8,464
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 16.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.03
Estimated Total Size (MB): 16.04
----------------------------------------------------------------


In [20]:
x1 = torch.randn(1, 256, 128, 128)
se(x1).shape

torch.Size([1, 256, 128, 128])

In [26]:
# 模型设计
# 输入
x = torch.randn(1, 256*5, 128, 128)

In [42]:
# 模型
cnn = SEBlock(256*5)
csa = ChannelGate(256*10)

In [38]:
x1 = cnn(x)
x1 = torch.add(x1, x)
display(x1.shape)

torch.Size([1, 1280, 128, 128])

In [40]:
# 对x 和 x1进行concat
x2 = torch.cat((x1, x), dim = 1)
display(x2.shape)

torch.Size([1, 2560, 128, 128])

In [43]:
x2 = csa(x2)
display(x2.shape)



torch.Size([1, 1280, 128, 128])

In [44]:
x3 = torch.add(x2, x)
display(x2.shape)

torch.Size([1, 1280, 128, 128])

In [51]:
class CCMBlock(nn.Module):
    def __init__(self, input_channel):
        super(CCMBlock, self).__init__()
        self.cnn = SEBlock(input_channel)
        self.csa = ChannelGate(input_channel * 2)
        self.conv = nn.Conv2d(input_channel, input_channel, kernel_size=(1, 1), stride=1)

    def forward(self, x):
        x1 = self.cnn(x)
        x1 = torch.add(x1, x)
        x2 = torch.cat((x1, x), dim=1)
        x2 = self.csa(x2)
        return self.conv(x2)

In [52]:
net = CCMBlock(256*5)
display(net)

CCMBlock(
  (cnn): SEBlock(
    (pool): AdaptiveAvgPool2d(output_size=1)
    (conv1): Conv2d(1280, 80, kernel_size=(1, 1), stride=(1, 1))
    (activ): ReLU(inplace=True)
    (conv2): Conv2d(80, 1280, kernel_size=(1, 1), stride=(1, 1))
    (sigmoid): Sigmoid()
  )
  (csa): ChannelGate(
    (mlp): Sequential(
      (0): Flatten()
      (1): Linear(in_features=2560, out_features=160, bias=True)
      (2): ReLU()
      (3): Linear(in_features=160, out_features=2560, bias=True)
    )
    (conv): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
)

In [53]:
summary(net, (256*5, 128, 128), device="cpu")



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
 AdaptiveAvgPool2d-1           [-1, 1280, 1, 1]               0
            Conv2d-2             [-1, 80, 1, 1]         102,480
              ReLU-3             [-1, 80, 1, 1]               0
            Conv2d-4           [-1, 1280, 1, 1]         103,680
           Sigmoid-5           [-1, 1280, 1, 1]               0
           SEBlock-6       [-1, 1280, 128, 128]               0
           Flatten-7                 [-1, 2560]               0
            Linear-8                  [-1, 160]         409,760
              ReLU-9                  [-1, 160]               0
           Linear-10                 [-1, 2560]         412,160
          Flatten-11                 [-1, 2560]               0
           Linear-12                  [-1, 160]         409,760
             ReLU-13                  [-1, 160]               0
           Linear-14                 [-

In [54]:
print(torch.__version__)

1.13.0
