# SRCNN Super Resolution CNN with pytorch  
## Training ipython for General100 Ddataset.  
Exercise for Product making. 
  1. Initialize to run SRCNN Model  
  2. Load Data General 100   
  3. Load Model  
  4. Loop step of training with outputs  

Ref.  
Chao Dong, Chen Change Loy, Kaiming He, Xiaoou Tang. Learning a Deep Convolutional Network for Image Super-Resolution, in Proceedings of European Conference on Computer Vision (ECCV), 2014 

stdout, stderr, weight and samples are saved into result directory.  

In [None]:
import os, sys
from pathlib import Path

import argparse
parser = argparse.ArgumentParser(description='predictionCNN Example')
parser.add_argument('--cuda' , action='store_true', default=False)
parser.add_argument('--naive', action='store_true', default=False)
parser.add_argument('--c11'  , action='store_true', default=False)
parser.add_argument('--con'  , action='store_true', default=False)
parser.add_argument('--epoch', type=int, default=50*1000)
parser.add_argument('--snap' , type=int, default=500)
if Path(sys.argv[0]).stem == 'ipykernel_launcher':
    # On IPython
    opt = parser.parse_args(args=[])
    # opt = parser.parse_args(args=["--cuda"])
    # opt = parser.parse_args(args=["--naive"])
    # opt = parser.parse_args(args=["--c11"])
    # opt = parser.parse_args(args=["--snap","5"])
else:
    # On Console
    opt = parser.parse_args()

result_fileout=not opt.con  #To fileout
result_dir=Path('result/SRCNN')    
if opt.naive:
    result_dir = Path(str(result_dir)+'-naive')
if opt.c11:
    result_dir = Path(str(result_dir)+'-c11')

sample_dir = result_dir / 'sample'
weight_dir = result_dir / 'weights'
os.makedirs(str(sample_dir), exist_ok=True)
os.makedirs(str(weight_dir), exist_ok=True)
os.makedirs(str(result_dir),exist_ok=True)
resultdB = result_dir / 'dBhistory.npy'
backup_stdout = backup_stderr = None
if result_fileout:
    backup_stdout = sys.stdout
    backup_stderr = sys.stderr
    sys.stdout = open(str(result_dir)+'/log.stdout','w')
    sys.stderr = open(str(result_dir)+'/log.stderr','w')
print("Result on {} file : {}".format(result_dir, resultdB))
print(" CUDA  : {}".format(opt.cuda))
print(" AUG   : {}".format(not opt.naive))
print(" C11   : {}".format(opt.c11))

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.utils import save_image

import numpy as np
from math import log10

# from model import SRCNN
from torch.nn.functional import relu
from torch.nn import MSELoss

import torch.utils.data as data
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomCrop
from PIL import Image, ImageOps
from pathlib import Path
import random
from pdb import set_trace

## データローダ定義  
学習用画像のミニバッチをロード  
評価用画像を1枚づつロード  
イテレータを戻す  

### データ拡張  
基本的なデータ拡張を行うことにより、1dB程度の精度改善  
flip  
mirror  
rotate  

In [None]:
class DatasetLoader4Train(data.Dataset):
    def __init__(self, image_dir, patch_size, scale_factor, data_augmentation=True):
        super(DatasetLoader4Train, self).__init__()
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        self.patch_size = patch_size
        self.scale_factor = scale_factor
        self.data_augmentation = data_augmentation
        self.crop = RandomCrop(self.patch_size)

    def __getitem__(self, index):
        target_img = Image.open(self.filenames[index]).convert('RGB')
        target_img = self.crop(target_img)

        if self.data_augmentation:     # Data Augmentation
            if random.random() < 0.5:
                target_img = ImageOps.flip(target_img)
            if random.random() < 0.5:
                target_img = ImageOps.mirror(target_img)
            if random.random() < 0.5:
                target_img = target_img.rotate(180)
        input_img = target_img.resize((self.patch_size // self.scale_factor,) * 2, Image.BICUBIC)
        input_img = input_img.resize((self.patch_size,) * 2, Image.BICUBIC)

        return ToTensor()(input_img), ToTensor()(target_img)

    def __len__(self):
        return len(self.filenames)

class DatasetLoader4Eval(data.Dataset):
    def __init__(self, image_dir, scale_factor):
        super(DatasetLoader4Eval, self).__init__()
        self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
        self.scale_factor = scale_factor

    def __getitem__(self, index):
        target_img = Image.open(self.filenames[index]).convert('RGB')

        input_img = target_img.resize((target_img.size[0] // self.scale_factor, target_img.size[1] // self.scale_factor), Image.BICUBIC)
        input_img = input_img.resize(target_img.size, Image.BICUBIC)

        return ToTensor()(input_img), ToTensor()(target_img), Path(self.filenames[index]).stem

    def __len__(self):
        return len(self.filenames)

データ読み込み   

In [None]:
train_set = DatasetLoader4Train(image_dir='./data/General-100/train', patch_size=96, scale_factor=4, data_augmentation=not opt.naive)
train_loader = DataLoader(dataset=train_set, batch_size=10, shuffle=True)

val_set = DatasetLoader4Eval(image_dir='./data/General-100/val', scale_factor=4)
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

## SRCNNモデル定義  
論文に従いモデルを定義   

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2)

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.normal_(m.weight, mean=0, std=0.001)
                nn.init.constant_(m.bias, val=0)

            self.params = [
                        {'params': self.conv1.parameters()},
                        {'params': self.conv2.parameters()},
                        {'params': self.conv3.parameters(),
                        'lr': 1e-5}]

    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.conv2(x)
        x = relu(x)
        x = self.conv3(x)
        return x

## SRCNNモデル改  
原論文では、カーネルサイズ９x９の特徴抽出層が第１層に当たる  
これにカーネルサイズ１１x１１の層を追加し広範囲の特徴抽出を行う  

In [None]:
class SRCNN11(nn.Module):
    def __init__(self):
        super(SRCNN11, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=9, padding=4)
        self.conv11= nn.Conv2d(in_channels=3, out_channels=32, kernel_size=11, padding=5)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2)

        for m in self.modules():
            if isinstance(m, (nn.Conv2d)):
                nn.init.normal_(m.weight, mean=0, std=0.001)
                nn.init.constant_(m.bias, val=0)

        self.params = [
                        {'params': self.conv1.parameters()},
                        {'params': self.conv11.parameters()},
                        {'params': self.conv2.parameters()},
                        {'params': self.conv3.parameters(),
                        'lr': 1e-5}]
    def forward(self, x):
        x9 = self.conv1(x)
        x9 = relu(x9)
        x11 = self.conv11(x)
        x11 = relu(x11)
        x = torch.cat((x9, x11),dim=1)
        x = self.conv2(x)
        x = relu(x)
        x = self.conv3(x)
        return x

## 使用モデルのインスタンス  
モデルインスタンス  
オプティマイザ定義  

In [None]:
try:
    if opt.c11:
        model = SRCNN11()
    else:
        model = SRCNN()
except:
    model = SRCNN()

criterion = MSELoss()
if opt.cuda:
    model = model.cuda()
    criterion = criterion.cuda()

optimizer = optim.Adam( model.params, lr=1e-4 )

## Training Loop with MiniBatch  
epoch and snapshot size can be specified command args  

In [None]:
try:
    epochs = opt.epoch
    snaps  = opt.snap
except:
    epochs = 5*10000
    snaps  = 5*  100
progressPSNR = []
progressLoss = []
from pdb import set_trace
for epoch in range(epochs):
    model.train() # Training Phase
    epoch_loss, epoch_psnr = 0, 0
    for batch in train_loader:
        inputs, targets = Variable(batch[0]), Variable(batch[1])
        if opt.cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()

        optimizer.zero_grad()
        prediction = model(inputs)
        loss = criterion(prediction, targets)
        epoch_loss += loss.data
        epoch_psnr += 10 * log10(1 / loss.data)

        loss.backward()
        optimizer.step()

    print('[Epoch {}] Loss: {:.4f}, PSNR: {:.4f} dB'.format(epoch + 1, epoch_loss / len(train_loader), epoch_psnr / len(train_loader)))
    sys.stdout.flush()

    if (epoch + 1) % snaps != 0:
        continue

    model.eval()  # Validation Phase
    val_loss, val_psnr = 0, 0
    val_loss0,val_psnr0= 0, 0
    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch[0], batch[1]
            if opt.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()

            prediction = model(inputs)
            loss = criterion(prediction, targets)
            val_loss += loss.data
            val_psnr += 10 * log10(1 / loss.data)
            
            loss0= criterion(inputs, targets)
            val_loss0+= loss0.data
            val_psnr0+= 10 * log10(1 / loss0.data)

            pred_file   = sample_dir / '{}_epoch{:05}.png'.format(batch[2][0], epoch + 1)
            target_file = sample_dir / '{}_epoch{:05}.png'.format(batch[2][0], 00000)
            save_image(prediction, pred_file, nrow=1)
            if not target_file.exists(): save_image(targets, target_file, nrow=1)

    avrg_loss0= val_loss0/ len(val_loader) # For Validation
    avrg_psnr0= val_psnr0/ len(val_loader)
    avrg_loss = val_loss / len(val_loader) # For Prediction
    avrg_psnr = val_psnr / len(val_loader)
    progressPSNR.append(avrg_psnr)
    progressLoss.append(avrg_loss)
    print("===> Avrg Loss: {:.4f} PSNR: {:.4f} dB [ VAL {:.4f} / {:.4f} dB ]".format(avrg_loss, avrg_psnr, avrg_loss0, avrg_psnr0))
    np.save(str(resultdB),progressPSNR)
    
    torch.save(model.state_dict(), str(result_dir / 'latest_weight.pth')) # Save Latest Weight
    torch.save(model.state_dict(), str(weight_dir / 'weight_epoch{:05}.pth'.format(epoch + 1)))

後処理

In [None]:
# retrieve stdio
if result_fileout:
    sys.stdout = backup_stdout if backup_stdout else None
    sys.stderr = backup_stderr if backup_stderr else None