In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 15


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_15_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma15/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:3'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)

model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
i = 0
psnrs_all = list()

with torch.no_grad():
    for runs in range(5):
        psnrs = list()
        for img in test_dataset:
            i = i + 1
            if True:
                gt = img.to(device)[None, :, :, :]
                noisy_image = (img.to(device) + (15.0/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
                denoised = model.solve_majorize_minimize(noisy_image)
                psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
                psnrs.append(psnr)
                print(i, psnr)
        psnrs_all.append(psnrs)

1 32.41
2 30.99
3 33.26
4 32.41
5 32.5
6 29.37
7 31.73
8 31.35
9 32.78
10 31.55
11 30.84
12 28.99
13 32.08
14 29.39
15 30.7
16 33.24
17 30.94
18 31.55
19 33.37
20 32.06
21 35.42
22 29.15
23 28.95
24 27.54
25 30.65
26 31.49
27 30.83
28 35.2
29 29.06
30 32.28
31 30.42
32 33.97
33 32.39
34 26.91
35 31.23
36 30.72
37 31.67
38 30.14
39 35.69
40 29.01
41 32.64
42 35.39
43 32.99
44 33.06
45 29.76
46 32.23
47 31.34
48 29.39
49 28.36
50 40.86
51 30.35
52 35.37
53 36.16
54 32.65
55 31.65
56 29.89
57 29.35
58 30.17
59 27.41
60 29.3
61 30.7
62 30.45
63 33.09
64 27.56
65 31.77
66 36.03
67 34.98
68 32.16
69 32.45
70 30.98
71 33.33
72 32.38
73 32.5
74 29.4
75 31.82
76 31.4
77 32.8
78 31.53
79 30.84
80 28.99
81 32.03
82 29.36
83 30.73
84 33.22
85 30.9
86 31.58
87 33.41
88 32.04
89 35.38
90 29.15
91 29.01
92 27.52
93 30.64
94 31.53
95 30.78
96 35.16
97 29.05
98 32.28
99 30.42
100 33.95
101 32.44
102 26.9
103 31.27
104 30.73
105 31.7
106 30.16
107 35.81
108 29.05
109 32.61
110 35.31
111 33.01
112 33.09


In [6]:
psnr_mat = np.zeros((5, 68))

for i in range(5):
    psnr_mat[i, :] = np.array(psnrs_all[i])

std_vec = np.std(psnr_mat, axis=0)
avg_vec = np.mean(psnr_mat, axis=0)

print('std mean:', np.mean(std_vec))
print('mean mean:', np.mean(avg_vec))

std mean: 0.023407753118581395
mean mean: 31.608499999999996
