In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from IPython.display import display
from PIL import Image
import re
import math

def conv_layer(in_channels, out_channels, kernel_size):
    padding = int((kernel_size - 1) / 2)
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)

class artCNN(nn.Module):
    def __init__(self, block_depth=4, num_feat=32):
        super(artCNN, self).__init__()
        self.act = nn.ReLU()
        self.ps = nn.PixelShuffle(2)
        io_channels = 3
        self.conv_head = conv_layer(io_channels, num_feat, kernel_size=3)
        self.conv_mid = nn.ModuleList(
            [
                conv_layer(num_feat, num_feat, kernel_size=3)
                for _ in range(block_depth)
            ]
        )
        self.conv_mid2 = conv_layer(num_feat, num_feat, kernel_size=3)
        self.conv_tail = conv_layer(num_feat, io_channels*2**2, kernel_size=3)
    def forward(self, x):
        out = self.conv_head(x)
        store = out
        for conv in self.conv_mid:
            out = self.act(conv(out))
        out = self.conv_mid2(out)
        out = store+out
        out = self.conv_tail(out)
        out = self.ps(out)
        return torch.clamp(out, max=1.0, min=0.0)
    
    def import_param(self, filename):
        for param in self.parameters():
            param.requires_grad = False
        with open(filename) as f:
            text = f.read()
        pattern = r'-?\d+(\.\d{2,})(e-?\d+)?'
        iter = re.finditer(pattern, text)
        convert(self.conv_head, iter)
        for conv in self.conv_mid:
            convert(conv, iter)
        convert(self.conv_mid2, iter)
        convert(self.conv_tail, iter, False)
        check = next(iter, None)
        if check == None:
            print("pass")
        else:
            print("---failed---\n", check)
            
def convert(c, iter, doswap=False):
    swap = [0,2,1,3]
    out_chan, in_chan, width, height = c.weight.shape
    for to in range(math.ceil(out_chan/4)):
        for ti in range(math.ceil(in_chan/4)):
            for w in range(width):
                for h in range(height):
                    for i in range(min(4, in_chan)):
                        for o in range(min(4, out_chan)):
                            o = swap[o] if doswap else o
                            c.weight.data[to*4+o, ti*4+i, h, w] = float(next(iter).group(0))
        for o in range(min(4, out_chan)):
            o = swap[o] if doswap else o
            c.bias.data[to*4+o] = float(next(iter).group(0))


In [None]:
to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()      
            
device = torch.device("cuda")
model = artCNN().to(device).half()
model.import_param("R:/ArtCNN_C4F32_RGB.glsl")
model.eval()
image = to_tensor(Image.open("C:/Users/khoi/Videos/koiama.png").convert("RGB")).unsqueeze(0).half().to(device)
out = model(image)[0]
display(to_pil(out))


In [13]:
model = artCNN(num_feat=32).to(device).half()
model.import_param("R:/ArtCNN_C4F32_SH_RGB.glsl")
dummy_input = torch.randn(1, 3, 720, 1280).to(device).half()
onnx_path = "R:/ArtCNN_C4F32_SH_RGB.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "width", 3: "height"},
        "output": {0: "batch_size", 2: "width", 3: "height"},
    },
)

pass
