In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 15


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_15_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma15/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:0'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)

model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
i = 0
psnrs = list()

with torch.no_grad():
    for img in test_dataset:
        i = i + 1
        if True:
            gt = img.to(device)[None, :, :, :]
            noisy_image = (img.to(device) + (15.0/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
            denoised = model.solve_majorize_minimize(noisy_image)
            psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
            psnrs.append(psnr)
            print(i, psnr)

1 32.43
2 30.95
3 33.26
4 32.39
5 32.49
6 29.42
7 31.8
8 31.34
9 32.78
10 31.5
11 30.85
12 28.96
13 32.03
14 29.36
15 30.74
16 33.29
17 30.9
18 31.58
19 33.4
20 32.07
21 35.34
22 29.15
23 29.01
24 27.5
25 30.62
26 31.52
27 30.8
28 35.2
29 29.03
30 32.28
31 30.45
32 33.98
33 32.39
34 26.94
35 31.23
36 30.76
37 31.67
38 30.17
39 35.74
40 29.03
41 32.62
42 35.38
43 32.99
44 33.07
45 29.77
46 32.15
47 31.4
48 29.4
49 28.29
50 40.71
51 30.33
52 35.32
53 36.21
54 32.66
55 31.66
56 29.89
57 29.29
58 30.12
59 27.44
60 29.27
61 30.76
62 30.45
63 33.06
64 27.53
65 31.77
66 35.96
67 34.96
68 32.19


In [6]:
print(np.round(np.mean(np.array(psnrs)), 2), np.round(np.std(np.array(psnrs)), 2))

31.6 2.42
