In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 25  


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_25_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma25/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:0'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)

model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
i = 0
psnrs = list()

with torch.no_grad():
    for img in test_dataset:
        i = i + 1
        if True:
            gt = img.to(device)[None, :, :, :]
            noisy_image = (img.to(device) + (25.0/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
            denoised = model.solve_majorize_minimize(noisy_image)
            psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
            psnrs.append(psnr)
            print(i, psnr)

1 30.09
2 28.64
3 30.67
4 29.87
5 30.08
6 26.48
7 28.93
8 28.62
9 29.77
10 29.0
11 28.3
12 26.25
13 29.55
14 26.63
15 28.19
16 30.88
17 28.35
18 29.38
19 31.52
20 29.6
21 33.22
22 27.42
23 26.7
24 24.34
25 27.87
26 28.92
27 27.97
28 33.01
29 26.09
30 30.37
31 27.91
32 32.09
33 30.02
34 23.97
35 28.61
36 28.31
37 29.2
38 27.48
39 33.64
40 26.32
41 30.9
42 33.16
43 31.02
44 30.28
45 27.11
46 29.8
47 28.93
48 26.89
49 25.82
50 38.29
51 27.49
52 33.53
53 33.55
54 30.22
55 29.06
56 27.22
57 26.3
58 27.56
59 24.47
60 26.57
61 28.54
62 27.98
63 31.06
64 24.68
65 28.88
66 33.11
67 32.97
68 29.41


In [6]:
print(np.round(np.mean(np.array(psnrs)), 2), np.round(np.std(np.array(psnrs)), 2))

29.1 2.59
