In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F 

In [2]:
""" 
h_sigmoid 类是超饱和 sigmoid 激活函数的自定义实现，
这是一种将输出限制在特定范围内的 sigmoid 变体。
这个类使用 ReLU6 激活函数来实现这种限制。
ReLU6 函数的输出被限制在 [-6, 6] 的范围内，
这可以看作是 sigmoid 函数的一种改进，因为它可以提供更好的数值稳定性。
"""
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True) -> None:
        super().__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6

In [3]:
""" 
h_swish 类是 Swish 激活函数的 PyTorch 自定义实现。
Swish 激活函数定义为输入与输入的的`双曲正切函数`的乘积。
这个实现使用了 h_sigmoid 函数作为一个组件，这是 
sigmoid 函数的一个变体，可以用于 inplace 操作以节省内存。
"""
class h_swish(nn.Module):
    def __init__(self, inplace=True) -> None:
        super().__init__()
        self.sigmoid = h_sigmoid(inplace=True)
    def forward(self, x):
        return x * self.sigmoid(x)

In [8]:
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=22) -> None:
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        
        mip = max(8, inp // reduction)
        
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        
        n, c, h ,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        
        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        
        out = identity * a_w * a_h
        return out

In [13]:
class CoordAtt_block(nn.Module):
    def __init__(self, inp, oup, reduction=22) -> None:
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        
        mip = max(8, inp // reduction)
        
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        
        n, c, h ,w = x.size()
        print(f'x : {x.size()}')
        x_h = self.pool_h(x)
        print(f'x_h : {x_h.size()}')
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        print(f'x_w : {x_w.size()}')
        
        y = torch.cat([x_h, x_w], dim=2)
        print(f'y : {y.size()}')
        
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        print(f'y_conv1 : {y.size()}')
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        
        print(f'x_h : {x_h.size()}')
        print(f'x_w : {x_w.size()}')
        
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        
        print(f'a_h : {a_h.size()}')
        print(f'a_w : {a_w.size()}')
        
        out = identity * a_w * a_h
        return out

In [14]:
x = torch.randn(32, 16, 8, 9)
model = CoordAtt_block(16, 16)
y = model(x)

x : torch.Size([32, 16, 8, 9])
x_h : torch.Size([32, 16, 8, 1])
x_w : torch.Size([32, 16, 9, 1])
y : torch.Size([32, 16, 17, 1])
y_conv1 : torch.Size([32, 8, 17, 1])
x_h : torch.Size([32, 8, 8, 1])
x_w : torch.Size([32, 8, 1, 9])
a_h : torch.Size([32, 16, 8, 1])
a_w : torch.Size([32, 16, 1, 9])


In [7]:
x = torch.randn(16, 8, 3, 4)
y_h = nn.AdaptiveAvgPool2d((None, 1))(x)
y_w = nn.AdaptiveAvgPool2d((1, None))(x)
y_c = nn.AdaptiveAvgPool2d(1)(x)
print(f'{y_h.size()}\n{y_w.size()}\n{y_c.size()}')

torch.Size([16, 8, 3, 1])
torch.Size([16, 8, 1, 4])
torch.Size([16, 8, 1, 1])
