In [1]:
import json
import torch
import numpy as np
from matplotlib import pyplot as plt

from utils.utilities import *
from dataloader.BSD68_test import BSD68_test

from deal import DEAL

Load Data and Model

In [2]:
device = 'cuda:1'

test_dataset = BSD68_test('data/test.h5')

path_ckp = "trained_models/deal_gray.pth"
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = DEAL(color=False)
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()

  ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})


DEAL(
  (W1): MultiConv2d(
    (conv_layers): ModuleList(
      (0): ParametrizedConv2d(
        1, 4, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): ZeroMean()
          )
        )
      )
      (1): Conv2d(4, 8, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
      (2): Conv2d(8, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
    )
  )
  (M1): MultiConv2d(
    (conv_layers): ModuleList(
      (0): ParametrizedConv2d(
        1, 4, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): ZeroMean()
          )
        )
      )
      (1): Conv2d(4, 8, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
      (2): Conv2d(8, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
    )
  )
  (M2): Conv2d(128, 128, k

In [3]:
sigmas = [5., 15, 25]
identity = lambda x: x
eps_in = 1e-6
eps_out = 1e-5

with torch.no_grad():
    for sigma in sigmas:
        psnrs = []
        model.cal_lambda(torch.tensor(([[sigma]])).to(device))
        lmbda = model.lmbda.item()
        print('sigma : ', sigma, 'lambda : ', lmbda)

        for i, img in enumerate(test_dataset):
            
            np.random.seed(seed=0)
            noise = np.random.normal(0, sigma / 255., img.shape).astype(np.float32)
            noisy_im = img + noise

            gt = img.to(device)[None, :, :, :]
            noisy_image = noisy_im.to(device)[None, :, :, :]

            denoised = model.solve_inverse_problem(noisy_image, identity, identity, sigma, lmbda, eps_in=eps_in, eps_out=eps_out)

            psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
            psnrs.append(psnr)
            print('Image', i+1, 'PSNR', psnr)
        
        print('PSNR (mean and std): ', np.round(np.mean(np.array(psnrs)), 2), np.round(np.std(np.array(psnrs)), 2))



sigma :  5.0 lambda :  8.379857063293457
Image 1 PSNR 38.34
Image 2 PSNR 37.07
Image 3 PSNR 39.41
Image 4 PSNR 38.24
Image 5 PSNR 38.42
Image 6 PSNR 36.4
Image 7 PSNR 38.19
Image 8 PSNR 37.78
Image 9 PSNR 39.15
Image 10 PSNR 37.87
Image 11 PSNR 37.29
Image 12 PSNR 36.02
Image 13 PSNR 38.16
Image 14 PSNR 36.28
Image 15 PSNR 37.47
Image 16 PSNR 38.98
Image 17 PSNR 37.03
Image 18 PSNR 37.18
Image 19 PSNR 38.62
Image 20 PSNR 38.02
Image 21 PSNR 40.88
Image 22 PSNR 35.4
Image 23 PSNR 35.56
Image 24 PSNR 35.47
Image 25 PSNR 37.68
Image 26 PSNR 38.07
Image 27 PSNR 37.34
Image 28 PSNR 40.66
Image 29 PSNR 36.65
Image 30 PSNR 37.83
Image 31 PSNR 37.08
Image 32 PSNR 39.1
Image 33 PSNR 38.21
Image 34 PSNR 34.81
Image 35 PSNR 37.58
Image 36 PSNR 37.54
Image 37 PSNR 37.67
Image 38 PSNR 37.03
Image 39 PSNR 40.63
Image 40 PSNR 35.8
Image 41 PSNR 37.84
Image 42 PSNR 40.84
Image 43 PSNR 38.42
Image 44 PSNR 39.2
Image 45 PSNR 36.27
Image 46 PSNR 38.26
Image 47 PSNR 37.69
Image 48 PSNR 36.16
Image 49 PSNR