In [None]:
import math
import torch
from torch import nn as nn

def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)

class MDBN(nn.Module):
    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, upscale=2, res_scale=1.0):
        super(MDBN, self).__init__()
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(ResidualBlock, num_block, num_feat=num_feat, res_scale=res_scale)

        self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

        self.upsample = Upsample(upscale, num_feat)

        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

    def forward(self, x):
        x = self.conv_first(x)
        res = self.conv_after_body(self.body(x))
        res += x
        x = self.conv_last(self.upsample(res))

        return x

class ResidualBlock(nn.Module):
    def __init__(self, num_feat=64, res_scale=1):
        super(ResidualBlock, self).__init__()
        self.res_scale = res_scale
        self.baseblock1 = BaseBlock(num_feat)
        self.baseblock2 = BaseBlock(num_feat)

    def forward(self, x):
        identity = x

        x = self.baseblock1(x)
        x = self.baseblock2(x)

        return identity + x * self.res_scale

class BaseBlock(nn.Module):
    def __init__(self, num_feat):
        super(BaseBlock, self).__init__()
        self.uconv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.uconv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.dconv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.act = nn.GELU()

    def forward(self, x):
        x1 = self.uconv2(self.act(self.uconv1(x)))
        x2 = self.dconv(x)
        x = self.act(x1 + x2)
        return x

class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)


In [None]:
device = torch.device("cuda")
dummy_input = torch.randn(1, 3, 256, 256).to(device)
model = MDBN(upscale=4).to(device)
model.load_state_dict(torch.load("R:/MDBN_x4.pth", map_location=device)['params'])

# from fvcore.nn import FlopCountAnalysis
# dummy_input = torch.randn(1, 3, 720, 1280).to(device)
# flops = FlopCountAnalysis(model, dummy_input)
# print(f': {flops.total() / 10**9:.3f}G')

torch.onnx.export(
    model,
    dummy_input,
    "R:/MDBN_x4.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size", 2: "height", 3: "width"},
    },
    opset_version=11,
)

In [11]:
from torch import nn as nn
import torch
class bpp(nn.Module):
    def __init__(self):
        super(bpp, self).__init__()
        sr_rate = 3
        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv1 = nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False)
        self.conv_out = nn.Conv2d(32, (2*sr_rate)**2 * 3, kernel_size=3, padding=1, bias=False)
        self.Depth2Space = nn.PixelShuffle(2*sr_rate)
        self.act = nn.LeakyReLU(inplace=True, negative_slope=0.1)

    def forward(self, x):
        x0 = self.conv0(x)
        x0 = self.act(x0)
        x1 = self.conv1(x0)
        x1 = self.act(x1)
        x2 = self.conv2(x1)
        x2 = self.act(x2) + x0
        y = self.conv_out(x2)
        y = self.Depth2Space(y)
        return y
    
device = torch.device("cuda")
model = bpp().to(device).half()

from fvcore.nn import FlopCountAnalysis
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
flops = FlopCountAnalysis(model, dummy_input)
print(f': {flops.total() / 10**9:.3f}G')


Unsupported operator aten::leaky_relu_ encountered 3 time(s)
Unsupported operator aten::add encountered 1 time(s)
Unsupported operator aten::pixel_shuffle encountered 1 time(s)


: 11.612G


In [12]:
sum(p.numel() for p in model.parameters()  if p.requires_grad)

50400