# layer

## ConvAftermath

In [None]:
class ConvAftermath(nn.Module):
    """
    use_bias:数据添加偏置
    use_scale:数据放缩
    """
    def __init__(self, in_channels, out_channels, use_bias=True, use_scale=True, norm=None, act=None):
        super(ConvAftermath, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_bias = use_bias
        self.use_scale = use_scale
        self.norm = norm
        self.act = act
        self.b = None
        self.s = None

    def forward(self, input):
        net = input
        if self.use_bias and self.b is not None:
            net = net + self.use_bias
        if self.use_scale and self.s is not None:
            net = net * self.s
        if self.norm is not None:
            net = self.norm(net)
        if self.act is not None:
            net = self.act(net)
        return net

疑问：
self.s=None?

## Conv_ReflectPad

padding_mode:zeros（常量填充）（默认0填充）、reflect（反射填充）、replicate（复制填充）、circular（循环填充）。

In [None]:
class Conv2D_ReflectPad(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, use_bias=True, use_scale=True, norm=None,
                 act=None,
                 padding='same', padding_algorithm="reflect"):
        super(Conv2D_ReflectPad, self).__init__()
        self.padding = padding
        self.padding_algorithm = padding_algorithm
        self.act = act
        self.norm = norm
        self.use_scale = use_scale
        self.use_bias = use_bias
        self.strides = stride
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.atrous_rate = 1
        if padding == 'same':
            self.padding = self.kernel_size // 2 if self.strides == 1 else 0
            self.pad_flag = True
        else:
            self.padding = padding

        self.conv = nn.Conv2d(in_channels=self.in_channels,
                              out_channels=self.out_channels,
                              kernel_size=self.kernel_size,
                              stride=self.strides,
                              bias=False,
                              padding=self.padding,
                              padding_mode=self.padding_algorithm)
        self.conv_aftermath = ConvAftermath(in_channels=self.out_channels,
                                            out_channels=self.out_channels,
                                            use_bias=self.use_bias,
                                            use_scale=self.use_scale,
                                            norm=self.norm,
                                            act=self.act)

    def forward(self, input):
        x = self.conv(input)
        y = self.conv_aftermath(x)
        return y

参考：
+ padding_mode: https://blog.csdn.net/weixin_42211626/article/details/122542323

# 模块

## DCB

<div align=center>
<img src=.\img\DCB.png />
</div>

In [None]:
class AtrousBlockPad2(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, use_bias, use_scale, activation,
                 needs_projection = False,atrousBlock=[1, 2, 4, 8]):
        super(AtrousBlockPad2, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.use_bias = use_bias
        self.use_scale = use_scale
        self.activation = activation
        self.atrousBlocks = atrousBlock
        # self.needs_projection = self.stride > 1
        self.dims_match = self.in_channels != self.out_channels
        # self.needs_projection = self.needs_projection or self.dims_match
        self.needs_projection = needs_projection

        if self.needs_projection:
            self.projection = Conv2D_ReflectPad(in_channels=self.in_channels,
                                                out_channels=self.out_channels,
                                                kernel_size=1,
                                                stride=self.stride,
                                                use_bias=self.use_bias,
                                                act=self.activation)
        self.atrous_layers = []

        for i in range(4):
            self.atrous_layers.append(AtrousConv2D_ReflectPad(in_channels=self.out_channels,
                                                              out_channels=int(self.out_channels / 2),
                                                              kernel_size=self.kernel_size,
                                                              stride=self.stride,
                                                              dilation=atrousBlock[i],
                                                              use_bias=self.use_bias,
                                                              use_scale=self.use_scale,
                                                              act=self.activation))
        self.atrous_layers = nn.Sequential(*self.atrous_layers)

        self.conv1 = Conv2D_ReflectPad(in_channels=self.out_channels * 2,
                                       out_channels=self.out_channels,
                                       kernel_size=self.kernel_size,
                                       stride=self.stride,
                                       use_bias=self.use_bias,
                                       use_scale=self.use_scale,
                                       act=self.activation
                                       )

    def forward(self, input):
        if self.needs_projection:
            input = self.projection(input)

        x1 = self.atrous_layers[0](input)
        x2 = self.atrous_layers[1](input)
        x3 = self.atrous_layers[2](input)
        x4 = self.atrous_layers[3](input)

        x = torch.cat((x1, x2, x3, x4), 1)
        x5 = self.conv1(x)

        return input + x5

## 小波变换模块

<div align=center>
<img src=.\img\WRRM.png />
</div>

In [None]:
def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4
    return torch.cat((x_LL, x_HL, x_LH, x_HH), 0)


# 使用哈尔 haar 小波变换来实现二维离散小波
def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    #print([in_batch, in_channel, in_height, in_width])
    out_batch, out_channel, out_height, out_width = int(in_batch / r ** 2), int(
        in_channel), r * in_height, r * in_width
    x1 = x[0:out_batch, :, :, :] / 2
    x2 = x[out_batch:out_batch * 2, :, :, :] / 2
    x3 = x[out_batch * 2:out_batch * 3, :, :, :] / 2
    x4 = x[out_batch * 3:out_batch * 4, :, :, :] / 2

    h = torch.zeros([out_batch, out_channel, out_height,
                     out_width]).float().cuda()

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

    return h

# 二维离散小波
class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False  # 信号处理，非卷积运算，不需要进行梯度求导

    def forward(self, x):
        return dwt_init(x)


# 逆向二维离散小波
class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)
class SRCNN(nn.Module):
    def __init__(self, num_channels, out_channels):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, out_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        self.DWT = DWT()
        self.IDWT = IWT()

    def forward(self, x):
        x = self.DWT(x)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        x = self.IDWT(x)

        return x

小波变换：为了获取高频的纹理信息。

## efficientattention


<div align=center>
<img src=.\img\efficientattention.png width="1000" />
</div>


In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, in_channels, key_channels, head_count, value_channels):
        super().__init__()
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Conv2d(in_channels, key_channels, 1)
        self.queries = nn.Conv2d(in_channels, key_channels, 1)
        self.values = nn.Conv2d(in_channels, value_channels, 1)
        self.reprojection = nn.Conv2d(value_channels, in_channels, 1)

    def forward(self, input_):
        n, _, h, w = input_.size()
        keys = self.keys(input_).reshape((n, self.key_channels, h * w))
        queries = self.queries(input_).reshape(n, self.key_channels, h * w)
        values = self.values(input_).reshape((n, self.value_channels, h * w))
        head_key_channels = self.key_channels // self.head_count
        head_value_channels = self.value_channels // self.head_count

        attended_values = []
        for i in range(self.head_count):
            key = f.softmax(keys[
                            :,
                            i * head_key_channels: (i + 1) * head_key_channels,
                            :
                            ], dim=2)
            query = f.softmax(queries[
                              :,
                              i * head_key_channels: (i + 1) * head_key_channels,
                              :
                              ], dim=1)
            value = values[
                :,
                i * head_value_channels: (i + 1) * head_value_channels,
                :
            ]
            context = key @ value.transpose(1, 2)
            attended_value = (
                context.transpose(1, 2) @ query
            ).reshape(n, head_value_channels, h, w)
            attended_values.append(attended_value)

        aggregated_values = torch.cat(attended_values, dim=1)
        reprojected_value = self.reprojection(aggregated_values)
        attention = reprojected_value + input_

        return attention

高效注意力是一种注意力机制，它极大地优化了内存和计算效率，同时保留了与传统的点积注意力完全相同的表现力。上图比较了这两种类型的注意力。优点：1、使用更少的资源实现相同的准确率。2、相同的资源达到更高的性能。

参考文献：
+ https://github.com/cmsflash/efficient-attention
+ https://www.bilibili.com/video/BV1Gt4y1Y7E3
+ Efficient Attention: Attention with Linear Complexities

## ContextBlock

<div align=center>
<img src=.\img\Contextblock.png width="500" />
</div>


In [None]:
class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                # nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                # nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out


GCNet充分结合了Non-local全局上下文建模能力强和SENet省计算量的优点

参考:
+ https://github.com/xvjiarui/GCNet
+ GCNet: Non-Local Networks Meet Squeeze-Excitation Networks and Beyond

# Model

## basic model

<div align=center>
<img src=.\img\netstructure.png />
</div>

In [None]:
class AtrousNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks=10, max_global_stride=8, pad_to_fit_global_stride=True,
                 d_mult=16,
                 activation=nn.ELU(alpha=1.0, inplace=True),
                 atrousDim=[1, 2, 4, 8]):
        super(AtrousNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_blocks = num_blocks
        self.max_global_stride = max_global_stride
        self.pad_to_fit_global_stride = pad_to_fit_global_stride
        self.d_mult = d_mult
        self.activation = activation

        self.downsampling_layers = []
        self.downsampling_layers.append(Conv2D_ReflectPad(in_channels=self.in_channels,
                                                          out_channels=self.d_mult,
                                                          kernel_size=7,
                                                          stride=1,
                                                          use_bias=True,
                                                          use_scale=True,
                                                          padding="same",
                                                          act=self.activation))
        self.downsampling_layers.append(Conv2D_ReflectPad(in_channels=self.d_mult,
                                                          out_channels=self.d_mult * 2,
                                                          kernel_size=3,
                                                          stride=2,
                                                          use_scale=True,
                                                          use_bias=True,
                                                          padding=1,
                                                          act=self.activation))
        self.downsampling_layers = nn.Sequential(*self.downsampling_layers)

        self.blocks = []
        for x in range(num_blocks):
            if x == 0:
                self.blocks.append(AtrousBlockPad2(in_channels=self.d_mult * 2,
                                                   out_channels=self.d_mult * 4,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   activation=self.activation,
                                                   atrousBlock=atrousDim,
                                                   needs_projection=True))
            else:
                self.blocks.append(AtrousBlockPad2(in_channels=self.d_mult * 4,
                                                   out_channels=self.d_mult * 4,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   activation=self.activation,
                                                   atrousBlock=atrousDim))

        self.blocks = nn.Sequential(*self.blocks)

        self.upsampling_layers = []
        self.upsampling_layers.append(nn.ConvTranspose2d(in_channels=self.d_mult * 6,  # Error: should be 4
                                                         out_channels=self.d_mult * 2,
                                                         kernel_size=3,
                                                         stride=2,
                                                         padding=1,
                                                         output_padding=1,
                                                         bias=True))
        self.upsampling_layers.append(nn.ELU(alpha=1.0, inplace=True))
        self.upsampling_layers = nn.Sequential(*self.upsampling_layers)

        self.output_layer = []
        self.output_layer.append(Conv2D_ReflectPad(in_channels=self.d_mult * 3,
                                                   out_channels=self.d_mult,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   padding='same',
                                                   act=self.activation))
        self.output_layer.append(Conv2D_ReflectPad(in_channels=self.d_mult,
                                                   out_channels=self.out_channels,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   padding='same',
                                                   act=None))
        self.output_layer = nn.Sequential(*self.output_layer)

    def forward(self, input_data):
        downs = []
        net = input_data
        for x in range(len(self.downsampling_layers)):
            net = self.downsampling_layers[x](net)
            downs.append(net)

        for x in range(len(self.blocks)):
            net = self.blocks[x](net)

        for x in range(len(self.upsampling_layers)):
            idx = len(downs) - x - 1
            net = torch.cat((net, downs[idx]), 1)
            net = self.upsampling_layers[x](net)

        for x in range(len(self.output_layer)):
            net = self.output_layer[x](net)

        return input_data + net

In [None]:
class AtrousNet_SRCNN_tail(nn.Module):
    """
    num_blocks: the num of DCB
    max_global_stride: no use?
    pad_to_fit_global_stride: no use?
    d_mult: channel
    efficientattention:
    gcattention:
    """
    def __init__(self,
                in_channels,
                out_channels,
                num_blocks=10,
                max_global_stride=8,
                pad_to_fit_global_stride=True,
                d_mult=32,
                activation=nn.ELU(alpha=1.0, inplace=True),
                atrousDim=[[1, 2, 4, 8],[1, 3, 5, 7]],
                efficientattention=False,
                gcattention=False,
                ):
        super(AtrousNet_SRCNN_tail, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_blocks = num_blocks
        self.max_global_stride = max_global_stride
        self.pad_to_fit_global_stride = pad_to_fit_global_stride
        self.d_mult = d_mult
        self.activation = activation
        self.efficientattention = efficientattention
        self.gcattention = gcattention

        self.downsampling_layers = []
        self.downsampling_layers.append(Conv2D_ReflectPad(in_channels=self.in_channels,
                                                          out_channels=self.d_mult,
                                                          kernel_size=7,
                                                          stride=1,
                                                          use_bias=True,
                                                          use_scale=True,
                                                          padding="same",
                                                          act=self.activation))
        self.downsampling_layers.append(Conv2D_ReflectPad(in_channels=self.d_mult,
                                                          out_channels=self.d_mult * 2,
                                                          kernel_size=3,
                                                          stride=2,
                                                          use_scale=True,
                                                          use_bias=True,
                                                          padding=1,
                                                          act=self.activation))
        self.downsampling_layers = nn.Sequential(*self.downsampling_layers)
        self.SRCNN = SRCNN(self.d_mult*7, self.out_channels)
        self.blocks = []
        for x in range(num_blocks):
            if x == 0:
                self.blocks.append(AtrousBlockPad2(in_channels=self.d_mult * 2,
                                                   out_channels=self.d_mult * 4,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   activation=self.activation,
                                                   atrousBlock=atrousDim[0],
                                                   needs_projection=True))
            elif x != num_blocks-1:
                self.blocks.append(AtrousBlockPad2(in_channels=self.d_mult * 4,
                                                   out_channels=self.d_mult * 4,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   activation=self.activation,
                                                   atrousBlock=atrousDim[0]))
            else:
                self.blocks.append(AtrousBlockPad2(in_channels=self.d_mult * 4,
                                                   out_channels=self.d_mult * 4,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   activation=self.activation,
                                                   atrousBlock=atrousDim[1]))

        self.blocks = nn.Sequential(*self.blocks)

        self.upsampling_layers = []
        self.upsampling_layers.append(nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=True))
        self.upsampling_layers.append(nn.ELU(alpha=1.0, inplace=True))
        self.upsampling_layers = nn.Sequential(*self.upsampling_layers)

        self.output_layer = []
        # modify output later from 3 * d_multi to 7 * multi
        self.output_layer.append(Conv2D_ReflectPad(in_channels=self.d_mult * 7,
                                                   out_channels=self.d_mult,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   padding='same',
                                                   act=self.activation))
        if self.efficientattention:
            self.output_layer.append(EfficientAttention(in_channels=self.d_mult, key_channels=self.d_mult, head_count=4, value_channels=self.d_mult))
        elif self.gcattention:
            self.output_layer.append(ContextBlock(inplanes=self.d_mult, ratio=0.25))
        self.output_layer.append(Conv2D_ReflectPad(in_channels=self.d_mult,
                                                   out_channels=self.out_channels,
                                                   kernel_size=3,
                                                   stride=1,
                                                   use_bias=True,
                                                   use_scale=True,
                                                   padding='same',
                                                   act=None))
        self.output_layer = nn.Sequential(*self.output_layer)

    def forward(self, input_data):
        downs = []
        net = input_data
        for x in range(len(self.downsampling_layers)):
            net = self.downsampling_layers[x](net)
            downs.append(net)

        for x in range(len(self.blocks)):
            net = self.blocks[x](net)

        for x in range(len(self.upsampling_layers)):
            idx = len(downs) - x - 1
            net = torch.cat((net, downs[idx]), 1)
            net = self.upsampling_layers[x](net)

        SRCNN_net = self.SRCNN(net)
        for x in range(len(self.output_layer)):
            net = self.output_layer[x](net)

        return input_data + net + SRCNN_net
