In [1]:
import json
import torch
import numpy as np
from matplotlib import pyplot as plt

from utils.utilities import *
from dataloader.CBSD68_test import CBSD68_test

from deal import DEAL

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:0'

test_dataset = CBSD68_test()


path_ckp = "trained_models/deal_color.pth"
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = DEAL(color=True)
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()

  ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})


DEAL(
  (W1): MultiConv2d(
    (conv_layers): ModuleList(
      (0): ParametrizedConv2d(
        3, 12, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): ZeroMean()
          )
        )
      )
      (1): Conv2d(12, 24, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
      (2): Conv2d(24, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
    )
  )
  (M1): MultiConv2d(
    (conv_layers): ModuleList(
      (0): ParametrizedConv2d(
        3, 12, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): ZeroMean()
          )
        )
      )
      (1): Conv2d(12, 24, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
      (2): Conv2d(24, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
    )
  )
  (M2): Conv2d(128

In [3]:
psnrs = list()

sigmas = [5., 15, 25]
identity = lambda x: x
eps_in = 1e-6
eps_out = 1e-5

with torch.no_grad():
    for sigma in sigmas:
        psnrs = []
        model.cal_lambda(torch.tensor(([[sigma]])).to(device))
        lmbda = model.lmbda.item()
        print('sigma : ', sigma, 'lambda : ', lmbda)

        for i, img in enumerate(test_dataset):
            #img = crop_center(img, 256, 256)

            np.random.seed(seed=0)
            noise = np.random.normal(0, sigma / 255., img.shape)
            noisy_im = img + noise

            gt = torch.Tensor(img).transpose(0, 2).transpose(1, 2).to(device)[None, ...]
            noisy_image = torch.Tensor(noisy_im).transpose(0, 2).transpose(1, 2).to(device)[None, ...].float()

            denoised = model.solve_inverse_problem(noisy_image, identity, identity, sigma, lmbda, eps_in=eps_in, eps_out=eps_out)
            psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
            psnrs.append(psnr)
            print('Image', i+1, 'PSNR', psnr)
       
        print('PSNR (mean and std): ', np.round(np.mean(np.array(psnrs)), 2), np.round(np.std(np.array(psnrs)), 2))

sigma :  5.0 lambda :  8.082879066467285
Image 1 PSNR 40.74
Image 2 PSNR 40.1
Image 3 PSNR 40.07
Image 4 PSNR 38.92
Image 5 PSNR 40.51
Image 6 PSNR 40.76
Image 7 PSNR 40.03
Image 8 PSNR 42.45
Image 9 PSNR 40.77
Image 10 PSNR 41.36
Image 11 PSNR 42.51
Image 12 PSNR 40.05
Image 13 PSNR 40.55
Image 14 PSNR 39.98
Image 15 PSNR 39.83
Image 16 PSNR 40.96
Image 17 PSNR 38.76
Image 18 PSNR 39.1
Image 19 PSNR 39.51
Image 20 PSNR 40.7
Image 21 PSNR 39.71
Image 22 PSNR 39.85
Image 23 PSNR 41.17
Image 24 PSNR 40.54
Image 25 PSNR 41.12
Image 26 PSNR 39.53
Image 27 PSNR 39.89
Image 28 PSNR 42.68
Image 29 PSNR 39.08
Image 30 PSNR 42.02
Image 31 PSNR 38.86
Image 32 PSNR 39.87
Image 33 PSNR 38.0
Image 34 PSNR 39.96
Image 35 PSNR 39.77
Image 36 PSNR 39.01
Image 37 PSNR 41.65
Image 38 PSNR 40.38
Image 39 PSNR 39.72
Image 40 PSNR 40.04
Image 41 PSNR 43.11
Image 42 PSNR 40.9
Image 43 PSNR 38.88
Image 44 PSNR 39.64
Image 45 PSNR 39.11
Image 46 PSNR 40.83
Image 47 PSNR 41.45
Image 48 PSNR 42.93
Image 49 PSNR