In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import math

import torchvision
from IPython.display import Image, display
from PIL import Image

# import onnx

to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()

device = torch.device("cuda")

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()
    def forward(self, x):
        # return torch.cat((F.relu(x), F.relu(-x)), 1)
        return F.relu(torch.cat((x, -x), 1))

def convert(c, iter, doswap=False):
    swap = [0,2,1,3]
    out_chan, in_chan, width, height = c.weight.shape
    for to in range(math.ceil(out_chan/4)):
        for ti in range(math.ceil(in_chan/4)):
            for w in range(width):
                for h in range(height):
                    for i in range(min(4, in_chan)):
                        for o in range(min(4, out_chan)):
                            o = swap[o] if doswap else o
                            c.weight.data[to*4+o, ti*4+i, w, h] = float(next(iter).group(0))
        for o in range(min(4, out_chan)):
            o = swap[o] if doswap else o
            c.bias.data[to*4+o] = float(next(iter).group(0))
            
def debug(*args, **kwargs):
    if False:
        print(*args, **kwargs)


In [2]:
class Anime4kRestore(nn.Module):
    def __init__(self, block_depth=8, block_stack=5, channel=12):
        super(Anime4kRestore, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel*2, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                2*channel*block_stack, 3, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(2*channel, 3, kernel_size=3, padding=1)
        self.crelu = CReLU()
        self.block_no_stack = block_depth - block_stack

    def forward(self, x):
        out = self.crelu(self.conv_head(x))
        debug(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
            debug(0, "----")
        else:
            depth_list = []
            debug(0, "a")
        for i, conv in enumerate(self.conv_mid):
            out = self.crelu(conv(out))
            debug(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
                debug(i+1, "----")
            else:
                debug(i+1, "a")
        out = self.conv_tail(torch.cat(depth_list, 1))
        debug(self.conv_tail.weight.shape)
        debug("out")
        # return torch.clamp(out + x, max=1.0, min=0.0)
        return out + x

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{2,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_tail, iter)
        check = next(iter, None)
        if check == None:
            debug("pass")
        else:
            debug("---failed---\n", check)

# image2 = Image.open(
#     "C://Users/khoi/Videos/Screenshot 2023-08-21 14-04-25.png").convert("RGB")
# image2 = to_tensor(image2).unsqueeze(0).to(device)
# out = model(image2)[0]
# display(to_pil(out))

In [2]:
class Anime4kUpscale(nn.Module):
    def __init__(self, block_depth=8, block_stack=5, channel=12):
        super(Anime4kUpscale, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel*2, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                2*channel*block_stack, 12, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(2*channel, 12, kernel_size=3, padding=1)
        self.crelu = CReLU()
        self.block_no_stack = block_depth - block_stack
        self.ps = nn.PixelShuffle(2)

    def forward(self, x):
        out = self.crelu(self.conv_head(x))
        debug(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
            debug(0, "----")
        else:
            depth_list = []
            debug(0, "a")
        for i, conv in enumerate(self.conv_mid):
            out = self.crelu(conv(out))
            debug(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
                debug(i+1, "----")
            else:
                debug(i+1, "a")
        out = self.conv_tail(torch.cat(depth_list, 1))
        debug(self.conv_tail.weight.shape)
        debug("out")
        out = self.ps(out) + F.interpolate(x, scale_factor=2, mode='bilinear')
        # return torch.clamp(out, max=1.0, min=0.0)
        return out

    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{4,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_tail, iter, True)
        check = next(iter, None)
        if check == None:
            debug("pass")
        else:
            debug("---failed---\n", check)

In [3]:
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()

upscaleModel = [
    ("tmp/Anime4K_Upscale_CNN_x2_L.glsl", "Upscale_L", 3, 1, 8),
    ("tmp/Anime4K_Upscale_CNN_x2_VL.glsl", "Upscale_VL", 7, 7, 8),
    ("tmp/Anime4K_Upscale_CNN_x2_UL.glsl", "Upscale_UL", 7, 5, 12),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_L.glsl", "Upscale_Denoise_L", 3, 1, 8),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_VL.glsl", "Upscale_Denoise_VL", 7, 7, 8),
    ("tmp/Anime4K_Upscale_Denoise_CNN_x2_UL.glsl", "Upscale_Denoise_UL", 7, 5, 12),
]
restoreModel = [
    ("tmp/Anime4K_Restore_CNN_S.glsl", "Restore_S", 3, 1, 4),
    ("tmp/Anime4K_Restore_CNN_M.glsl", "Restore_M", 7, 7, 4),
    ("tmp/Anime4K_Restore_CNN_L.glsl", "Restore_L", 4, 1, 8),
    ("tmp/Anime4K_Restore_CNN_VL.glsl", "Restore_VL", 8, 7, 8),
    ("tmp/Anime4K_Restore_CNN_UL.glsl", "Restore_UL", 8, 5, 12),
    ("tmp/Anime4K_Restore_CNN_Soft_S.glsl", "Restore_Soft_S", 3, 1, 4),
    ("tmp/Anime4K_Restore_CNN_Soft_M.glsl", "Restore_Soft_M", 7, 7, 4),
    ("tmp/Anime4K_Restore_CNN_Soft_L.glsl", "Restore_Soft_L", 4, 1, 8),
    ("tmp/Anime4K_Restore_CNN_Soft_VL.glsl", "Restore_Soft_VL", 8, 7, 8),
    ("tmp/Anime4K_Restore_CNN_Soft_UL.glsl", "Restore_Soft_UL", 8, 5, 12),
]

In [None]:

for filename, name, a, b, c in upscaleModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kUpscale(a, b, c).to(device).half()
    model.import_param(filename)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "width", 3: "height"},
            "output": {0: "batch_size", 2: "width", 3: "height"},
        },
    )
    # torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

for filename, name, a, b, c in restoreModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kRestore(a, b, c).to(device).half()
    model.import_param(filename)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "height", 3: "width"},
            "output": {0: "batch_size", 2: "height", 3: "width"},
        },
        opset_version=10,
    )
    # torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Anime4kUpscale2(nn.Module):
    def __init__(self, block_depth=7, block_stack=5, channel=12):
        super(Anime4kUpscale2, self).__init__()
        self.conv_head = nn.Conv2d(3, channel, kernel_size=3, padding=1)
        self.conv_mid = nn.ModuleList([nn.Conv2d(
            channel, channel, kernel_size=3, padding=1) for _ in range(block_depth-1)])
        if block_stack != 1:
            self.conv_tail = nn.Conv2d(
                channel*block_stack, 12, kernel_size=1, padding=0)
        else:
            self.conv_tail = nn.Conv2d(channel, 12, kernel_size=3, padding=1)
        self.prelu = nn.ModuleList([nn.PReLU() for i in range(block_depth)])
        self.block_no_stack = block_depth - block_stack
        self.ps = nn.PixelShuffle(2)

    def forward(self, x):
        out = self.prelu[0](self.conv_head(x))
        print(self.conv_head.weight.shape)
        if self.block_no_stack == 0:
            depth_list = [out]
        else:
            depth_list = []
        for i, conv in enumerate(self.conv_mid):
            out = self.prelu[i+1](conv(out))
            print(conv.weight.shape)
            if i >= self.block_no_stack - 1:
                depth_list.append(out)
        out = self.conv_tail(torch.cat(depth_list, 1))
        out = self.ps(out) + F.interpolate(x, scale_factor=2, mode='bilinear')
        # return torch.clamp(out, max=1.0, min=0.0)
        return out
    
device = torch.device("cuda")
model = Anime4kUpscale2().to(device).half()
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
onnx_path = "onnxModel/test.onnx"
torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 2: "height", 3: "width"},
            "output": {0: "batch_size", 2: "height", 3: "width"},
        },
        opset_version=11,
    )

In [None]:
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()

from fvcore.nn import FlopCountAnalysis
for filename, name, a, b, c in upscaleModel:
    model = Anime4kUpscale(a, b, c).to(device).half()
    flops = FlopCountAnalysis(model, dummy_input)
    print(name, f': {flops.total() / 10**9:.3f}G')
for filename, name, a, b, c in restoreModel:
    model = Anime4kRestore(a, b, c).to(device).half()
    flops = FlopCountAnalysis(model, dummy_input)
    print(name, f': {flops.total() / 10**9:.3f}G')

'''
720p
Restore_Soft_S : 0.829G
Restore_Soft_M : 1.847G
Restore_Soft_L : 3.782G
Restore_Soft_VL : 7.941G
Restore_Soft_UL : 17.352G
Upscale_L : 3.959G
Upscale_VL : 7.852G
Upscale_UL : 16.003G

GAN_x2_S : 5.187G
GAN_x2_M : 9.592G
GAN_x3_L : 28.975G
GAN_x3_VL : 53.283G
GAN_x4_UL : 94.593G
GAN_x4_UUL : 220.366G

janaiSUC : 41.218G
janaiUC : 279.765G
janaiC : 551.555G
1400.891G

1080p
Restore_S : 1.866G
Restore_Soft_M : 4.155G
Restore_Soft_L : 8.510G
Restore_Soft_VL : 17.866G
Restore_Soft_UL : 39.042G
Upscale_L : 8.908G
Upscale_VL : 17.667G
Upscale_UL : 36.006G
janaiSUC : 92.740G
janaiUC : 629.470G
janaiC : 1241.000G

'''

In [None]:
from fvcore.nn import FlopCountAnalysis

dummy_input = torch.randn(1, 3, 1080, 1920).to(device).half()
# dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()

class compact(nn.Module):
    """A compact VGG-style network structure for super-resolution.

    It is a compact network structure, which performs upsampling in the last layer and no convolution is
    conducted on the HR feature space.

    Args:
        num_in_ch (int): Channel number of inputs. Default: 3.
        num_out_ch (int): Channel number of outputs. Default: 3.
        num_feat (int): Channel number of intermediate features. Default: 64.
        num_conv (int): Number of convolution layers in the body network. Default: 16.
        upscale (int): Upsampling factor. Default: 4.
        act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
    """

    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu', **kwargs):
        super(compact, self).__init__()
        self.num_in_ch = num_in_ch
        self.num_out_ch = num_out_ch
        self.num_feat = num_feat
        self.num_conv = num_conv
        self.upscale = upscale
        self.act_type = act_type

        self.body = nn.ModuleList()
        # the first conv
        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
        # the first activation
        if act_type == 'relu':
            activation = nn.ReLU(inplace=True)
        elif act_type == 'prelu':
            activation = nn.PReLU(num_parameters=num_feat)
        elif act_type == 'leakyrelu':
            activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.body.append(activation)

        # the body structure
        for _ in range(num_conv):
            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
            # activation
            if act_type == 'relu':
                activation = nn.ReLU(inplace=True)
            elif act_type == 'prelu':
                activation = nn.PReLU(num_parameters=num_feat)
            elif act_type == 'leakyrelu':
                activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            self.body.append(activation)

        # the last conv
        self.body.append(nn.Conv2d(num_feat, num_out_ch *
                         upscale * upscale, 3, 1, 1))
        # upsample
        self.upsampler = nn.PixelShuffle(upscale)

    def forward(self, x):
        out = x
        for i in range(0, len(self.body)):
            out = self.body[i](out)

        out = self.upsampler(out)
        # add the nearest upsampled image, so that the network learns the residual
        base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
        out += base
        return out

model = compact(num_feat=64, num_conv=16).to(device).half()
flops = FlopCountAnalysis(model, dummy_input)
print(f': {flops.total() / 10**9:.3f}G')


In [12]:
## eval
import os
import torch
from torchvision.transforms import ToTensor
from PIL import Image
from torch.utils.data import DataLoader, Dataset


# Define a custom dataset
class CustomDataset(Dataset):
    def __init__(self, low_res_folder, gt_folder, transform=None):
        self.low_res_folder = low_res_folder
        self.gt_folder = gt_folder
        self.transform = transform

        self.gt_files = os.listdir(gt_folder)

    def __len__(self):
        return len(self.gt_files)

    def __getitem__(self, idx):
        low_res_path = os.path.join(self.low_res_folder, self.gt_files[idx])
        gt_path = os.path.join(self.gt_folder, self.gt_files[idx])

        low_res_image = Image.open(low_res_path).convert("RGB")
        gt_image = Image.open(gt_path).convert("RGB")

        if self.transform:
            low_res_image = self.transform(low_res_image)
            gt_image = self.transform(gt_image)

        return low_res_image.to(device).half(), gt_image.to(device).half()

# Create a DataLoader for the dataset
transform = ToTensor()  # Convert images to tensors
dataset = CustomDataset("R:/lr/", "R:/hr/", transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
(filename, name, a, b, c) = upscaleModel[0]
model = Anime4kUpscale(a, b, c).to(device).half()
model.import_param(filename)
model.eval()

# Define a loss function (e.g., Mean Squared Error)
criterion = nn.MSELoss()

# Initialize a variable to store the total loss
total_loss = 0.0

# Evaluate the model
with torch.no_grad():
    for low_res, gt in dataloader:
        # Upscale the low-resolution image using the model
        upscaled_image = model(low_res)

        # Calculate the loss between the upscaled image and ground truth
        loss = criterion(upscaled_image, gt)

        # Add the loss to the total loss
        total_loss += loss.item()

# Calculate the average loss over the dataset
average_loss = total_loss / len(dataset)

print(f"Average Loss: {average_loss:.4f}")

Average Loss: 0.0004


In [5]:
for filename, name, a, b, c in upscaleModel:
    onnx_path = f"onnxModel/{name}.onnx"
    model = Anime4kUpscale(a, b, c).to(device).half()
    model.import_param(filename)
    torch.save(model.state_dict(), f'{name}.pth')

In [3]:
model = Anime4kUpscale(a, b, c).to(device).half()

NameError: name 'a' is not defined