## SR Model 정의

In [None]:
# ESA_SAPEON_bigger_B3_x2

import torch
import torch.nn as nn
import torch.nn.functional as F


def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation,
                     groups=groups)


def norm(norm_type, nc):
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def pad(pad_type, padding):
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
               pad_type='zero', norm_type=None, act_type='relu'):
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0

    c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                  dilation=dilation, bias=bias, groups=groups)
    a = activation(act_type) if act_type else None
    n = norm(norm_type, out_nc) if norm_type else None
    return sequential(p, c, n, a)


def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer
    
def pixelshuffle_block(in_channels, out_channels, upscale_factor=3, kernel_size=3, stride=1):
    if upscale_factor == 3:
        conv = conv_layer(in_channels, 9 * out_channels, 3, stride)
        pixel_shuffle = nn.PixelShuffle(upscale_factor)
    else :
        conv = conv_layer(in_channels, out_channels * (upscale_factor ** 2), kernel_size, stride)
        pixel_shuffle = nn.PixelShuffle(upscale_factor)
    return sequential(conv, pixel_shuffle)

def mean_channels(F):
    assert(F.dim() == 4)
    spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
    return spatial_sum / (F.size(2) * F.size(3))

def stdv_channels(F):
    assert(F.dim() == 4)
    F_mean = mean_channels(F)
    F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
    return F_variance.pow(0.5)

def sequential(*args):
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)



class ESA(nn.Module):
    def __init__(self, n_feats, conv):
        super(ESA, self).__init__()
        f = n_feats // 4
        self.conv1 = conv(n_feats, f, kernel_size=1)
        self.conv_f = conv(f, f, kernel_size=1)
        self.conv_max = conv(f, f, kernel_size=3, padding=1)
        self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=1)
        self.conv3 = conv(f, f, kernel_size=3, padding=1)
        self.conv3_ = conv(f, f, kernel_size=3, padding=1)
        #upsample_block = pixelshuffle_block
        #self.upsampler = upsample_block(f, f, upscale_factor=8)
        self.conv4 = conv(f, n_feats, kernel_size=1)        
        self.clip_with = nn.ReLU6()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        c1_ = (self.conv1(x))
        c1 = self.conv2(c1_) #F.avg_pool2d(c1_, kernel_size=2, stride=2)
        v_max = F.max_pool2d(c1, kernel_size=4, stride=4)
        v_range = self.relu(self.conv_max(v_max))
        c3 = self.relu(self.conv3(v_range))
        c3 = self.conv3_(c3)
        #c3 = self.upsampler(c3) # x2에서는 pixelshuffle이 문제되는 것으로 보여 임시로 interpolate로 upsample함
        c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bicubic', align_corners=False)
        cf = self.conv_f(c1_)
        c4 = self.conv4(c3+cf)
        m = self.clip_with(c4)
        return x * m


class RFDB(nn.Module):
    def __init__(self, in_channels, distillation_rate=0.25):
        super(RFDB, self).__init__()
        self.dc = self.distilled_channels = in_channels//2
        self.rc = self.remaining_channels = in_channels
        self.c1_d = conv_layer(in_channels, self.dc, 1)
        self.c1_r = conv_layer(in_channels, self.rc, 3)
        self.c2_d = conv_layer(self.remaining_channels, self.dc, 1)
        self.c2_r = conv_layer(self.remaining_channels, self.rc, 3)
        self.c4 = conv_layer(self.remaining_channels, self.dc, 3)
        self.act = activation('relu', neg_slope=0.05) # lrelu
        self.c5 = conv_layer(self.dc*4, in_channels, 1)
        self.esa = ESA(in_channels, nn.Conv2d)

    def forward(self, input):
        distilled_c1 = self.act(self.c1_d(input))
        r_c1 = self.act(self.c1_r(input))
        r_c1 = r_c1+input

        distilled_c2 = self.act(self.c2_d(r_c1))
        r_c2 = self.act(self.c2_r(r_c1))
        r_c2 = r_c2+r_c1

        r_c4 = self.act(self.c4(r_c2))

        out = torch.cat([distilled_c1, distilled_c2, r_c4, r_c4], dim=1)
        out_fused = self.esa(self.act(self.c5(out)))

        return out_fused


class RFDN64(nn.Module):
    def __init__(self, in_nc=3, nf=64, num_modules=3, out_nc=3, upscale=3):
        super(RFDN64, self).__init__()

        self.fea_conv = conv_layer(in_nc, nf, kernel_size=3)

        self.B1 = RFDB(in_channels=nf)
        self.B2 = RFDB(in_channels=nf)
        self.B3 = RFDB(in_channels=nf)

        self.c = conv_block(nf * num_modules, nf, kernel_size=3, act_type='relu') # lrelu

        upsample_block = pixelshuffle_block
        self.upsampler = upsample_block(nf, out_nc, upscale_factor=upscale)
        self.scale_idx = 0
        self.init = 0

        self.act = nn.ReLU(inplace=True)

    def forward(self, input):        

        out_fea = self.act(self.fea_conv(input))
        out_B1 = self.B1(out_fea)
        out_B2 = self.B2(out_B1)
        out_B3 = self.B3(out_B2)

        out_B = self.c(torch.cat([out_B1, out_B2, out_B3], dim=1))

        out_lr = out_B + out_fea + out_B3

        output = self.upsampler(out_lr)
        output = torch.clamp(output, min=0, max=1)    

        return output

    def set_scale(self, scale_idx):
        self.scale_idx = scale_idx

## 모델이 RGB input, PixelShuffle upsample일때만 동작 가능.

In [None]:
from __future__ import print_function
import argparse
import torch
from PIL import Image
from torchvision.transforms import ToTensor
import torch.nn as nn
import numpy as np
from collections import OrderedDict

img = Image.open('kaimedia_sample_3220.png').convert('RGB')
img = img.resize((960,540),Image.BICUBIC) # 이미지 크기가 작다면 resize 과정 필요 X
img.save('kaimedia_sample_after_bic_3220.png')

print('input image size {:d}x{:d}'.format(img.size[0], img.size[1]))

model = RFDN64(upscale=2)

state_dict = torch.load('kaimedia_star_model.pth') # 추론에 쓸 모델 가중치 필요

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

img_to_tensor = ToTensor()
input = img_to_tensor(img).view(1, -1, img.size[1], img.size[0])

model = model.cuda()
model = nn.DataParallel(model) # train을 이걸로 했으면 지금처럼 추론때도 해줘야 함.
input = input.cuda()

print(input.size())

out = model(input)
out = out.cpu()

print(out.size())

out_img = out[0].detach().numpy()
out_img *= 255.0
out_img = out_img.clip(0, 255)
out_img = np.transpose(out_img,(1,2,0))
out_img = np.uint8(out_img)

out_img = Image.fromarray(out_img, mode='RGB')
out_img.save('kaimedia_sample_result_3220.png')

print('output image saved to ', 'kaimedia_sample_result_3220.png')

input image size 960x540
torch.Size([1, 3, 540, 960])
torch.Size([1, 3, 1080, 1920])
output image saved to  kaimedia_sample_result_3220.png
