In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import math

import torchvision
from IPython.display import Image, display
from PIL import Image

# import onnx

to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()

device = torch.device("mps")

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()
    def forward(self, x):
        # return torch.cat((F.relu(x), F.relu(-x)), 1)
        return F.relu(torch.cat((x, -x), 1))

def convert(c, iter, doswap=False):
    swap = [0,2,1,3]
    out_chan, in_chan, width, height = c.weight.shape
    for to in range(math.ceil(out_chan/4)):
        for ti in range(math.ceil(in_chan/4)):
            for w in range(width):
                for h in range(height):
                    for i in range(min(4, in_chan)):
                        for o in range(min(4, out_chan)):
                            o = swap[o] if doswap else o
                            c.weight.data[to*4+o, ti*4+i, h, w] = float(next(iter).group(0))
        for o in range(min(4, out_chan)):
            o = swap[o] if doswap else o
            c.bias.data[to*4+o] = float(next(iter).group(0))

def debug(*args, **kwargs):
    if True:
        print(*args, **kwargs)

In [2]:
class Anime4kGan(nn.Module):
    def __init__(self, block_nums=5, channel=4, last_channel=8, second_last_channel=0, scale=2):
        super(Anime4kGan, self).__init__()
        body_a = []
        body_b = []
        body_c = []
        self.scale = scale
        self.block_nums = block_nums
        self.second_last_channel = second_last_channel
        self.first = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        for i in range(block_nums):
            if i != self.block_nums - 1:
                body_a.append(nn.Conv2d(channel * 2, 4, kernel_size=3, padding=1))
            body_b.append(nn.Conv2d(channel * 2, 4, kernel_size=3, padding=1))
            body_c.append(
                nn.Conv2d(
                    channel * 2 + 16 + i * 8,
                    channel if i != block_nums - 1 else last_channel,
                    kernel_size=1,
                    padding=0,
                )
            )

        self.body_a = nn.ModuleList(body_a)
        self.body_b = nn.ModuleList(body_b)
        self.body_c = nn.ModuleList(body_c)
        if second_last_channel == 0:
            self.last = nn.Conv2d(last_channel * 2, 3, kernel_size=3, padding=1)
        else:
            self.last = nn.Conv2d(last_channel * 2, second_last_channel, kernel_size=3, padding=2, dilation=2)
            self.second_last = nn.Conv2d(second_last_channel * 2, 3, kernel_size=3, padding=1)
        self.crelu = CReLU()

    def forward(self, x):
        x0 = self.first(x)
        accumulate = []
        for i in range(self.block_nums):
            x0 = self.crelu(x0)
            if i != self.block_nums - 1:
                x2 = self.crelu(self.body_a[i](x0))
            x1 = self.crelu(self.body_b[i](x0))
            accumulate.append(x1)
            x0 = self.body_c[i](torch.cat([x0, x2, *accumulate], 1))
        out = self.last(
            self.crelu(F.interpolate(x0, scale_factor=self.scale, mode="bilinear"))
        )
        if self.second_last_channel > 0:
            out = self.second_last(self.crelu(out))
        return out + F.interpolate(x, scale_factor=self.scale, mode="bilinear")
    

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r"-?\d+(\.\d{2,})(e-?\d+)?"
        iter = re.finditer(pattern, text)
        convert(self.first, iter)
        for i in range(self.block_nums):
            if i != self.block_nums - 1:
                convert(self.body_a[i], iter)
            convert(self.body_b[i], iter)
            convert(self.body_c[i], iter)
        convert(self.last, iter)
        if self.second_last_channel > 0:
            convert(self.second_last, iter)
        check = next(iter, None)
        if check == None:
            print("pass")
        else:
            print("---failed---\n", check)


# model = Anime4kGan(block_nums=5, channel=8, last_channel=12, scale=2)
# model.import_param("tmp/Anime4K_Upscale_GAN_x2_M.glsl")

# model = Anime4kGan(block_nums=5, channel=12, last_channel=12, second_last_channel=8, scale=3)
# model.import_param("tmp/Anime4K_Upscale_GAN_x3_L.glsl")

# model = Anime4kGan(block_nums=10, channel=16, last_channel=16, second_last_channel=12, scale=4)
# model.import_param("tmp/Anime4K_Upscale_GAN_x4_UL.glsl")

# model = Anime4kGan(block_nums=9, channel=24, last_channel=24, second_last_channel=24, scale=4)
# model.import_param("tmp/Anime4K_Upscale_GAN_x4_UUL.glsl")
# model.to(device).half()

# image2 = Image.open("/Users/khoi.ho/Downloads/Screenshot 2023-12-18 18-32-07.png").convert("RGB")
# image2 = to_tensor(image2).unsqueeze(0).to(device).half()
# out = model(image2)[0]
# # clamp out to 0,1 
# out = torch.clamp(out, min=0, max=1)
# display(to_pil(out))

In [3]:
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
upscaleModel = [
    ("tmp/Anime4K_Upscale_GAN_x2_S.glsl", "Anime4K_Upscale_GAN_x2_S", 5 , 4, 8, 0, 2),
    ("tmp/Anime4K_Upscale_GAN_x2_M.glsl", "Anime4K_Upscale_GAN_x2_M", 5 , 8, 12, 0, 2),
    ("tmp/Anime4K_Upscale_GAN_x3_L.glsl", "Anime4K_Upscale_GAN_x3_L",  5 , 12, 12, 8, 3),
    ("tmp/Anime4K_Upscale_GAN_x3_VL.glsl", "Anime4K_Upscale_GAN_x3_VL", 8 , 12, 16, 12, 3),
    ("tmp/Anime4K_Upscale_GAN_x4_UL.glsl", "Anime4K_Upscale_GAN_x4_UL", 10, 16, 16, 12, 4),
    ("tmp/Anime4K_Upscale_GAN_x4_UUL.glsl", "Anime4K_Upscale_GAN_x4_UUL", 9, 24, 24, 24, 4),
]

for filename, name, a, b, c, d, e in upscaleModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kGan(a, b, c, d, e)
    model.import_param(filename)
    model.to(device).half()
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "width", 3: "height"},
            "output": {0: "batch_size", 2: "width", 3: "height"},
        },
    )

pass
pass
pass
pass
pass
pass


In [None]:
from fvcore.nn import FlopCountAnalysis
for filename, name, a, b, c, d, e in upscaleModel:
    model = Anime4kGan(a, b, c, d, e).to(device).half()
    flops = FlopCountAnalysis(model, dummy_input)
    print(name, f': {flops.total() / 10**9:.3f}G')

In [None]:
# model = Anime4kGan(block_nums=5, channel=8, last_channel=12, scale=2).to(device).half()
model = Anime4kGan(block_nums=5, channel=12, last_channel=12, second_last_channel=8, scale=3).to(device).half()
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
output_tensor = model.forward(dummy_input)


# get total number of parameters of model
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of parameters: {total_params:,}")

with open("tmp/Anime4K_Upscale_GAN_x3_L.glsl") as f:
    text = f.read()

pattern = r'-?\d+(\.\d{3,})(e-?\d+)?'
# get total number of match patern in text
total_match_pattern = len(re.findall(pattern, text))
print(f"Total number of match pattern: {total_match_pattern:,}")
