In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

upscale = 2


class SECA(nn.Module):
    """
    Spatial-Efficient Channel Attention (SECA)
    """

    def __init__(self, channel, reduction=8):  # 16 or 8(mini)
        super().__init__()
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), nn.Sigmoid()
        )
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channel, channel // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, kernel_size=1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # spacial attn:
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_weight = torch.cat([avg_out, max_out], dim=1)
        spatial_weight = self.spatial_attention(spatial_weight)
        # channel attn (2 MLP)
        channel_weight = self.channel_attention(x)
        return x * spatial_weight * channel_weight


class CSA(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), nn.Sigmoid()
        )
        # Bi-Directional
        self.channel_attention_forward = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv1d(
                1,
                1,
                kernel_size=kernel_size,
                padding=(kernel_size - 1) // 2,
                bias=False,
            ),
            nn.Sigmoid(),
        )
        self.channel_attention_backward = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv1d(
                1,
                1,
                kernel_size=kernel_size,
                padding=(kernel_size - 1) // 2,
                bias=False,
            ),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Spatial attention
        avg_out = torch.mean(x, dim=1, keepdim=True)  # (B,1,H,W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # (B,1,H,W)
        spatial_weight = torch.cat([avg_out, max_out], dim=1)  # (B,2,H,W)
        spatial_weight = self.spatial_attention(spatial_weight)  # (B,1,H,W)

        # Bi-Directional attention
        _b, _c, _, _ = x.size()
        y = torch.mean(x, dim=(2, 3), keepdim=True)  # (B,16,1,1)

        y_forward = (
            self.channel_attention_forward(y.squeeze(-1).transpose(-1, -2))
            .transpose(-1, -2)
            .unsqueeze(-1)
        )  # (B,1,1,1)
        y_backward = (
            self.channel_attention_backward(
                y.squeeze(-1).transpose(-1, -2).flip(dims=[1])
            )
            .transpose(-1, -2)
            .unsqueeze(-1)
        )  # (B,1,1,1)

        channel_weight = (y_forward + y_backward.flip(dims=[1])) / 2
        channel_weight = channel_weight.expand_as(x)  # (B,16,H,W)

        return x * spatial_weight * channel_weight


class Conv(nn.Module):
    def __init__(self, N):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(N, N * 2, 1),
            nn.BatchNorm2d(N * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(N * 2, N, 3, padding=1),
            nn.BatchNorm2d(N),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class FFN(nn.Module):
    def __init__(self, N):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Conv2d(N, N * 2, 1),
            nn.BatchNorm2d(N * 2),
            nn.GELU(),
            nn.Conv2d(N * 2, N, 1),
            nn.BatchNorm2d(N),
        )

    def forward(self, x):
        return self.ffn(x) + x


class Attn(nn.Module):
    def __init__(self, N, mini=False):
        super().__init__()
        self.pre_mixer = Conv(N)
        self.post_mixer = FFN(N)
        self.attn = SECA(N, reduction=8) if mini else CSA()
        self.norm1 = nn.BatchNorm2d(N)
        self.norm2 = nn.BatchNorm2d(N)

    def forward(self, x):
        out = self.pre_mixer(x)
        out = self.norm1(out)
        out = self.attn(out)
        out = self.post_mixer(out)
        out = self.norm2(out)
        out += x
        return out


class sebica(nn.Module):
    def __init__(self, sr_rate=upscale, N=16, mini=False, dropout=0.0, **kwargs):
        super().__init__()
        self.scale = sr_rate
        dropout = dropout if self.training else 0.0
        self.head = nn.Sequential(
            nn.Conv2d(3, N, 3, padding=1), nn.BatchNorm2d(N), nn.ReLU(inplace=True)
        )

        self.body = nn.Sequential(*[
            Attn(N, mini=mini) for _ in range(4 if mini else 6)
        ])

        self.tail = nn.Sequential(
            nn.Conv2d(N, 3 * sr_rate * sr_rate, 1),
            nn.Dropout(dropout),
            nn.PixelShuffle(sr_rate),
        )

    def forward(self, x):
        body_out = self.head(x)
        for attn_layer in self.body:
            body_out = attn_layer(body_out)
        h = self.tail(body_out)
        base = torch.clamp(
            F.interpolate(
                x, scale_factor=self.scale, mode="bilinear", align_corners=False
            ),
            0,
            1,
        )
        return h + base


def sebica_mini(**kwargs):  # noqa: ARG001
    return sebica(N=8, mini=True)


In [1]:
import torch
import torch.nn as nn

upscale = 2


class AttentionBlock(nn.Module):
    """
    A typical Squeeze-Excite attention block, with a local pooling instead of global
    """

    def __init__(self, n_feats, reduction=4, stride=16):
        super().__init__()
        self.body = nn.Sequential(
            nn.AvgPool2d(
                2 * stride - 1,
                stride=stride,
                padding=stride - 1,
                count_include_pad=False,
            ),
            nn.Conv2d(n_feats, n_feats // reduction, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats // reduction, n_feats, 1, bias=True),
            nn.Sigmoid(),
            nn.Upsample(scale_factor=stride, mode="nearest"),
        )

    def forward(self, x):
        res = self.body(x)
        if res.shape != x.shape:
            res = res[:, :, : x.shape[2], : x.shape[3]]
        return res * x


class ResBlock(nn.Module):
    def __init__(self, n_feats, mid_feats, in_scale, out_scale):
        super().__init__()

        self.in_scale = in_scale
        self.out_scale = out_scale

        m = []
        conv1 = nn.Conv2d(n_feats, mid_feats, 3, padding=1, bias=True)
        nn.init.kaiming_normal_(conv1.weight)
        nn.init.zeros_(conv1.bias)
        m.extend((conv1, nn.ReLU(inplace=True), AttentionBlock(mid_feats)))
        conv2 = nn.Conv2d(mid_feats, n_feats, 3, padding=1, bias=False)
        nn.init.kaiming_normal_(conv2.weight)
        # nn.init.zeros_(conv2.weight)
        m.append(conv2)

        self.body = nn.Sequential(*m)

    def forward(self, x):
        res = self.body(x * self.in_scale) * (2 * self.out_scale)
        res += x
        return res


class Rescale(nn.Module):
    def __init__(self, sign):
        super().__init__()
        # rgb_mean = (0.4488, 0.4371, 0.4040)
        # bias = sign * torch.Tensor(rgb_mean).reshape(1, 3, 1, 1)
        self.bias = nn.Parameter(torch.tensor(sign), requires_grad=True)

    def forward(self, x):
        return x + self.bias


class ninasr(nn.Module):
    def __init__(
        self, n_resblocks=26, n_feats=32, n_colors=3, scale=upscale, expansion=2.0
    ):
        super().__init__()
        self.scale = scale
        self.head = ninasr.make_head(n_colors, n_feats)
        self.body = ninasr.make_body(n_resblocks, n_feats, expansion)
        self.tail = ninasr.make_tail(n_colors, n_feats, scale)

    @staticmethod
    def make_head(n_colors, n_feats):
        m_head = [Rescale(-1.0), nn.Conv2d(n_colors, n_feats, 3, padding=1, bias=False)]
        return nn.Sequential(*m_head)

    @staticmethod
    def make_body(n_resblocks, n_feats, expansion):
        mid_feats = int(n_feats * expansion)
        out_scale = 4 / n_resblocks
        expected_variance = torch.tensor(1.0)
        m_body = []
        for _i in range(n_resblocks):
            in_scale = 1.0 / torch.sqrt(expected_variance)
            m_body.append(ResBlock(n_feats, mid_feats, in_scale, out_scale))
            expected_variance += out_scale**2
        return nn.Sequential(*m_body)

    @staticmethod
    def make_tail(n_colors, n_feats, scale):
        m_tail = [
            nn.Conv2d(n_feats, n_colors * scale**2, 3, padding=1, bias=True),
            nn.PixelShuffle(scale),
            Rescale(1.0),
        ]
        return nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        return self.tail(res)


def ninasr_b0(**kwargs):  # noqa: ARG001
    return ninasr(n_resblocks=10, n_feats=16)


def ninasr_b2(**kwargs):  # noqa: ARG001
    return ninasr(n_resblocks=84, n_feats=56)


In [5]:
model = sebica_mini()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = sebica()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = ninasr_b0()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = ninasr_b2()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

7652
40284
100174
10038854
