In [1]:
import torch 
from torch import nn
from torch.nn import functional as F 

In [20]:
class _GridAttentionBlockND(nn.Module):

    def __init__(
        self,
        in_channels,
        gating_channels,
        inter_channels=None,
        dimension=3,
        mode="concatenation",
        sub_sample_factor=(2, 2, 2),
    ) -> None:
        """空间注意力：注意力门

        Args:
            in_channels (_type_): 输入特征的通道数
            gating_channels (_type_): 门控信号 g 的通道数
            inter_channelse (_type_, optional): 中间特征的通道数. Defaults to None.
            dimension (int, optional): 卷积维度. Defaults to 3.
            mode (str, optional): _description_. Defaults to 'concatenation'.
            sub_sample_factor (tuple, optional): 卷积核大小和步长，同比缩小倍数. Defaults to (2, 2, 2).
        """
        super().__init__()

        assert dimension in [2, 3]
        assert mode in [
            "concatenation",
            "concatenation_debug",
            "concatenation_residual",
        ]

        # downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple):
            self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list):
            self.sub_sample_factor = tuple(sub_sample_factor)
        else:
            self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = "trilinear"
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = "bilinear"
        else:
            raise NotImplemented

        # output transform
        self.W = nn.Sequential(
            conv_nd(
                in_channels=self.in_channels,
                out_channels=self.in_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        # theta 需要进行上采样，将 W_x 变为 W_g
        self.W_x = conv_nd(
            in_channels=self.in_channels,
            out_channels=self.inter_channels,
            kernel_size=self.sub_sample_kernel_size,
            stride=self.sub_sample_factor,
            padding=0,
            bias=False,
        )
        # phi
        self.W_g = conv_nd(
            in_channels=self.gating_channels,
            out_channels=self.inter_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
        # inter
        self.psi = conv_nd(
            in_channels=self.inter_channels,
            out_channels=1,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

        # Define the operation
        if mode == "concatenation":
            self.operation_function = self._concatenation
        elif mode == "concatenation_debug":
            self.operation_function = self._concatenation_debug
        elif mode == "concatenation_residual":
            self.operation_function = self._concatenation_residual
        else:
            raise NotImplementedError("Unknown operation function.")

    def forward(self, x, g):
        """
        x: (b, c, t, h, w)
        g: (b, c_g, t', h', w')
        """
        output = self.operation_function(x, g)
        return output

    def _concatenation(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta => (b, c, t, h, w) -> -> (b, i_c, t/s1, h/s2, w/s3)
        theta_x = self.W_x(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        phi_g = self.W_g(g)
        # phi_g (b, i_c, t', h', w') -> phi_g (b, i_c, t/s1, h/s2, w/s3) ps: h'=h/s1
        phi_g = F.upsample(phi_g, size=theta_x_size[2:], mode=self.upsample_mode)
        # ReLU激活
        f = F.relu(theta_x + phi_g, inplace=True)

        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = F.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        # 进行上采样，将 h/s1 恢复到 h
        sigm_psi_f = F.upsmaple(
            sigm_psi_f, size=input_size[2:], model=self.upsample_mode
        )
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y)

        return W_y, sigm_psi_f

    def _concatenation_debug(self, x, g):
        input_size = x.size()
        # input_size : torch.Size([16, 16, 16, 16])
        print(f'input_size : {input_size}')
        batch_size = input_size[0]
        assert batch_size == g.size(0)
        
        # theta => (b, c, t, h, w) -> -> (b, i_c, t/s1, h/s2, w/s3)
        theta_x = self.W_x(x)
        theta_x_size = theta_x.size()
        # theta_x_size : torch.Size([16, 8, 8, 8])
        print(f'theta_x_size : {theta_x_size}')

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        phi_g = self.W_g(g)
        # phi_g (b, i_c, t', h', w') -> phi_g (b, i_c, t/s1, h/s2, w/s3) ps: h'=h/s1
        phi_g = F.upsample(phi_g, size=theta_x_size[2:], mode=self.upsample_mode)
        # phi_g : torch.Size([16, 8, 8, 8])
        print(f'phi_g : {phi_g.size()}')
        # ReLU激活
        f = F.softplus(theta_x + phi_g)

        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = F.sigmoid(self.psi(f))

        # upsample the attentions and multiply
        # 进行上采样，将 h/s1 恢复到 h
        sigm_psi_f = F.upsample(
            sigm_psi_f, size=input_size[2:], mode=self.upsample_mode
        )
        # sigm_psi_f : torch.Size([16, 1, 16, 16])
        print(f'sigm_psi_f : {sigm_psi_f.size()}')
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y)

        return W_y, sigm_psi_f
    
    def _concatenation_residual(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta => (b, c, t, h, w) -> -> (b, i_c, t/s1, h/s2, w/s3)
        theta_x = self.W_x(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        phi_g = self.W_g(g)
        # phi_g (b, i_c, t', h', w') -> phi_g (b, i_c, t/s1, h/s2, w/s3) ps: h'=h/s1
        phi_g = F.upsample(phi_g, size=theta_x_size[2:], mode=self.upsample_mode)
        # ReLU激活
        f = F.relu(theta_x + phi_g, inplace=True)

        #  psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        f = self.psi(f).view(batch_size, 1, -1)
        sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x_size[2:])

        # upsample the attentions and multiply
        # 进行上采样，将 h/s1 恢复到 h
        sigm_psi_f = F.upsmaple(
            sigm_psi_f, size=input_size[2:], model=self.upsample_mode
        )
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y)

        return W_y, sigm_psi_f

In [21]:
class GridAttentionBlock2D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
                 sub_sample_factor=(2,2)):
        super(GridAttentionBlock2D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=2, mode=mode,
                                                   sub_sample_factor=sub_sample_factor,
                                                   )


class GridAttentionBlock3D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
                 sub_sample_factor=(2,2,2)):
        super(GridAttentionBlock3D, self).__init__(in_channels,
                                                   inter_channels=inter_channels,
                                                   gating_channels=gating_channels,
                                                   dimension=3, mode=mode,
                                                   sub_sample_factor=sub_sample_factor,
                                                   )

In [22]:
x = torch.randn(16, 16, 16, 16)
g = torch.randn(16, 12, 8, 8)
model = GridAttentionBlock2D(
    in_channels=16, gating_channels=12, inter_channels=8, mode="concatenation_debug"
)
y = model(x, g)

input_size : torch.Size([16, 16, 16, 16])
theta_x_size : torch.Size([16, 8, 8, 8])
phi_g : torch.Size([16, 8, 8, 8])
sigm_psi_f : torch.Size([16, 1, 16, 16])


