# Libraries

In [None]:
import os
import easydict
import time

import numpy as np
import cv2
import math

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

from torch.utils import data as D
from torch.utils.data import DataLoader

from tqdm import tqdm
import glob

# Mount google drive to colab environment
The location will at /content/gdrive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# Check mounted google drive and its contens
Make sure you to save data inside google driver

In [None]:
!pwd  #리눅스 현재 경로 명령어와 동일
#!ls -l
#!rm -rf /content/div2k_100/

#파일 경로 /content/gdrive/MyDrive/srcnn

/content


In [None]:
!unzip -q "/content/gdrive/My Drive/srcnn/div2k_100.zip" -d "/content/gdrive/My Drive/srcnn/div2k_100-2"

In [None]:
!ls -l "/content/gdrive/My Drive/srcnn/div2k_100-2"

total 16
drwx------ 2 root root 4096 Dec 23 09:13 checkpoint_dir
drwx------ 4 root root 4096 Dec 23 09:19 test_images
drwx------ 4 root root 4096 Dec 23 09:21 train_patches_x4lr64
drwx------ 4 root root 4096 Dec 23 09:19 valid_patches_x4lr64


In [None]:
!python --version
print(torch.__version__)
print(torch.cuda.is_available())

Python 3.7.12
1.10.0+cu111
True


# Set your data directory into div2k_dir

In [None]:
div2k_dir = "/content/gdrive/My Drive/srcnn/div2k_100-2"

# Options to control overal programming behaviors

In [None]:
# import argparse
opt = easydict.EasyDict({
    "resume": True,
    "resume_best": True,
    "use_npy": True,
    
    "multi_gpu": True,
    "use_cuda": True,
    "device": 'cuda',

    "n_epochs": 100, # Total number of epoch to iterate
    "batch_size": 100, # Size of batch of one epoch
    "start_epoch": 1,
    "lr": 1e-4, # Adam: learning rate
    "b1": 0.9, # Adam: The exponential decay rate for the first moment estimates
    "b2": 0.999, # Adam: The exponential decay rate for the second-moment estimates
    
    "checkpoint_dir": None,   # 학습 모델 저장 dir location
    "data_dir": None,
    "train_dir": None,
    "valid_dir": None,
    "test_dir": None,
    "test_result_dir": None,
    
    "lr_img": 64,   # input patch size
    "res_scale": 4, # output image 64 * 4 = 256
    "n_channels": 3
})

opt.data_dir = div2k_dir
opt.checkpoint_dir = os.path.join(opt.data_dir, "checkpoint_dir")
opt.train_dir = os.path.join(opt.data_dir, "train_patches_x" + str(opt.res_scale) + "lr" + str(opt.lr_img))
opt.valid_dir = os.path.join(opt.data_dir, "valid_patches_x" + str(opt.res_scale) + "lr" + str(opt.lr_img))
opt.test_dir = os.path.join(opt.data_dir, "test_images")
opt.test_result_dir = os.path.join(opt.data_dir, "test_result")

print(opt)

{'resume': True, 'resume_best': True, 'use_npy': True, 'multi_gpu': True, 'use_cuda': True, 'device': 'cuda', 'n_epochs': 100, 'batch_size': 100, 'start_epoch': 1, 'lr': 0.0001, 'b1': 0.9, 'b2': 0.999, 'checkpoint_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir', 'data_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2', 'train_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2/train_patches_x4lr64', 'valid_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2/valid_patches_x4lr64', 'test_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2/test_images', 'test_result_dir': '/content/gdrive/My Drive/srcnn/div2k_100-2/test_result', 'lr_img': 64, 'res_scale': 4, 'n_channels': 3}


# Noramlization of image
This function will make image range from 0 ~ 255 to 0 ~ 1.0 and convert type from int16 to float32

In [None]:
def normalize_img(img):
    img = img / 255.
    img = img.astype(np.float32)
    return img

# Define your dataset
This defined dataset can be easily loaded by pytorch library

In [None]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, data_dir, use_npy):
        super(DatasetFromFolder, self).__init__()

        self.use_npy = use_npy

        if self.use_npy:
            lr_dir = os.path.join(data_dir, 'lr.npy')
            hr_dir = os.path.join(data_dir, 'hr.npy')
        else:
            lr_dir = os.path.join(data_dir, 'lr')
            hr_dir = os.path.join(data_dir, 'hr')
        
        self.dsets = {}

        if self.use_npy:
            self.dsets['lr'] = np.load(lr_dir)
            self.dsets['hr'] = np.load(hr_dir)
        else:
            lr_list = os.listdir(lr_dir)
            hr_list = os.listdir(hr_dir)

            lr_list.sort()
            hr_list.sort()
            self.dsets['lr'] = [os.path.join(lr_dir, x) for x in lr_list]
            self.dsets['hr'] = [os.path.join(hr_dir, x) for x in hr_list]
        
        self.dsets['file_name'] = os.listdir(os.path.join(data_dir, 'hr'))

    def __getitem__(self, idx):

        if self.use_npy:
            input = self.dsets['lr'][idx]
            target = self.dsets['hr'][idx]
        else:
            input = cv2.imread(self.dsets['lr'][idx])
            target = cv2.imread(self.dsets['hr'][idx])

        file_name = self.dsets['file_name'][idx]

        # 64 x 64 (width x height)-> 256 x 256 upscale before entering them into the network
        #bicubic interpolation으로 input 전에 resize 해주기
        input = cv2.resize(input, (target.shape[1], target.shape[0]), interpolation=cv2.INTER_CUBIC)  # normally ()pytorch convention (height, width)
        
        input = normalize_img(input)
        target = normalize_img(target)

        input = np.transpose(input, (2, 0, 1))  #  (height, width, channel) -> pytorch convention (channel, height, width)
        target = np.transpose(target, (2, 0, 1))

        # pytorch 용 tensor로 변환
        input = torch.from_numpy(input).type(torch.FloatTensor)
        target = torch.from_numpy(target).type(torch.FloatTensor)

        return input, target, file_name

    def __len__(self):
        return len(self.dsets['lr'])

# Shift function
This function will shift your image from 0 ~ 1.0 to -mean ~ (1 - mean) for each RGB channel. This will make neural network to learn easily

In [None]:
# mean set to zero
class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

# Define your model
See details of [paper](https://arxiv.org/abs/1501.00092)

In [None]:
class SRCNN(nn.Module):
    def __init__(self, opt):
        super(SRCNN, self).__init__()
        pix_range = 1.0
        
        self.sub_mean = MeanShift(pix_range)
        self.add_mean = MeanShift(pix_range, sign=1)
        
        # CLASStorch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, 
        #          padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.conv1 = nn.Conv2d(opt.n_channels, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2) #filter size 9-5-5로 수정
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, opt.n_channels, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.sub_mean(x)
        #residual = x
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.conv3(x)
        #x = torch.add(x, residual)
        
        out = self.add_mean(x)
        return out

# Prepare for your unwanted system or program down
This function saves the progress of learning for each echo. You can load the mode later from the last learned model. Of course, you have to define the function of loading model.

In [None]:
def save_checkpoint(srcnn, epoch, loss):
    checkpoint_dir = os.path.abspath(opt.checkpoint_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, "models_epoch_%04d_loss_%.6f.pth" % (epoch, loss))
    
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        state = {"epoch": epoch, "srcnn": srcnn.module}
    else:
        state = {"epoch": epoch, "srcnn": srcnn}

    torch.save(state, checkpoint_path)
    print("Checkpoint saved to {}".format(checkpoint_path))

In [None]:
def load_model(checkpoint_dir):
    checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
    checkpoint_list.sort()

    if opt.resume_best:
        loss_list = list(map(lambda x: float(os.path.basename(x).split('_')[4][:-4]), checkpoint_list))
        best_loss_idx = loss_list.index(min(loss_list))
        checkpoint_path = checkpoint_list[best_loss_idx]
    else:
        checkpoint_path = checkpoint_list[len(checkpoint_list) - 1]

    srcnn = SRCNN(opt)

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint '{}'".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        
        n_epoch = checkpoint['epoch']
        srcnn.load_state_dict(checkpoint['srcnn'].state_dict())
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(checkpoint_path, n_epoch))
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_path))
        n_epoch = 0

    return n_epoch + 1, srcnn

# Define your train step

In [None]:
def train(opt, model, optimizer, data_loader, loss_criterion):
    print("===> Training")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0

    for iteration, batch in enumerate(tqdm(data_loader), 1):
        x, target = batch[0], batch[1]  # return input[0], target[1], file_name[2]
        if opt.use_cuda:
            x = x.to(opt.device)
            target = target.to(opt.device)

        out = model(x)    # SRCNN result

        loss = loss_criterion(out, target)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        psnr = 10 * math.log10(1. / loss.item())
        total_psnr += psnr

        # print("Training %.2fs => Epoch[%d/%d](%d/%d): Loss: %.5f PSNR: %.5f" %
        #     (time.time() - start_time, opt.epoch_num, opt.n_epochs, iteration, len(data_loader), loss.item(), psnr))

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("***Training %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

# Define your evaluation method for validation dataset

In [None]:
def evaluate(opt, model, data_loader, loss_criterion):
    print("===> Validation")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0
    with torch.no_grad(): # no need to use backpropagation computation unlike training stage
        for iteration, batch in enumerate(tqdm(data_loader), 1):
            x, target = batch[0], batch[1]

            if opt.use_cuda:
                x = x.to(opt.device)
                target = target.to(opt.device)

            out = model(x)
            
            loss = loss_criterion(out, target)
            total_loss += loss.item()

            psnr = 10 * math.log10(1. / loss.item())
            total_psnr += psnr

            # print("Validation %.2fs => Epoch[%d/%d](%d/%d): Loss: %.5f PSNR: %.5f" %
            #     (time.time() - start_time, opt.epoch_num, opt.n_epochs, iteration, len(data_loader), loss.item(), psnr))

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("***Validation %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

# Define whole training process

In [None]:
def run_train(opt, training_data_loader, validation_data_loader):
    # Define what device we are using
    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)

    log_file = os.path.join(opt.checkpoint_dir, "srcnn_log.csv")

    print('[Initialize networks for training]')
    srcnn = SRCNN(opt)
    l2_criterion = nn.MSELoss()
    print(srcnn)

    # optionally resume from a checkpoint
    if opt.resume:
        opt.start_epoch, srcnn = load_model(opt.checkpoint_dir)
    else: # csv file initialization
        with open(log_file, mode='w') as f:
            f.write("epoch,train_loss,valid_loss\n")

    print("===> Setting GPU")
    print("CUDA Available: ", torch.cuda.is_available())
    if opt.use_cuda and torch.cuda.is_available():
        opt.use_cuda = True
        opt.device = 'cuda'
    else:
        opt.use_cuda = False
        opt.device = 'cpu'
        
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    if opt.use_cuda:
        srcnn = srcnn.to(opt.device)
        l2_criterion = l2_criterion.to(opt.device)

    print("===> Setting Optimizer")
    optimizer = torch.optim.Adam(srcnn.parameters(),  lr=opt.lr, betas=(opt.b1, opt.b2))

    for epoch in range(opt.start_epoch, opt.n_epochs):
        opt.epoch_num = epoch
        train_loss, train_psnr = train(opt, srcnn, optimizer, training_data_loader, loss_criterion=l2_criterion)
        valid_loss, valid_psnr = evaluate(opt, srcnn, validation_data_loader, loss_criterion=l2_criterion)

        with open(log_file, mode='a') as f:
            f.write("%d,%08f,%08f,%08f,%08f\n" % (
                epoch,
                train_loss,
                train_psnr,
                valid_loss,
                valid_psnr
            ))
        save_checkpoint(srcnn, epoch, valid_loss)

# Now it's time to train the SRCNN model

In [None]:
train_dir = opt.train_dir
valid_dir = opt.valid_dir
print("train_dir is: {}".format(train_dir))
print("valid_dir is: {}".format(valid_dir))

target_size = (opt.lr_img * opt.res_scale, opt.lr_img * opt.res_scale)


train_dataset = DatasetFromFolder(train_dir, opt.use_npy)
valid_dataset = DatasetFromFolder(valid_dir, opt.use_npy)

training_data_loader = DataLoader(dataset=train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
validation_data_loader = DataLoader(dataset=valid_dataset,
                                    batch_size=opt.batch_size,
                                    shuffle=False)

run_train(opt, training_data_loader, validation_data_loader)

train_dir is: /content/gdrive/My Drive/srcnn/div2k_100-2/train_patches_x4lr64
valid_dir is: /content/gdrive/My Drive/srcnn/div2k_100-2/valid_patches_x4lr64
[Initialize networks for training]
SRCNN(
  (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (relu1): ReLU()
  (conv2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu2): ReLU()
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
=> loading checkpoint '/content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0077_loss_0.002023.pth'
=> loaded checkpoint '/content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0077_loss_0.002023.pth' (epoch 77)
===> Setting GPU
CUDA Available:  True
===> Setting Optimizer
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.44s => Epoch[78/100]: Loss: 0.00323 PSNR: 25.02582
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.11s => Epoch[78/100]: Loss: 0.00206 PSNR: 27.57876
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0078_loss_0.002064.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.05s => Epoch[79/100]: Loss: 0.00284 PSNR: 25.51125
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.08s => Epoch[79/100]: Loss: 0.00202 PSNR: 27.68944
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0079_loss_0.002023.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.85s/it]


***Training 102.77s => Epoch[80/100]: Loss: 0.00278 PSNR: 25.63033
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.16s => Epoch[80/100]: Loss: 0.00202 PSNR: 27.69527
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0080_loss_0.002021.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.35s => Epoch[81/100]: Loss: 0.00283 PSNR: 25.52735
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.12s => Epoch[81/100]: Loss: 0.00202 PSNR: 27.69967
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0081_loss_0.002019.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.85s/it]


***Training 102.72s => Epoch[82/100]: Loss: 0.00280 PSNR: 25.56370
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.11s => Epoch[82/100]: Loss: 0.00202 PSNR: 27.69993
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0082_loss_0.002019.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.85s/it]


***Training 102.76s => Epoch[83/100]: Loss: 0.00280 PSNR: 25.56104
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.13s => Epoch[83/100]: Loss: 0.00202 PSNR: 27.70247
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0083_loss_0.002019.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.50s => Epoch[84/100]: Loss: 0.00279 PSNR: 25.57473
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.38s/it]


***Validation 8.27s => Epoch[84/100]: Loss: 0.00202 PSNR: 27.70629
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0084_loss_0.002017.pth
===> Training


100%|██████████| 36/36 [01:44<00:00,  2.90s/it]


***Training 104.28s => Epoch[85/100]: Loss: 0.00279 PSNR: 25.59021
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.37s/it]


***Validation 8.25s => Epoch[85/100]: Loss: 0.00202 PSNR: 27.70865
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0085_loss_0.002016.pth
===> Training


100%|██████████| 36/36 [01:44<00:00,  2.90s/it]


***Training 104.32s => Epoch[86/100]: Loss: 0.00283 PSNR: 25.53359
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.37s/it]


***Validation 8.23s => Epoch[86/100]: Loss: 0.00202 PSNR: 27.70819
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0086_loss_0.002016.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.88s/it]


***Training 103.65s => Epoch[87/100]: Loss: 0.00277 PSNR: 25.64756
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.09s => Epoch[87/100]: Loss: 0.00201 PSNR: 27.71298
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0087_loss_0.002014.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.86s/it]


***Training 102.87s => Epoch[88/100]: Loss: 0.00282 PSNR: 25.53806
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.11s => Epoch[88/100]: Loss: 0.00201 PSNR: 27.71248
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0088_loss_0.002014.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.06s => Epoch[89/100]: Loss: 0.00279 PSNR: 25.56840
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.16s => Epoch[89/100]: Loss: 0.00202 PSNR: 27.70148
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0089_loss_0.002018.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.19s => Epoch[90/100]: Loss: 0.00276 PSNR: 25.69850
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.15s => Epoch[90/100]: Loss: 0.00201 PSNR: 27.71644
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0090_loss_0.002013.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.08s => Epoch[91/100]: Loss: 0.00278 PSNR: 25.60505
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.17s => Epoch[91/100]: Loss: 0.00201 PSNR: 27.70838
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0091_loss_0.002015.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.14s => Epoch[92/100]: Loss: 0.00284 PSNR: 25.52320
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]


***Validation 8.14s => Epoch[92/100]: Loss: 0.00202 PSNR: 27.70385
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0092_loss_0.002015.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.35s => Epoch[93/100]: Loss: 0.00282 PSNR: 25.55185
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.17s => Epoch[93/100]: Loss: 0.00201 PSNR: 27.71382
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0093_loss_0.002013.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.10s => Epoch[94/100]: Loss: 0.00281 PSNR: 25.54979
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.15s => Epoch[94/100]: Loss: 0.00203 PSNR: 27.66942
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0094_loss_0.002029.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.15s => Epoch[95/100]: Loss: 0.00281 PSNR: 25.53649
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.37s/it]


***Validation 8.24s => Epoch[95/100]: Loss: 0.00201 PSNR: 27.72298
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0095_loss_0.002010.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.87s/it]


***Training 103.26s => Epoch[96/100]: Loss: 0.00282 PSNR: 25.54259
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.19s => Epoch[96/100]: Loss: 0.00201 PSNR: 27.73420
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0096_loss_0.002006.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.86s/it]


***Training 102.94s => Epoch[97/100]: Loss: 0.00275 PSNR: 25.70132
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.37s/it]


***Validation 8.21s => Epoch[97/100]: Loss: 0.00203 PSNR: 27.67043
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0097_loss_0.002029.pth
===> Training


100%|██████████| 36/36 [01:43<00:00,  2.86s/it]


***Training 103.15s => Epoch[98/100]: Loss: 0.00279 PSNR: 25.58925
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.36s/it]


***Validation 8.18s => Epoch[98/100]: Loss: 0.00200 PSNR: 27.73910
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0098_loss_0.002004.pth
===> Training


100%|██████████| 36/36 [01:42<00:00,  2.86s/it]


***Training 102.86s => Epoch[99/100]: Loss: 0.00276 PSNR: 25.63481
===> Validation


100%|██████████| 6/6 [00:08<00:00,  1.35s/it]

***Validation 8.12s => Epoch[99/100]: Loss: 0.00200 PSNR: 27.74316
Checkpoint saved to /content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0099_loss_0.002003.pth





# Test model
We will use PSNR and SSIM metrics to measure how well the model is trained. We can skimage libraries to easily measure two metrics.

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Define test dataset
This time, we will not use dataloader of pytorch function. Instead, we will load each image from test data directory.

In [None]:
def get_test_dataset(data_dir, res_scale):
    lr_dir = os.path.join(data_dir, 'lrx' + str(res_scale))
    hr_dir = os.path.join(data_dir, 'hr')
        
    dsets = {}
    lr_list = os.listdir(lr_dir)
    hr_list = os.listdir(hr_dir)
    lr_list.sort()
    hr_list.sort()
    dsets['lr'] = [os.path.join(lr_dir, x) for x in lr_list]
    dsets['hr'] = [os.path.join(hr_dir, x) for x in hr_list]
    dsets['file_name'] = os.listdir(hr_dir)
    
    return dsets

# Define your test method of the model

In [None]:
def test_model(opt, test_dataset):
    
    lr_list = test_dataset['lr']
    hr_list = test_dataset['hr']
    filename_list = test_dataset['file_name']
    
    sr_compare_dir = os.path.join(opt.test_result_dir, "compare") #이미지 3개 비교가능(bicubic,prediction,정답)
    sr_result_dir = os.path.join(opt.test_result_dir, "sr")
    sr_input_dir = os.path.join(opt.test_result_dir, "bc") #bicubic 이미지 추출위한 폴더 생성

    if not os.path.exists(sr_result_dir):
        os.makedirs(sr_result_dir)
    if not os.path.exists(sr_compare_dir):
        os.makedirs(sr_compare_dir)
    if not os.path.exists(sr_input_dir):
        os.makedirs(sr_input_dir)    

    opt.resume_best = True  # validation loss best
    _, srcnn = load_model(opt.checkpoint_dir)
    criterion = nn.MSELoss()

    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    if opt.use_cuda and torch.cuda.is_available():
        opt.use_cuda = True
        opt.device = 'cuda'
    else:
        opt.use_cuda = False
        opt.device = 'cpu'

    if opt.use_cuda:
        srcnn = srcnn.to(opt.device)
        criterion = criterion.to(opt.device)

    hr_img_sz = (opt.lr_img * opt.res_scale, opt.lr_img * opt.res_scale)
    result_img = np.zeros((hr_img_sz[0], hr_img_sz[1] * 3, opt.n_channels))

    with torch.no_grad():
        total_num = 0
        sum_bicubic_psnr = 0.
        sum_sr_psnr = 0.
        sum_bicubic_ssim = 0.
        sum_sr_ssim = 0.
        
        avg_bicubic_psnr = 0.
        avg_sr_psnr = 0.
        avg_bicubic_ssim = 0.
        avg_sr_ssim = 0.

        bicubic_psnr_list = [] #표준편차 구하기 위한 리스트
        sr_psnr_list = []                       
        bicubic_ssim_list = []
        sr_ssim_list = []

        start_time = time.time()
        for batch in zip(lr_list, hr_list, filename_list):
            input_path, target_path, file_name = batch[0], batch[1], batch[2]

            # print(input_path)
            # print(target_path)

            # train 과는 다르게 data_loader 쓰지 않고, patch 아닌 한장 한장이미지 전체를 불러옴.            
            input = cv2.imread(input_path)
            target = cv2.imread(target_path)
            
            input = cv2.resize(input, (target.shape[1], target.shape[0]), interpolation=cv2.INTER_CUBIC)
            
            input = normalize_img(input)
            target = normalize_img(target)
            
            input = np.transpose(input, (2, 0, 1))
            target = np.transpose(target, (2, 0, 1))
            
            # 하나의 이미지 입력 (1)
            input = input.reshape(1, input.shape[0], input.shape[1], input.shape[2])
            target = target.reshape(1, target.shape[0], target.shape[1], target.shape[2])
            
            input = torch.from_numpy(input).type(torch.FloatTensor)
            target = torch.from_numpy(target).type(torch.FloatTensor)

            if opt.use_cuda:
                input = input.to(opt.device)
#                 target = target.to(opt.device)

            out = srcnn(input)
            
            for i in range(input.size(0)): 
                if opt.use_cuda:
                    input_arr = input[i].detach().to('cpu').data.numpy()
                    sr_arr = out[i].detach().to('cpu').data.numpy()
                else:
                    input_arr = input[i].data.numpy() # float tensor to numpy
                    sr_arr = out[i].data.numpy()
                
                target_arr = target[i].detach().data.numpy()
        
                input_arr = np.transpose(input_arr, (1, 2, 0))
                sr_arr = np.transpose(sr_arr, (1, 2, 0))
                target_arr = np.transpose(target_arr, (1, 2, 0))

                # 강제적으로 결과 범주를 정해줌 (negative pixel 값도 도출됨?)                
                sr_arr[sr_arr < 0.] = 0
                sr_arr[sr_arr > 1.] = 1.
                
                # bicubic, sr result, GT reference for comparison purpose
                compare_img = np.concatenate((input_arr, sr_arr, target_arr), axis=1)
            
                bicubic_psnr = psnr(input_arr, target_arr)
                sr_psnr = psnr(sr_arr, target_arr)
                
                sum_bicubic_psnr += bicubic_psnr
                bicubic_psnr_list.append(bicubic_psnr)
                sum_sr_psnr += sr_psnr
                sr_psnr_list.append(sr_psnr)
                
                bicubic_ssim = ssim(input_arr, target_arr, multichannel=True)
                sr_ssim = ssim(sr_arr, target_arr, multichannel=True)

                sum_bicubic_ssim += bicubic_ssim
                bicubic_ssim_list.append(bicubic_ssim)
                sum_sr_ssim += sr_ssim
                sr_ssim_list.append(sr_ssim)
                
                compare_img = compare_img * 255 # [0 - 1] * 255
                compare_img = compare_img.astype(np.int16)  # converst to float data type
                sr_arr = sr_arr * 255
                sr_arr = sr_arr.astype(np.int16)

                input_arr = input_arr * 255
                input_arr = input_arr.astype(np.int16)

                cv2.imwrite(os.path.join(sr_compare_dir, file_name), compare_img)
                cv2.imwrite(os.path.join(sr_result_dir, file_name), sr_arr)
                cv2.imwrite(os.path.join(sr_input_dir, file_name), input_arr)
                
                print(file_name)
                print("Bicubic PSNR: {:.8f}, Bicubic SSIM: {:.8f}, SR PSNR: {:.8f}, SR SSIM: {:.8f}".format(
                    bicubic_psnr, bicubic_ssim, sr_psnr, sr_ssim))
                total_num += 1

        avg_bicubic_psnr = sum_bicubic_psnr / total_num
        std_bicubic_psnr = np.std(bicubic_psnr_list) #표준편차 구하기
        avg_sr_psnr = sum_sr_psnr / total_num
        std_sr_psnr = np.std(sr_psnr_list)

        avg_bicubic_ssim = sum_bicubic_ssim / total_num
        std_bicubic_ssim = np.std(bicubic_ssim_list)
        avg_sr_ssim = sum_sr_ssim / total_num
        std_sr_ssim = np.std(sr_ssim_list)

        print("Time: {:.2f}".format(time.time() - start_time))
        print("Bicubic PSNR: {:.8f}".format(avg_bicubic_psnr))
        print("Bicubic PSNR STD: {:.8f}".format(std_bicubic_psnr))
        print("Bicubic SSIM: {:.8f}".format(avg_bicubic_ssim))
        print("Bicubic SSIM STD: {:.8f}".format(std_bicubic_ssim))

        print("SR PSNR: {:.8f}".format(avg_sr_psnr))
        print("SR PSNR STD: {:.8f}".format(std_sr_psnr))
        print("SR SSIM: {:.8f}".format(avg_sr_ssim))
        print("SR SSIM STD: {:.8f}".format(std_sr_ssim))

In [None]:
test_dir = opt.test_dir
res_scale = opt.res_scale
print("test_dir is: {}".format(test_dir))
print("test_result_dir is: {}".format(opt.test_result_dir))

test_dataset = get_test_dataset(test_dir, res_scale)

test_model(opt, test_dataset)

test_dir is: /content/gdrive/My Drive/srcnn/div2k_100-2/test_images
test_result_dir is: /content/gdrive/My Drive/srcnn/div2k_100-2/test_result
=> loading checkpoint '/content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0099_loss_0.002003.pth'
=> loaded checkpoint '/content/gdrive/My Drive/srcnn/div2k_100-2/checkpoint_dir/models_epoch_0099_loss_0.002003.pth' (epoch 99)
0830.png
Bicubic PSNR: 23.65051715, Bicubic SSIM: 0.77931095, SR PSNR: 23.99733887, SR SSIM: 0.79907723
0883.png
Bicubic PSNR: 25.04664727, Bicubic SSIM: 0.79283191, SR PSNR: 25.54673404, SR SSIM: 0.82017736
0884.png
Bicubic PSNR: 24.20699310, Bicubic SSIM: 0.77717060, SR PSNR: 24.65827121, SR SSIM: 0.79984994
0886.png
Bicubic PSNR: 33.58014167, Bicubic SSIM: 0.96164323, SR PSNR: 34.23868006, SR SSIM: 0.96358766
0891.png
Bicubic PSNR: 23.90350388, Bicubic SSIM: 0.81219866, SR PSNR: 24.48856528, SR SSIM: 0.83068392
Time: 24.21
Bicubic PSNR: 26.07756061
Bicubic PSNR STD: 3.78067693
Bicubic SSIM: 0.82463107