In [1]:
import numpy as np
import torch
import torch.nn as nn
import os, time
from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave, imshow
import cv2

In [4]:
def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))

def save_residual(r, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    
    r = 2*(r+0.4)-0.3
    imsave(path, np.clip(r, 0, 1))

def save_structure(s, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    
    #s = 4*(s+0.3)-0.7
    s = 1.8*(s+0.7)-0.8
    imsave(path, np.clip(s, 0, 1))

def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

In [5]:
set_dir = 'data/Test'
set_names = ['Set68', 'Set12', 'test']
sigma = 50
model_dir = 'result'
model_name = 'model_obj.pt'
result_dir = 'result/image'
device = "cuda:0"

model = torch.load(os.path.join(model_dir, model_name))

model.eval()
if torch.cuda.is_available():
    model = model.to(device)

for set_cur in set_names:

    if not os.path.exists(os.path.join(result_dir, set_cur)):
        os.mkdir(os.path.join(result_dir, set_cur))
    psnrs = []
    ssims = []

    for im in os.listdir(os.path.join(set_dir, set_cur)):
        if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
            
            x = np.array(imread(os.path.join(set_dir, set_cur, im)), dtype=np.float32)/255.0
            init_shape = x.shape
                
            np.random.seed(seed=0)  # for reproducibility
            y = x + np.random.normal(0, sigma/255.0, x.shape)  # Add Gaussian noise without clipping
            y = y.astype(np.float32)
            
            y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

            torch.cuda.synchronize()
            start_time = time.time()
            y_ = y_.to(device)
            r = model(y_)
            
            y_ = y_.view(y.shape[0], y.shape[1])
            y_ = y_.cpu()
            y_ = y_.detach().numpy().astype(np.float32)
            
            r = r.view(y.shape[0], y.shape[1])
            r = r.cpu()
            r = r.detach().numpy().astype(np.float32)
            
            torch.cuda.synchronize()
            elapsed_time = time.time() - start_time
            print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

            psnr_x_ = compare_psnr(x, r)
            ssim_x_ = compare_ssim(x, r)
            
            name, ext = os.path.splitext(im)
            #show(np.hstack((y, x_)))  # show the image
            save_result(r, path=os.path.join(result_dir, set_cur, name+'_denoised'+ext))  # save the denoised image
            save_result(y, path=os.path.join(result_dir, set_cur, name+'_noise'+ext))  # save the denoised image
            #ave_structure(r, path=os.path.join(result_dir, set_cur, name+'_r1'+ext))  # save the denoised image

            psnrs.append(psnr_x_)
            ssims.append(ssim_x_)
    psnr_avg = np.mean(psnrs)
    ssim_avg = np.mean(ssims)
    psnrs.append(psnr_avg)
    ssims.append(ssim_avg)

    print('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

     Set68 : test001.png : 0.0240 second
     Set68 : test002.png : 0.0229 second
     Set68 : test003.png : 0.0239 second
     Set68 : test004.png : 0.0230 second
     Set68 : test005.png : 0.0239 second
     Set68 : test006.png : 0.0240 second
     Set68 : test007.png : 0.0229 second
     Set68 : test008.png : 0.0239 second
     Set68 : test009.png : 0.0239 second
     Set68 : test010.png : 0.0239 second
     Set68 : test011.png : 0.0239 second
     Set68 : test012.png : 0.0259 second
     Set68 : test013.png : 0.0239 second
     Set68 : test014.png : 0.0239 second
     Set68 : test015.png : 0.0239 second
     Set68 : test016.png : 0.0239 second
     Set68 : test017.png : 0.0249 second
     Set68 : test018.png : 0.0249 second
     Set68 : test019.png : 0.0239 second
     Set68 : test020.png : 0.0240 second
     Set68 : test021.png : 0.0229 second
     Set68 : test022.png : 0.0230 second
     Set68 : test023.png : 0.0239 second
     Set68 : test024.png : 0.0239 second
     Set68 : tes

In [4]:
import cv2

img = cv2.imread("data/Test/0003.png", cv2.IMREAD_GRAYSCALE)
cv2.imwrite("data/Test/test003_gray.png",img[:512, :512])

True