In [36]:
import torch
from torch import nn
import torch.nn.functional as F

def conv_layer(in_channels,
               out_channels,
               kernel_size = 3,
               bias=True):
    padding = int((kernel_size - 1) / 2)
    return nn.Conv2d(in_channels,
                     out_channels,
                     kernel_size=kernel_size,
                     padding=padding,
                     bias=bias)

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()
    def forward(self, x):
        return F.relu(torch.cat((x, -x), 1))

class PFAB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PFAB, self).__init__()
        self.conv1 = conv_layer(in_channels, out_channels)
        self.conv2 = conv_layer(out_channels, out_channels)
        self.act = nn.PReLU(out_channels)
    def forward(self, x):
        out = self.act(self.conv1(x))
        att = torch.sigmoid(self.conv2(out)) - 0.5
        return out * att

class AnimeA(nn.Module):
    def __init__(self, channel=8, block_depth=6, scale=2, stack_list=None):
        super(AnimeA, self).__init__()
        self.scale = scale
        self.first = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.body = nn.ModuleList([PFAB(channel, channel) for _ in range(block_depth)])
        if stack_list == None:
            stack_list = list(range(block_depth+1))
        self.stack_list = [s for s in stack_list if s <= channel]
        print(self.stack_list)
        self.last = nn.Sequential(conv_layer(channel * len(self.stack_list), 3*scale*scale), nn.PixelShuffle(scale))
    def forward(self, x):
        stack = []
        out = self.first(x)
        if 0 in self.stack_list:
            stack.append(out)
        for index, block in enumerate(self.body):
            out = block(out)
            if index+1 in self.stack_list:
                stack.append(out)
        out = self.last(torch.cat(stack, 1))
        return out + F.interpolate(x, scale_factor=self.scale, mode='bilinear')


In [41]:
device = torch.device("mps")
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
from fvcore.nn import FlopCountAnalysis
model = AnimeA(channel=12, block_depth=6, scale=2, stack_list=[0, 1, 5, 6]).to(device).half()
flops = FlopCountAnalysis(model, dummy_input)
print(f'{flops.total() / 10**9:.3f}G')

Unsupported operator aten::prelu encountered 6 time(s)
Unsupported operator aten::sigmoid encountered 6 time(s)
Unsupported operator aten::sub encountered 6 time(s)
Unsupported operator aten::mul encountered 6 time(s)
Unsupported operator aten::pixel_shuffle encountered 1 time(s)
Unsupported operator aten::add encountered 1 time(s)


[0, 1, 5, 6]
19.453G


In [None]:
'''
channel=8, block_depth=6, scale=2, stack_list=[0, 1, 2, 3, 4, 5, 6]
12.187G
channel=8, block_depth=6, scale=2, stack_list=[0, 1, 5, 6]
9.798G
channel=12, block_depth=6, scale=2, stack_list=[0, 1, 5, 6]
19.453G
channel=12, block_depth=7, scale=2, stack_list=[0, 1, 6, 7]
21.842G
'''