In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from IPython.display import display
from PIL import Image
import re
import math

def conv_layer(in_channels, out_channels, kernel_size):
    padding = int((kernel_size - 1) / 2)
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()

    def forward(self, x):
        return F.relu(torch.cat((x, -x), 1))

class anime4k(nn.Module):
    def __init__(self, block_depth=7, stack_list=5, num_feat=12, last=False, scale=2, single_tail=False, upscale_mode="bilinear"):
        super(anime4k, self).__init__()
        self.act = CReLU()
        if type(stack_list) == int:
            stack_list = list(range(-stack_list, 0))
        self.stack_list = stack_list
        self.scale = scale
        self.ps = nn.PixelShuffle(self.scale)
        self.conv_head = conv_layer(3, num_feat, kernel_size=3)
        self.conv_mid = nn.ModuleList(
            [
                conv_layer(num_feat * 2, num_feat, kernel_size=3)
                for _ in range(block_depth - 1)
            ]
        )
        tail_out_c = 4 if single_tail else 3*scale*scale
        if last:
            self.conv_tail = conv_layer(2 * num_feat * len(stack_list), tail_out_c, kernel_size=3)
        else:
            self.conv_tail = conv_layer(2 * num_feat * len(stack_list), tail_out_c, kernel_size=1)
        self.upscale_mode = upscale_mode

    def forward(self, x):
        out = self.act(self.conv_head(x))
        depth_list = [out]
        for conv in self.conv_mid:
            out = self.act(conv(out))
            depth_list.append(out)
        out = self.conv_tail(torch.cat([depth_list[i] for i in self.stack_list], 1))
        if self.scale != 1:
            out = self.ps(out) + F.interpolate(x, scale_factor=self.scale, mode=self.upscale_mode)
        else:
            out += x
        return torch.clamp(out, max=1.0, min=0.0)

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{4,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_tail, iter, True)
        check = next(iter, None)
        if check == None:
            print("pass")
        else:
            print("---failed---\n", check)


def convert(c, iter, doswap=False):
    swap = [0,2,1,3]
    out_chan, in_chan, width, height = c.weight.shape
    for to in range(math.ceil(out_chan/4)):
        for ti in range(math.ceil(in_chan/4)):
            for w in range(width):
                for h in range(height):
                    for i in range(min(4, in_chan)):
                        for o in range(min(4, out_chan)):
                            o = swap[o] if doswap else o
                            c.weight.data[to*4+o, ti*4+i, w, h] = float(next(iter).group(0))
        for o in range(min(4, out_chan)):
            o = swap[o] if doswap else o
            c.bias.data[to*4+o] = float(next(iter).group(0))
    

to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()      
            
device = torch.device("cuda")
model = anime4k(block_depth=14, stack_list=7, num_feat=12).to(device).half()
model.eval()
model.load_state_dict(torch.load("E:/project/neosr/experiments/amakano_a14-7-r_b2_li-la/models/net_g_latest.pth", map_location=device)['params'])
image2 = to_tensor(Image.open("E:/Dataset/val/sora.png").convert("RGB")).unsqueeze(0).half().to(device)
out = model(image2)[0]
print(out.shape)
# display(to_pil(out))


  model.load_state_dict(torch.load("E:/project/neosr/experiments/amakano_a14-7-r_b2_li-la/models/net_g_latest.pth", map_location=device)['params'])


torch.Size([3, 1440, 2560])


In [None]:
# Create ONNX from 2 models
# restoreul = anime4k(8, 5, 12, False, 1).to(device).half()
# restoreul.import_param("../tmp/Anime4K_Restore_CNN_UL.glsl")
# softul = anime4k(8, 5, 12, False, 1).to(device).half()
# softul.import_param("../tmp/Anime4K_Restore_CNN_Soft_UL.glsl")
# upul = anime4k(7, 5, 12, False, 2).to(device).half()
# upul.import_param("../tmp/Anime4K_Upscale_CNN_x2_UL.glsl")
# dummy_input = torch.randn(1, 3, 720, 1280).half().to(device)
# reup = nn.Sequential(restoreul, upul)
# soup = nn.Sequential(softul, upul)
# upre = nn.Sequential(upul, restoreul)
# upso = nn.Sequential(upul, softul)

# for model, path in [(reup, "R:/RestoreUL_UpscaleUL.onnx"), (soup, "R:/SoftUL_UpscaleUL.onnx"), (upre, "R:/UpscaleUL_RestoreUL.onnx"), (upso, "R:/UpscaleUL_SoftUL.onnx")]:
#     torch.onnx.export(
#         model,
#         dummy_input,
#         path,
#         input_names=["input"],
#         output_names=["output"],
#         dynamic_axes={
#             "input": {0: "batch_size", 2: "height", 3: "width"},
#             "output": {0: "batch_size", 2: "height", 3: "width"},
#         },
#         opset_version=17,
#     )

pass
pass
pass


In [6]:
model = anime4k(block_depth=7, stack_list=7, num_feat=8)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=8, stack_list=5, num_feat=8)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=7, stack_list=5, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=15, stack_list=5, num_feat=8)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=14, stack_list=11, num_feat=8)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=14, stack_list=7, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=17, stack_list=9, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

8540
9316
17412
17436
17428
36216
44604
