In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 5  


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_5_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma5/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:3'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)

model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
i = 0
psnrs_all = list()

with torch.no_grad():
    for runs in range(5):
        psnrs = list()
        for img in test_dataset:
            i = i + 1
            if True:
                gt = img.to(device)[None, :, :, :]
                noisy_image = (img.to(device) + (5.0/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
                denoised = model.solve_majorize_minimize(noisy_image)
                psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
                psnrs.append(psnr)
                print(i, psnr)
        psnrs_all.append(psnrs)

1 38.42
2 37.14
3 39.55
4 38.27
5 38.52
6 36.45
7 38.21
8 37.87
9 39.26
10 37.99
11 37.33
12 36.08
13 38.22
14 36.29
15 37.52
16 39.1
17 37.04
18 37.21
19 38.66
20 38.11
21 41.0
22 35.42
23 35.55
24 35.51
25 37.75
26 38.05
27 37.33
28 40.77
29 36.72
30 37.96
31 37.14
32 39.18
33 38.27
34 34.82
35 37.64
36 37.59
37 37.7
38 37.03
39 40.61
40 35.85
41 37.81
42 41.11
43 38.52
44 39.25
45 36.3
46 38.33
47 37.72
48 36.18
49 35.27
50 45.34
51 37.19
52 39.8
53 41.62
54 39.04
55 37.84
56 36.62
57 36.41
58 36.67
59 34.96
60 36.17
61 37.21
62 37.01
63 38.28
64 35.04
65 38.45
66 41.52
67 40.3
68 39.06
69 38.39
70 37.15
71 39.57
72 38.29
73 38.52
74 36.45
75 38.19
76 37.86
77 39.21
78 37.9
79 37.36
80 36.09
81 38.24
82 36.32
83 37.5
84 39.06
85 37.03
86 37.2
87 38.65
88 38.12
89 41.04
90 35.39
91 35.58
92 35.52
93 37.73
94 38.08
95 37.34
96 40.7
97 36.7
98 37.95
99 37.2
100 39.14
101 38.25
102 34.8
103 37.62
104 37.57
105 37.73
106 37.04
107 40.73
108 35.86
109 37.87
110 41.12
111 38.52
112 39.25
1

In [6]:
psnr_mat = np.zeros((5, 68))

for i in range(5):
    psnr_mat[i, :] = np.array(psnrs_all[i])

std_vec = np.std(psnr_mat, axis=0)
avg_vec = np.mean(psnr_mat, axis=0)

print('std mean:', np.mean(std_vec))
print('mean mean:', np.mean(avg_vec))

std mean: 0.017912930522000697
mean mean: 37.91538235294118
