In [1]:
import os
import glob
import argparse
import logging
import time

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import h5py
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from skimage.measure.simple_metrics import compare_psnr
from sklearn.feature_extraction.image import extract_patches_2d

import torch
import torch.nn as nn
from torch.utils import data
import torchvision.transforms as transforms

from torchdiffeq import odeint_adjoint as odeint
from models.DnCNN import DnCNN
from models.NODE import NODEDenoiser

# Parameters, Settings

In [2]:
# GPU device
device = 'cuda:0'

num_epochs = 100

# standard dev for noise
noise_level = 25
NODE_network = True

batch_size = 16
learning_rate = 1e-4

scales = [1, 0.9, 0.8, 0.7]
patch_per_image = 1000
patch_size = 10
train_file_path = './data/train/'
test_file_path = './data/Set68/'

# Create patches realtime (may throttle CPU/not utilize GPU)
realtime_patch = True

# Use CUDA when available
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    torch.backends.cudnn.benchmark=True
else:
    dtype = torch.FloatTensor

# Data Loading

This section is dedicated to constructing a method of generating a database for patches to prevent CPU throttling during data loading.

In [3]:
if realtime_patch:
    files = glob.glob(os.path.join(train_file_path, '*.png'))
    files.sort()
    train_patches = []
    for i in range(len(files)):
        image = Image.open(files[i])
        height, width = image.size
        for scale in scales:
            im_re = image.resize((int(height*scale), int(width*scale)), Image.BICUBIC)
            im = np.float32(np.array(im_re) / 255., axis=0)
            patches = np.expand_dims(extract_patches_2d(im, (patch_size,patch_size), max_patches=1), axis=1)
            train_patches.append(patches)

        print('Train Image Iter {} of {}'.format(i, len(files)), end='\r')    
    train_patches = np.array(train_patches)
    train_patches = train_patches.reshape((-1,) + train_patches.shape[-3:])
    
    files = glob.glob(os.path.join(test_file_path, '*.png'))
    files.sort()
    test_images = []
    for i in range(len(files)):
        image = Image.open(files[i])
        im = np.float32(np.array(im_re) / 255., axis=0)
        image = np.expand_dims(im, axis=0)
        test_images.append(image)

        print('Val Image Iter {} of {}'.format(i, len(files)), end='\r')
    test_images = np.array(test_images)
else:
    if not os.path.exists("/scratch/NODE-Denoiser/denoising_train.hdf5"):
        f = h5py.File("/scratch/NODE-Denoiser/denoising_train.hdf5", "w")

        files = glob.glob(os.path.join(train_file_path, '*.png'))
        files.sort()
        train_num = 0
        for i in range(len(files)):
            image = Image.open(files[i])
            height, width = image.size
            for scale in scales:
                im_re = image.resize((int(height*scale), int(width*scale)), Image.BICUBIC)
                im = np.float32(np.array(im_re) / 255., axis=0)
                patches = np.expand_dims(extract_patches_2d(im, (patch_size,patch_size), max_patches=patch_per_image), axis=1)
                for j in range(patches.shape[0]):
                    f.create_dataset(str(train_num), data=patches[j,:], dtype='f4')
                    train_num += 1

            print('Train Image Iter {} of {}'.format(i, len(files)), end='\r')

    if not os.path.exists("/scratch/NODE-Denoiser/denoising_test.hdf5"):
        f = h5py.File("/scratch/NODE-Denoiser/denoising_test.hdf5", "w")

        files = glob.glob(os.path.join(test_file_path, '*.png'))
        files.sort()
        train_num = 0
        for i in range(len(files)):
            image = Image.open(files[i])
            im = np.float32(np.array(im_re) / 255., axis=0)
            patches = np.expand_dims(im, axis=0)
            f.create_dataset(str(train_num), data=patches, dtype='f4')
            train_num += 1

            print('Val Image Iter {} of {}'.format(i, len(files)), end='\r')

Val Image Iter 67 of 68 400

In [4]:
class RandomPatchDataset(data.Dataset):
    'Characterizes a dataset for loading patches for denoising'
    
    def __init__(self, filepath, patch_size=0, train_data=True, transform=None):
        'Initialization'
        super(RandomPatchDataset, self).__init__()
        self.patch_size = patch_size
        self.transform = transform
        self.files = glob.glob(os.path.join(filepath, '*.png'))
        self.train_data = train_data
                        
        if self.train_data:
            self.length = len(self.files) * len(scales) * patch_per_image
        else:
            self.length = len(self.files)
            
    def __len__(self):
        'Denotes the total number of samples'
        return self.length

    def __getitem__(self, index):
        'Generates one single image patch'
        file_idx = (index // (len(scales) * patch_per_image)) % len(self.files)
        scale_idx = (index // patch_per_image) % len(scales)
        patch_idx = index % patch_per_image
        
        if not self.train_data:
            image = Image.open(self.files[index])
        else:
            image = Image.open(self.files[file_idx])
            if self.transform is not None:        
                image = self.transform(image)
        if self.train_data:
            height, width = image.size
            image = image.resize((int(height*scales[scale_idx]), int(width*scales[scale_idx])), Image.BICUBIC)
        patch = image = np.float32(np.array(image) / 255., axis=0)
        if self.train_data:
            patch = np.expand_dims(extract_patches_2d(image, (self.patch_size, self.patch_size), 
                                                  max_patches=patch_per_image), axis=1)[patch_idx,:]  
        return torch.Tensor(patch)
    
class h5pyDataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, file_name, augment=False):
        'Initialization'
        super(h5pyDataset, self).__init__()
        self.file_name = file_name
        self.augment = augment
        with h5py.File(self.file_name, 'r') as db:
            self.length = len(db.keys())

    def __len__(self):
        'Denotes the total number of samples'
        return self.length

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        with h5py.File(self.file_name, 'r') as db:
            image = np.array(db[str(index)])
    
        if self.augment:
            data_augmentation(image)
            
        return torch.Tensor(image)
    
class ImageDataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, images, augment=False):
        'Initialization'
        super(ImageDataset, self).__init__()
        self.images = images
        self.images = torch.Tensor(self.images)
        self.augment = augment

    def __len__(self):
        'Denotes the total number of samples'
        return self.images.shape[0]

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        image = self.images[index, :]
    
        if self.augment:
            data_augmentation(image)
            
        return image
    
def data_augmentation(image):
    out = np.transpose(image, (1,2,0))
    mode = np.random.randint(0,8)
    if mode == 0:
        # original
        out = out
    elif mode == 1:
        # flip up and down
        out = np.flipud(out)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(out)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(out)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(out, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(out, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(out, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(out, k=3)
        out = np.flipud(out)
    return np.transpose(out, (2,0,1))
        
if realtime_patch:
    #trainloader = data.DataLoader(RandomPatchDataset(train_file_path, patch_size=patch_size, transform=patch_transform), 
    #                              batch_size=batch_size, shuffle=True, num_workers=4)
    #testloader = data.DataLoader(RandomPatchDataset(test_file_path, train_data=False), 
    #                             batch_size=1, shuffle=False, num_workers=4)
    trainloader = data.DataLoader(ImageDataset(train_patches, augment=True), 
                                  batch_size=batch_size, shuffle=True, num_workers=4)
    testloader = data.DataLoader(ImageDataset(test_images, augment=False),
                                batch_size=1, shuffle=False, num_workers=4)
else:
    trainloader = data.DataLoader(h5pyDataset('/scratch/NODE-Denoiser/denoising_train.hdf5', augment=True), 
                                  batch_size=batch_size, shuffle=True, num_workers=1)
    testloader = data.DataLoader(h5pyDataset('/scratch/NODE-Denoiser/denoising_test.hdf5'), 
                                 batch_size=1, shuffle=False, num_workers=1)   

# Model Setup

In [9]:
if NODE_network:
    model = NODEDenoiser()
else:
    model = DnCNN(1)
model = model.to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss().to(device)

# Train Model

In [None]:
for epoch in range(num_epochs):
    model.train()
    for batch_count, batch_data in enumerate(trainloader):
        optimizer.zero_grad()
        
        image_clean = batch_data
        noise = torch.FloatTensor(image_clean.size()).normal_(mean=0, std=noise_level/255.)
        image_clean, noise = image_clean.to(device), noise.to(device)
        
        noise_est = model(image_clean + noise)
        
        loss = criterion(noise_est, noise)
        print('Train iter: {} of {}, loss: {}'.format(batch_count+1, trainloader.__len__(), loss.item()), end='\r')
        loss.backward()
        optimizer.step()
        
    model.eval()
    val_psnr = []
    with torch.no_grad():
        for batch_count, batch_data in enumerate(testloader):
            image_clean = batch_data
            noise = torch.FloatTensor(image_clean.size()).normal_(mean=0, std=noise_level/255.)
            image_clean, noise = image_clean.to(device), noise.to(device)
        
            noise_est = model(image_clean + noise)
            img_est = torch.clamp(image_clean+noise-noise_est, 0., 1.)
            
            val_psnr.append(compare_psnr(img_est.detach().cpu().numpy(), image_clean.detach().cpu().numpy(), 1.))
    
    print('Epoch {} of {}, Val PSNR: {:4f}'.format(epoch+1, num_epochs, val_psnr[-1]))


Train iter: 100 of 100, loss: 0.0053953551687300205