# SRCNN
- 참고논문: [Image Super-Resolution Using Deep Convolutional Networks](https://arxiv.org/pdf/1501.00092)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 1. 필요 라이브러리 불러오기

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms
from torchvision.utils import save_image

import time
import PIL.Image as pil_image
import numpy as np
import os

## 3. 하이퍼 파라미터 정의

In [None]:
root = '/content/drive/MyDrive/ai_intro/SRCNN'  # PC에서는 root = './'로 변경

ckpt_dir = os.path.join(root, 'checkpoint')
test_db_dir = os.path.join(root, 'DB/test')
img_dir = os.path.join(root, 'results')

num_workers = 0

device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 4. 파라미터 저장/불러오기 함수

In [None]:
def load(filename, net, optim=None):
    dict_model = torch.load(filename, map_location=device)

    net.load_state_dict(dict_model['net'])
    if optim is not None:
        optim.load_state_dict(dict_model['optim'])

    return net, optim

## 5. 네트워크 및 손실함수 정의

In [None]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=3):
        super(SRCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=num_channels,
                               out_channels=64,
                               kernel_size=9,
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64,
                               out_channels=32,
                               kernel_size=5,
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5,
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        
        return x

model = SRCNN().to(device)

## 7. Dataset, data loader 선언

In [None]:
class SRDataset(Dataset):
    def __init__(self, 
                 image_dir,
                 input_transform=None, 
                 target_transform=None, 
                 db_type='train'):
        super(SRDataset, self).__init__()
        
        self.db_type = db_type
        
        lr_dir = os.path.join(image_dir, "lr/")     # low-resolution image (input)
        self.lr_list = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir)]
        self.lr_list.sort()

        if self.db_type != 'test':
            hr_dir = os.path.join(image_dir, "hr/")     # high-resolution image (label)
            self.hr_list = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir)]
            self.hr_list.sort()
        
        total_len = len(self.lr_list)
        train_len = round(total_len * 0.9)
        
        self.input_transform = input_transform
        self.target_transform = target_transform

        if self.db_type is 'train':
            self.lr_list = self.lr_list[:train_len]
            self.hr_list = self.hr_list[:train_len]
        elif self.db_type is 'val':
            self.lr_list = self.lr_list[train_len:]
            self.hr_list = self.hr_list[train_len:]
        elif self.db_type is 'test':
            self.lr_list = self.lr_list

        self.crop_size = 33

    def __getitem__(self, idx):
        lr_image = pil_image.open(self.lr_list[idx])
        if self.db_type != 'test':
            hr_image = pil_image.open(self.hr_list[idx])
        
        # Transform images
        if self.input_transform is not None:
            lr_image = self.input_transform(lr_image)
        
        if (self.target_transform is not None) and (self.db_type != 'test'):
            hr_image = self.target_transform(hr_image)

        if self.db_type is 'train':
            # Random crop
            lr_image, hr_image = self.random_crop(lr_image, hr_image)
        
        if self.db_type is 'test':
            return lr_image
        else:
            return lr_image, hr_image
            
    def random_crop(self, input, target):
        h = input.size(-2)
        w = input.size(-1)
                
        rand_h = torch.randint(h - self.crop_size, [1, 1])
        rand_w = torch.randint(w - self.crop_size, [1, 1])
        
        input = input[:, rand_h:rand_h + self.crop_size, rand_w:rand_w + self.crop_size]
        target = target[:, rand_h:rand_h + self.crop_size, rand_w:rand_w + self.crop_size]
        
        return input, target
        
    def __len__(self):
        return len(self.lr_list)


input_transform = transforms.Compose(
                    [transforms.ToTensor()])
target_transform = transforms.Compose(
                    [transforms.ToTensor()])

test_dataset = SRDataset(test_db_dir, 
                         input_transform=input_transform,
                         target_transform=target_transform,
                         db_type='test')
test_dataloader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False)

## 8. Training & validation

In [None]:
model, _ = load(os.path.join(ckpt_dir, 'model_epoch_best.pth'), net=model)

model.eval()
psnr_arr = []

if not os.path.exists(img_dir):
    os.mkdir(img_dir)

for n, input in enumerate(test_dataloader):
    input = input.to(device)
    
    with torch.no_grad():
        preds = model(input)
        preds = preds.clamp(0.0, 1.0)
        preds.squeeze()
        
    save_image(preds, os.path.join(img_dir, 'result%d.png' % (n)))   # 영상 저장

