In [12]:
import numpy as np
import torch
import torch.nn as nn
import os, time
from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave, imshow
import cv2

In [19]:
def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))

def save_residual(r, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    
    r = 2*(r+0.4)-0.3
    imsave(path, np.clip(r, 0, 1))

def save_structure(s, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    
    #s = 4*(s+0.3)-0.7
    #s = 1.8*(s+0.7)-0.8
    s = 2*(s+0.3)-0.3
    imsave(path, np.clip(s, 0, 1))

def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

In [31]:
set_dir = 'data/Test'
set_names = ['Set68', 'Set12']
sigma = 25
model_dir = 'models'
model_name = 'model25m6d17/model.pth'
dnet = 'dncnn25(17).pth'
result_dir = 'results'
device = "cuda:0"

dncnn = torch.load(os.path.join(model_dir, dnet))
model = torch.load(os.path.join(model_dir, model_name))
print(model)

model.eval()
dncnn.eval()

if torch.cuda.is_available():
    model = model.to(device)
    dncnn = dncnn.to(device)

for set_cur in set_names:

    if not os.path.exists(os.path.join(result_dir, set_cur)):
        os.mkdir(os.path.join(result_dir, set_cur))
    psnrs = []
    ssims = []

    for im in os.listdir(os.path.join(set_dir, set_cur)):
        if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
            
            x = np.array(imread(os.path.join(set_dir, set_cur, im)), dtype=np.float32)/255.0
            init_shape = x.shape
                
            np.random.seed(seed=0)  # for reproducibility
            y = x + np.random.normal(0, sigma/255.0, x.shape)  # Add Gaussian noise without clipping
            y = y.astype(np.float32)
            
            y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

            torch.cuda.synchronize()
            start_time = time.time()
            y_ = y_.to(device)
            r = dncnn(y_)
            
            r = r.view(y.shape[0], y.shape[1])
            r = r.cpu()
            r = r.detach().numpy().astype(np.float32)
            
            f = np.fft.fft2(r)
            fshift = np.fft.fftshift(f)
            
            rows, cols = r.shape
            crow, ccol = int(rows/2), int(cols/2)
            
            d=100
            f_filter = np.zeros_like(r)
            f_filter[crow-d:crow+d, ccol-d:ccol+d] = 1
            fshift = fshift*f_filter
            
            f_ishift = np.fft.ifftshift(fshift)
            r_back = np.fft.ifft2(f_ishift)
            r_back = np.abs(r_back)
            
            torch.cuda.synchronize()
            elapsed_time = time.time() - start_time
            print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

            #psnr_x_ = compare_psnr(x, x_)
            #ssim_x_ = compare_ssim(x, x_)
            
            name, ext = os.path.splitext(im)
            #show(np.hstack((y, x_)))  # show the image
            #save_result(x_, path=os.path.join(result_dir, set_cur, name+'_denoised'+ext))  # save the denoised image
            save_structure(r, path=os.path.join(result_dir, set_cur, name+'_residual'+ext))  # save the denoised image
            save_structure(r_back, path=os.path.join(result_dir, set_cur, name+'_fft'+ext))  # save the denoised image

            #psnrs.append(psnr_x_)
            #ssims.append(ssim_x_)
    #psnr_avg = np.mean(psnrs)
    #ssim_avg = np.mean(ssims)
    #psnrs.append(psnr_avg)
    #ssims.append(ssim_avg)

    #print('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

DataParallel(
  (module): Model(
    (net): FFTConv(
      (net1): Sequential(
        (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(2, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(2, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
        (6): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
)
     Set68 : test001.png : 0.0490 second
     Set68 : test002.png : 0.0648 second
     Set68 : test003.png : 0.0728 second
     Set68 : test004.png : 0.0698 second
     Set68 : test005.png : 0.0658 second
     Set68 : test006.png : 0.0708 second
     Set68 : test007.png : 0.0705 second
     Set68 : test008.png : 0.0715 second
     Set68 : test009.png : 0.0708 second
     Set68 : test010.pn