In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 25  


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_25_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma25/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:0'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)

model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
i = 0
psnrs_all = list()

with torch.no_grad():
    for runs in range(5):
        psnrs = list()
        for img in test_dataset:
            i = i + 1
            if True:
                gt = img.to(device)[None, :, :, :]
                noisy_image = (img.to(device) + (25.0/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
                denoised = model.solve_majorize_minimize(noisy_image)
                psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
                psnrs.append(psnr)
                print(i, psnr)
        psnrs_all.append(psnrs)

1 29.98
2 28.63
3 30.75
4 29.87
5 30.02
6 26.51
7 28.98
8 28.64
9 29.83
10 29.02
11 28.25
12 26.18
13 29.52
14 26.6
15 28.13
16 30.86
17 28.4
18 29.33
19 31.47
20 29.63
21 33.17
22 27.39
23 26.67
24 24.41
25 27.81
26 28.84
27 27.98
28 33.07
29 26.1
30 30.37
31 27.91
32 32.11
33 30.05
34 23.97
35 28.66
36 28.28
37 29.17
38 27.55
39 33.66
40 26.31
41 30.93
42 33.01
43 31.06
44 30.36
45 27.13
46 29.72
47 28.96
48 26.9
49 25.83
50 38.47
51 27.52
52 33.49
53 33.67
54 30.16
55 29.11
56 27.24
57 26.28
58 27.55
59 24.51
60 26.55
61 28.57
62 28.03
63 31.05
64 24.67
65 28.92
66 33.22
67 33.0
68 29.47
69 30.04
70 28.6
71 30.73
72 29.9
73 29.96
74 26.5
75 28.96
76 28.67
77 29.82
78 29.02
79 28.26
80 26.23
81 29.51
82 26.6
83 28.22
84 30.9
85 28.4
86 29.36
87 31.46
88 29.63
89 33.08
90 27.4
91 26.69
92 24.37
93 27.88
94 28.9
95 27.98
96 33.07
97 26.06
98 30.39
99 27.88
100 32.06
101 30.07
102 23.98
103 28.6
104 28.28
105 29.1
106 27.5
107 33.67
108 26.32
109 30.93
110 32.92
111 30.94
112 30.45
113 

In [6]:
psnr_mat = np.zeros((5, 68))

for i in range(5):
    psnr_mat[i, :] = np.array(psnrs_all[i])

std_vec = np.std(psnr_mat, axis=0)
avg_vec = np.mean(psnr_mat, axis=0)

print('std mean:', np.mean(std_vec))
print('mean mean:', np.mean(avg_vec))

std mean: 0.02736070853474293
mean mean: 29.105147058823526
