In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from IPython.display import display
from PIL import Image

def conv_layer(in_channels, out_channels, kernel_size):
    padding = int((kernel_size - 1) / 2)
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)

class CReLU(nn.Module):
    def __init__(self):
        super(CReLU, self).__init__()

    def forward(self, x):
        return F.relu(torch.cat((x, -x), 1))

class anime4k(nn.Module):
    def __init__(self, block_depth=7, stack_list=5, num_feat=12, act="crelu", last=False):
        super(anime4k, self).__init__()
        if act == "crelu":
            factor = 2
            self.act = CReLU()
        elif act == "prelu":
            factor = 1
            self.act = nn.PReLU(num_parameters=num_feat)
        if type(stack_list) == int:
            stack_list = list(range(-stack_list, 0))
        self.stack_list = stack_list
        self.ps = nn.PixelShuffle(2)
        
        self.conv_head = conv_layer(3, num_feat, kernel_size=3)
        self.conv_mid = nn.ModuleList(
            [
                conv_layer(num_feat * factor, num_feat, kernel_size=3)
                for _ in range(block_depth - 1)
            ]
        )
        if last:
            self.conv_tail = conv_layer(factor * num_feat * len(stack_list), 12, kernel_size=3)
        else:
            self.conv_tail = conv_layer(factor * num_feat * len(stack_list), 12, kernel_size=1)

    def forward(self, x):
        out = self.act(self.conv_head(x))
        depth_list = [out]
        for conv in self.conv_mid:
            out = self.act(conv(out))
            depth_list.append(out)
        out = self.conv_tail(torch.cat([depth_list[i] for i in self.stack_list], 1))
        out = self.ps(out) + F.interpolate(x, scale_factor=2, mode="bilinear")
        # out = self.ps(out)*50+0.5
        return torch.clamp(out, max=1.0, min=0.0)
    

to_pil = torchvision.transforms.ToPILImage()
to_tensor = torchvision.transforms.ToTensor()      
            
device = torch.device("cuda")
model = anime4k(block_depth=14, stack_list=7, num_feat=12).to(device).half()
model.eval()
model.load_state_dict(torch.load("E:/project/neosr/experiments/amakano_a14-7-r_b2_li-la/models/net_g_latest.pth", map_location=device)['params'])
image2 = to_tensor(Image.open("E:/Dataset/val/sora.png").convert("RGB")).unsqueeze(0).half().to(device)
out = model(image2)[0]
print(out.shape)
# display(to_pil(out[4:7, :, :]))
display(to_pil(out))
# out1 = out[0, :, :]
# out2 = out[1, :, :]
# out3 = out[2, :, :]
# out4 = out[3, :, :]
# display(to_pil(out1))
# display(to_pil(out2))
# display(to_pil(out3))
# display(to_pil(out4))

In [4]:
model = anime4k(block_depth=7, stack_list=5, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=14, stack_list=7, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model = anime4k(block_depth=17, stack_list=9, num_feat=12)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

17412
36216
44604
