In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import math

import torchvision
from IPython.display import Image, display
from PIL import Image

# import onnx

to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()

device = torch.device("cuda")

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()
    def forward(self, x):
        # return torch.cat((F.relu(x), F.relu(-x)), 1)
        return F.relu(torch.cat((x, -x), 1))

def convert(c, iter, doswap=False):
    swap = [0,2,1,3]
    out_chan, in_chan, width, height = c.weight.shape
    for to in range(math.ceil(out_chan/4)):
        for ti in range(math.ceil(in_chan/4)):
            for w in range(width):
                for h in range(height):
                    for i in range(min(4, in_chan)):
                        for o in range(min(4, out_chan)):
                            o = swap[o] if doswap else o
                            c.weight.data[to*4+o, ti*4+i, w, h] = float(next(iter).group(0))
        for o in range(min(4, out_chan)):
            o = swap[o] if doswap else o
            c.bias.data[to*4+o] = float(next(iter).group(0))


In [8]:
class Anime4kRestore(nn.Module):
    def __init__(self, block_depth=8, block_stack=5, channel=12):
        super(Anime4kRestore, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel*2, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                2*channel*block_stack, 3, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(2*channel, 3, kernel_size=3, padding=1)
        self.crelu = CReLU()
        self.block_no_stack = block_depth - block_stack

    def forward(self, x):
        out = self.crelu(self.conv_head(x))
        print(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
            print(0, "----")
        else:
            depth_list = []
            print(0, "a")
        for i, conv in enumerate(self.conv_mid):
            out = self.crelu(conv(out))
            print(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
                print(i+1, "----")
            else:
                print(i+1, "a")
        out = self.conv_tail(torch.cat(depth_list, 1))
        print(self.conv_tail.weight.shape)
        print("out")
        # return torch.clamp(out + x, max=1.0, min=0.0)
        return out + x

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{2,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_tail, iter)
        check = next(iter, None)
        if check == None:
            print("pass")
        else:
            print("---failed---\n", check)

# model = Anime4kRestore(8, 5, 12).to(device)
# model.import_param("tmp/Anime4K_Restore_CNN_UL.glsl")
# model = Anime4kRestore(8, 7, 8)
# model.import_param("tmp/Anime4K_Restore_CNN_VL.glsl")
# model = Anime4kRestore(4, 1, 8)
# model.import_param("tmp/Anime4K_Restore_CNN_L.glsl")
# model = Anime4kRestore(7, 7, 4)
# model.import_param("tmp/Anime4K_Restore_CNN_M.glsl")
# model = Anime4kRestore(3, 1, 4)
# model.import_param("tmp/Anime4K_Restore_CNN_S.glsl")

# image2 = Image.open(
#     "C://Users/khoi/Videos/Screenshot 2023-08-21 14-04-25.png").convert("RGB")
# image2 = to_tensor(image2).unsqueeze(0).to(device)


# out = model(image2)[0]
# display(to_pil(out))

In [10]:
class Anime4kUpscale(nn.Module):
    def __init__(self, block_depth=8, block_stack=5, channel=12):
        super(Anime4kUpscale, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel*2, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                2*channel*block_stack, 12, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(2*channel, 12, kernel_size=3, padding=1)
        self.crelu = CReLU()
        self.block_no_stack = block_depth - block_stack
        self.ps = nn.PixelShuffle(2)

    def forward(self, x):
        out = self.crelu(self.conv_head(x))
        print(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
            print(0, "----")
        else:
            depth_list = []
            print(0, "a")
        for i, conv in enumerate(self.conv_mid):
            out = self.crelu(conv(out))
            print(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
                print(i+1, "----")
            else:
                print(i+1, "a")
        out = self.conv_tail(torch.cat(depth_list, 1))
        print(self.conv_tail.weight.shape)
        print("out")
        out = self.ps(out) + F.interpolate(x, scale_factor=2, mode='bilinear')
        # return torch.clamp(out, max=1.0, min=0.0)
        return out

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{4,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_tail, iter, True)
        check = next(iter, None)
        if check == None:
            print("pass")
        else:
            print("---failed---\n", check)


# model = Anime4kUpscale(7, 5, 12).to(device)
# model.import_param("tmp/Anime4K_Upscale_CNN_x2_UL.glsl")
# model = Anime4kUpscale(7, 7, 8)
# model.import_param("tmp/Anime4K_Upscale_CNN_x2_VL.glsl")
# model = Anime4kUpscale(3, 1, 8)
# model.import_param("tmp/Anime4K_Upscale_CNN_x2_L.glsl")

# image2 = Image.open(
#     "C://Users/khoi/Videos/Screenshot 2023-08-21 14-04-25.png").convert("RGB")
# image2 = to_tensor(image2).unsqueeze(0).to(device)
# out = model(image2)[0]
# display(to_pil(out))

In [11]:
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()

upscaleModel = [
    ("tmp/Anime4K_Upscale_CNN_x2_L.glsl", "Upscale_L", 3, 1, 8),
    ("tmp/Anime4K_Upscale_CNN_x2_VL.glsl", "Upscale_VL", 7, 7, 8),
    ("tmp/Anime4K_Upscale_CNN_x2_UL.glsl", "Upscale_UL", 7, 5, 12),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_L.glsl", "Upscale_Denoise_L", 3, 1, 8),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_VL.glsl", "Upscale_Denoise_VL", 7, 7, 8),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_UL.glsl", "Upscale_Denoise_UL", 7, 5, 12),
]
restoreModel = [
    ("tmp/Anime4K_Restore_CNN_S.glsl", "Restore_S", 3, 1, 4),
    ("tmp/Anime4K_Restore_CNN_M.glsl", "Restore_M", 7, 7, 4),
    ("tmp/Anime4K_Restore_CNN_L.glsl", "Restore_L", 4, 1, 8),
    ("tmp/Anime4K_Restore_CNN_VL.glsl", "Restore_VL", 8, 7, 8),
    ("tmp/Anime4K_Restore_CNN_UL.glsl", "Restore_UL", 8, 5, 12),
    ("tmp/Anime4K_Restore_CNN_Soft_S.glsl", "Restore_S", 3, 1, 4),
    ("tmp/Anime4K_Restore_CNN_Soft_M.glsl", "Restore_Soft_M", 7, 7, 4),
    ("tmp/Anime4K_Restore_CNN_Soft_L.glsl", "Restore_Soft_L", 4, 1, 8),
    ("tmp/Anime4K_Restore_CNN_Soft_VL.glsl", "Restore_Soft_VL", 8, 7, 8),
    ("tmp/Anime4K_Restore_CNN_Soft_UL.glsl", "Restore_Soft_UL", 8, 5, 12),
]


def print(*args):
    pass


for filename, name, a, b, c in upscaleModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kUpscale(a, b, c).to(device).half()
    model.import_param(filename)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "width", 3: "height"},
            "output": {0: "batch_size", 2: "width", 3: "height"},
        },
    )
    # torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

for filename, name, a, b, c in restoreModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kRestore(a, b, c).to(device).half()
    model.import_param(filename)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "height", 3: "width"},
            "output": {0: "batch_size", 2: "height", 3: "width"},
        },
        opset_version=10,
    )
    # torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Anime4kUpscale2(nn.Module):
    def __init__(self, block_depth=7, block_stack=5, channel=12):
        super(Anime4kUpscale2, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                channel*block_stack, 12, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(channel, 12, kernel_size=3, padding=1)
        self.prelu = nn.ModuleList[nn.PReLU() for i in range(block_depth)]
        self.block_no_stack = block_depth - block_stack
        self.ps = nn.PixelShuffle(2)

    def forward(self, x):
        out = self.prelu[0](self.conv_head(x))
        print(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
        else:
            depth_list = []
        for i, conv in enumerate(self.conv_mid):
            out = self.prelu[i+1](conv(out))
            print(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
        out = self.conv_tail(torch.cat(depth_list, 1))
        out = self.ps(out) + F.interpolate(x, scale_factor=2, mode='bilinear')
        # return torch.clamp(out, max=1.0, min=0.0)
        return out
    
device = torch.device("cuda")
    
model = Anime4kUpscale2().to(device).half()
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
onnx_path = "onnxModel/test.onnx"
torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "height", 3: "width"},
            "output": {0: "batch_size", 2: "height", 3: "width"},
        },
        opset_version=11,
    )

SyntaxError: invalid syntax (1792037378.py, line 16)