In [1]:
import os
import glob
import argparse
import logging
import time

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import h5py
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from skimage.measure import compare_psnr, compare_ssim
from sklearn.feature_extraction.image import extract_patches_2d

import torch
import torch.nn as nn
from torch.utils import data
import torchvision.transforms as transforms

from torchdiffeq import odeint_adjoint as odeint
from models.DnCNN import DnCNN
from models.NODE import NODEDenoiser
from utils.dataloaders import *

# Parameters, Settings

In [2]:
# GPU device
device = 'cuda:3'

# standard dev for noise
noise_level = 25
NODE_network = False

num_epochs = 180
batch_size = 128
learning_rate = 1e-3

scales = [1, 0.9, 0.8, 0.7]
patch_per_image = 150
patch_size = 40
train_file_path = './data/train/'
test_file_path = './data/Set68/'

# Create patches realtime (may throttle CPU/not utilize GPU)
realtime_patch = True

# Use CUDA when available
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    torch.backends.cudnn.benchmark=True
else:
    dtype = torch.FloatTensor

# Data Loading

This section is dedicated to constructing a method of generating a database for patches to prevent CPU throttling during data loading.

In [3]:
if realtime_patch:
    files = glob.glob(os.path.join(train_file_path, '*.png'))
    files.sort()
    train_patches = []
    for i in range(len(files)):
        image = Image.open(files[i])
        height, width = image.size
        for scale in scales:
            im_re = image.resize((int(height*scale), int(width*scale)), Image.BICUBIC)
            im = np.float32(np.array(im_re) / 255., axis=0)
            patches = np.expand_dims(extract_patches_2d(im, (patch_size,patch_size), max_patches=patch_per_image), axis=1)
            train_patches.append(patches)

        print('Train Image Iter {} of {}'.format(i, len(files)), end='\r')    
    train_patches = np.array(train_patches)
    train_patches = train_patches.reshape((-1,) + train_patches.shape[-3:])
    
    files = glob.glob(os.path.join(test_file_path, '*.png'))
    files.sort()
    test_images = []
    for i in range(len(files)):
        image = Image.open(files[i])
        im = np.float32(np.array(im_re) / 255., axis=0)
        if NODE_network:
            image = np.expand_dims(extract_patches_2d(im, (patch_size,patch_size), max_patches=patch_per_image), axis=1)
        else:
            image = np.expand_dims(im, axis=0)
        test_images.append(image)

        print('Val Image Iter {} of {}'.format(i, len(files)), end='\r')
    test_images = np.array(test_images)
else:
    if not os.path.exists("/scratch/NODE-Denoiser/denoising_train.hdf5"):
        f = h5py.File("/scratch/NODE-Denoiser/denoising_train.hdf5", "w")

        files = glob.glob(os.path.join(train_file_path, '*.png'))
        files.sort()
        train_num = 0
        for i in range(len(files)):
            image = Image.open(files[i])
            height, width = image.size
            for scale in scales:
                im_re = image.resize((int(height*scale), int(width*scale)), Image.BICUBIC)
                im = np.float32(np.array(im_re) / 255., axis=0)
                patches = np.expand_dims(extract_patches_2d(im, (patch_size,patch_size), max_patches=patch_per_image), axis=1)
                for j in range(patches.shape[0]):
                    f.create_dataset(str(train_num), data=patches[j,:], dtype='f4')
                    train_num += 1

            print('Train Image Iter {} of {}'.format(i, len(files)), end='\r')

    if not os.path.exists("/scratch/NODE-Denoiser/denoising_test.hdf5"):
        f = h5py.File("/scratch/NODE-Denoiser/denoising_test.hdf5", "w")

        files = glob.glob(os.path.join(test_file_path, '*.png'))
        files.sort()
        train_num = 0
        for i in range(len(files)):
            image = Image.open(files[i])
            im = np.float32(np.array(im_re) / 255., axis=0)
            patches = np.expand_dims(im, axis=0)
            f.create_dataset(str(train_num), data=patches, dtype='f4')
            train_num += 1

            print('Val Image Iter {} of {}'.format(i, len(files)), end='\r')

Val Image Iter 67 of 68 400

In [7]:
if realtime_patch:
    #trainloader = data.DataLoader(RandomPatchDataset(train_file_path, patch_size=patch_size, transform=patch_transform), 
    #                              batch_size=batch_size, shuffle=True, num_workers=4)
    #testloader = data.DataLoader(RandomPatchDataset(test_file_path, train_data=False), 
    #                             batch_size=1, shuffle=False, num_workers=4)
    trainloader = data.DataLoader(ImageDataset(train_patches, augment=True), 
                                  batch_size=batch_size, shuffle=True, num_workers=4)
    testloader = data.DataLoader(ImageDataset(test_images, augment=False),
                                batch_size=1, shuffle=False, num_workers=4)
else:
    trainloader = data.DataLoader(h5pyDataset('/scratch/NODE-Denoiser/denoising_train.hdf5', augment=True), 
                                  batch_size=batch_size, shuffle=True, num_workers=1)
    testloader = data.DataLoader(h5pyDataset('/scratch/NODE-Denoiser/denoising_test.hdf5'), 
                                 batch_size=1, shuffle=False, num_workers=1)   

# Model Setup

In [None]:
if NODE_network:
    model = NODEDenoiser()
else:
    model = DnCNN()
    
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)
criterion = nn.MSELoss(reduction='sum').to(device)

# Train Model

In [None]:
for epoch in range(num_epochs):
    scheduler.step(epoch)
    model.train()
    for batch_count, batch_data in enumerate(trainloader):
        optimizer.zero_grad()
        
        img_clean = batch_data
        noise = torch.FloatTensor(img_clean.size()).normal_(mean=0, std=noise_level/255.)
        img_clean, noise = img_clean.to(device), noise.to(device)
                
        loss = 0.5*criterion(model(img_clean + noise), img_clean)
        print('Train iter: {} of {}, loss: {}'.format(batch_count+1, trainloader.__len__(), loss.item()), end='\r')
        loss.backward()
        optimizer.step()
        
    model.eval()
    val_psnr = []
    val_ssim = []
    with torch.no_grad():
        for batch_count, batch_data in enumerate(testloader):
            img_clean = batch_data
            noise = torch.FloatTensor(img_clean.size()).normal_(mean=0, std=noise_level/255.)
            img_clean, noise = img_clean.to(device), noise.to(device)
        
            img_est = torch.clamp(model(img_clean + noise), 0., 1.).detach().cpu().numpy()
            img_clean = torch.clamp(img_clean, 0., 1.).detach().cpu().numpy()
            
            val_psnr.append(compare_psnr(img_est, img_clean))
            val_ssim.append(compare_ssim(np.squeeze(img_est), np.squeeze(img_clean)))
    
    print('Epoch {} of {}, Val PSNR: {:4f}, Val SSIM: {:4f}'.format(epoch+1, num_epochs, np.mean(val_psnr), np.mean(val_ssim)))
