In [1]:
import torch
import torch.nn as nn

In [38]:
class SRMLayer(nn.Module):
    def __init__(self, channel) -> None:
        super().__init__()
        # CFC: channel-wise fully connect layer
        self.cfc = nn.Conv1d(channel, channel, kernel_size=2, bias=False, groups=channel)
        self.bn = nn.BatchNorm1d(channel)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        
        # style pooling
        # AvgPool(全局平均池化)
        mean = x.view(b, c, -1).mean(-1).unsqueeze(-1)
        # stdPool(全局标准池化)
        std = x.view(b, c, -1).std(-1).unsqueeze(-1)
        u = torch.cat((mean, std), -1) # (b, c, 2)
        
        # style itegration
        # CFC (全连接层)
        z = self.cfc(u) # (b, c, 1)
        # BN (归一化)
        z = self.bn(z)
        # Sigmoid
        g = self.sigmoid(z)
        
        g = g.view(b, c, 1, 1)
        return x * g.expand_as(x)
        
        

In [39]:
class SRM_block(nn.Module):
    def __init__(self, channel) -> None:
        super().__init__()
        # CFC: channel-wise fully connect layer
        self.cfc = nn.Conv1d(channel, channel, kernel_size=2, bias=False, groups=channel) # 一维卷积
        self.bn = nn.BatchNorm1d(channel)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        # x = [2, 10, 4, 4]
        
        # style pooling
        # AvgPool(全局平均池化)
        # x.view(b, c, -1)                          [2, 10, 16]
        # x.view(b, c, -1).mean(-1)                 [2, 10]
        # x.view(b, c, -1).mean(-1).unsqueeze(-1)   [2, 10, 1]
        mean = x.view(b, c, -1).mean(-1).unsqueeze(-1)
        print(f'mean : {mean.shape}') 
        # stdPool(全局标准池化)
        std = x.view(b, c, -1).std(-1).unsqueeze(-1)
        u = torch.cat((mean, std), -1) # (b, c, 2)
        print(f'u : {u.shape}') # [2, 10, 2]
        
        # style itegration
        # CFC (全连接层)
        z = self.cfc(u) # (b, c, 1)
        print(f'z : {z.shape}')
        # BN (归一化)
        z = self.bn(z)
        # Sigmoid
        g = self.sigmoid(z)
        
        g = g.view(b, c, 1, 1)
        return x * g.expand_as(x)
        

In [40]:
x = torch.randn((2, 10, 4, 4))
model = SRM_block(10)
y = model(x)
print(f'y shape : {y.shape}')

mean : torch.Size([2, 10, 1])
u : torch.Size([2, 10, 2])
z : torch.Size([2, 10, 1])
y shape : torch.Size([2, 10, 4, 4])
