In [None]:
import torch
import torch.nn as nn


def autopad(k, p=None, d=1):  
    # kernel, padding, dilation
    # 对输入的特征层进行自动padding，按照Same原则
    if d > 1:
        # actual kernel-size
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
    if p is None:
        # auto-pad
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
    return p

class SiLU(nn.Module):  
    # SiLU激活函数
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)
    
class Conv(nn.Module):
    # 标准卷积+标准化+激活函数
    default_act = SiLU() 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv   = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn     = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        self.act    = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))
    
class CoordAttention(nn.Module):
    def __init__(self, in_channels, reduction=32):
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, in_channels // reduction)

        self.conv1 = Conv(in_channels, mip, 1, 1, act=True)
        self.conv2 = Conv(mip, in_channels, 1, 1, act=False)
        
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()

        # Coordinate Attention: height and width pooling
        x_h = self.pool_h(x).permute(0, 1, 3, 2)  # (batch_size, channels, width, 1)
        x_w = self.pool_w(x)  # (batch_size, channels, 1, height)

        # Concatenate and transform
        y = torch.cat([x_h, x_w], dim=3)
        y = self.conv1(y)
        
        # Split and permute back
        x_h, x_w = torch.split(y, [h, w], dim=3)
        x_h = x_h.permute(0, 1, 3, 2)  # (batch_size, mip, height, 1)
        
        a_h = self.conv2(x_h).sigmoid()
        a_w = self.conv2(x_w).sigmoid()

        out = identity * a_h * a_w

        return out