# SRCNN
- 참고논문: [Image Super-Resolution Using Deep Convolutional Networks](https://arxiv.org/pdf/1501.00092)

## 1. 필요 라이브러리 불러오기

In [112]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

import time
import PIL.Image as pil_image
import numpy as np
import os

## 3. 하이퍼 파라미터 정의

In [113]:
ckpt_dir = './checkpoint'
train_db_dir = './DB/train'
log_dir = './log'

lr = 0.0001     # cf. lr = 1e-4
batch_size = 32
num_epochs = 50
num_workers = 0

device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 4. 파라미터 저장/불러오기 함수

In [114]:
def save(ckpt_dir, net, optim, postfix):      # ckpt_dir: checkpoint를 저장할 경로, net, optim, epoch
    if not os.path.exists(ckpt_dir):    # ckpt_dir이 존재하는지 확인하는 함수
        os.makedirs(ckpt_dir)           # 디렉토리를 만들어주는 함수

    torch.save({'net': net.state_dict(),        # 네트워크에 있는 변수들
                'optim': optim.state_dict()},   # optimizer에 있는 변수들
               './%s/model_epoch%s.pth' % (ckpt_dir, str(postfix)))

def load(filename, net, optim):
    dict_model = torch.load(filename)

    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])

    return net, optim

## 5. 네트워크 및 손실함수 정의

In [115]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=3):
        super(SRCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=num_channels,
                               out_channels=64,
                               kernel_size=9,
                               padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64,
                               out_channels=32,
                               kernel_size=5,
                               padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, 
                               out_channels=num_channels, 
                               kernel_size=5,
                               padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        
        return x

model = SRCNN().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

## 6. 정확도 측정 함수 (PSNR) 선언
PSNR: [wiki](https://ko.wikipedia.org/wiki/%EC%B5%9C%EB%8C%80_%EC%8B%A0%ED%98%B8_%EB%8C%80_%EC%9E%A1%EC%9D%8C%EB%B9%84)

$\textrm{PSNR}=10\cdot\log{\left(\frac{\textrm{MAX}^2}{\textrm{MSE}}\right)}$

$\textrm{MAX}$: tensor가 가질 수 있는 최대값

$\textrm{MSE}$: Mean square error

In [116]:
def calculate_psnr(image1, image2):
    max_val = 1.
    mse = torch.mean((image1 - image2) ** 2)
    psnr = 10. * torch.log10((max_val ** 2) / mse)
    return psnr.item()

## 7. Dataset, data loader 선언

In [117]:
class SRDataset(Dataset):
    def __init__(self, 
                 image_dir,
                 input_transform=None, 
                 target_transform=None, 
                 db_type='train'):
        super(SRDataset, self).__init__()
        
        lr_dir = os.path.join(image_dir, "lr/")     # low-resolution image (input)
        hr_dir = os.path.join(image_dir, "hr/")     # high-resolution image (label)
        
        self.lr_list = [
            os.path.join(lr_dir, f) for f in os.listdir(lr_dir)]
        self.hr_list = [
            os.path.join(hr_dir, f) for f in os.listdir(hr_dir)]
        
        self.lr_list.sort()
        self.hr_list.sort()
        
        total_len = len(self.hr_list)
        train_len = round(total_len * 0.9)
        
        self.input_transform = input_transform
        self.target_transform = target_transform
        
        self.db_type = db_type
        
        if self.db_type is 'train':
            self.lr_list = self.lr_list[:train_len]
            self.hr_list = self.hr_list[:train_len]
        elif self.db_type is 'val':
            self.lr_list = self.lr_list[train_len:]
            self.hr_list = self.hr_list[train_len:]

        self.crop_size = 33

    def __getitem__(self, idx):
        lr_image = pil_image.open(self.lr_list[idx])
        hr_image = pil_image.open(self.hr_list[idx])
        
        # Transform images
        if self.input_transform is not None:
            lr_image = self.input_transform(lr_image)
        
        if self.target_transform is not None:
            hr_image = self.target_transform(hr_image)

        if self.db_type is 'train':
            # Random crop
            lr_image, hr_image = self.random_crop(lr_image, hr_image)
        
        return lr_image, hr_image
            
    def random_crop(self, input, target):
        h = input.size(-2)
        w = input.size(-1)
                
        rand_h = torch.randint(h - self.crop_size, [1, 1])
        rand_w = torch.randint(w - self.crop_size, [1, 1])
        
        input = input[:, rand_h:rand_h + self.crop_size, rand_w:rand_w + self.crop_size]
        target = target[:, rand_h:rand_h + self.crop_size, rand_w:rand_w + self.crop_size]
        
        return input, target
        
    def __len__(self):
        return len(self.hr_list)


input_transform = transforms.Compose(
                    [transforms.ToTensor()])
target_transform = transforms.Compose(
                    [transforms.ToTensor()])

train_dataset = SRDataset(train_db_dir,
                          input_transform=input_transform,
                          target_transform=target_transform,
                          db_type='train')
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              drop_last=True)
val_dataset = SRDataset(train_db_dir, 
                        input_transform=input_transform,
                        target_transform=target_transform,
                        db_type='val')
val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=1,
                            shuffle=False)

## 8. Training & validation

In [118]:
best_epoch = 0
best_psnr = 0.0

num_train = len(train_dataset) // batch_size

writer = SummaryWriter(log_dir='./log')

start_t = time.time()
for epoch in range(num_epochs):
    model.train()
    
    loss_arr = []
    
    for iter, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        preds = model(inputs)     # net.forward(input)
        
        loss = criterion(preds, labels)

        optimizer.zero_grad()   # G = 0 

        loss.backward()     # 그라디언트를 계산

        optimizer.step()    # 그라디언트를 사용하여 파라미터를 업데이트

        loss_arr += [loss.item()]
       
        if iter % 10 == 0:
            elapsed_time = time.time() - start_t
            print('TRAIN(Elapsed: %fs): EPOCH %d/%d | BATCH %d/%d | LOSS: %f' %
                (epoch, num_epochs, iter+1, num_train, np.mean(loss_arr)))
            start_t = time.time()
        
    save(ckpt_dir=ckpt_dir, net=model, optim=optimizer, postfix='{}'.format(epoch+1))
    
    model.eval()
    psnr_arr = []
    
    for n, (input, label) in enumerate(val_dataloader):
        input, label = input.to(device), label.to(device)
        
        with torch.no_grad():
            preds = model(input)
            preds = preds.clamp(0.0, 1.0)
            
        psnr = calculate_psnr(preds, label)
        psnr_arr.append(psnr)
    
    epoch_psnr = np.mean(psnr_arr)
    print('PSNR(epoch: %d): %f' % (epoch+1, epoch_psnr))
    
    if epoch_psnr > best_psnr:
        best_psnr = epoch_psnr
        best_epoch = epoch
        save(ckpt_dir=ckpt_dir, net=model, optim=optimizer, postfix='_best')
    
    # Log on tensorboard
    writer.add_scalar('loss', np.mean(loss_arr), epoch)
    writer.add_scalar('psnr', epoch_psnr, epoch)
        
writer.close
print('best epoch: {}, psnr: {}'.format(best_epoch, best_psnr))
save(ckpt_dir=ckpt_dir, net=model, optim=optimizer, postfix='_best')
    

TRAIN: EPOCH 0/100 | BATCH 1/346 | LOSS: 0.265382
TRAIN: EPOCH 0/100 | BATCH 11/346 | LOSS: 0.181070
TRAIN: EPOCH 0/100 | BATCH 21/346 | LOSS: 0.119643
TRAIN: EPOCH 0/100 | BATCH 31/346 | LOSS: 0.091432
TRAIN: EPOCH 0/100 | BATCH 41/346 | LOSS: 0.075563
TRAIN: EPOCH 0/100 | BATCH 51/346 | LOSS: 0.064887
TRAIN: EPOCH 0/100 | BATCH 61/346 | LOSS: 0.057650
TRAIN: EPOCH 0/100 | BATCH 71/346 | LOSS: 0.051923
TRAIN: EPOCH 0/100 | BATCH 81/346 | LOSS: 0.047460
TRAIN: EPOCH 0/100 | BATCH 91/346 | LOSS: 0.043726
TRAIN: EPOCH 0/100 | BATCH 101/346 | LOSS: 0.040757
TRAIN: EPOCH 0/100 | BATCH 111/346 | LOSS: 0.038043
TRAIN: EPOCH 0/100 | BATCH 121/346 | LOSS: 0.035734
TRAIN: EPOCH 0/100 | BATCH 131/346 | LOSS: 0.033609
TRAIN: EPOCH 0/100 | BATCH 141/346 | LOSS: 0.031787
TRAIN: EPOCH 0/100 | BATCH 151/346 | LOSS: 0.030169
TRAIN: EPOCH 0/100 | BATCH 161/346 | LOSS: 0.028690
TRAIN: EPOCH 0/100 | BATCH 171/346 | LOSS: 0.027427
TRAIN: EPOCH 0/100 | BATCH 181/346 | LOSS: 0.026259
TRAIN: EPOCH 0/100 | BA

KeyboardInterrupt: 