In [1]:
from dataloader.BSD500 import BSD500
import torch
from R_network_relax import RNet
import json
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 15  


val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_15_relaxed_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]


exp = "/home/pourya/mm_final/exps/sigma15/64_2_7_1e-3_relaxed"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:2'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()


model.eigenimage = model.eigenimage.to(device)
lip, model.eigenimage = model.W1.cal_lip(model.eigenimage, 1000)
model.alpha = 1 / lip
model.lmbda = torch.nn.Parameter(torch.ones(1, 1).to(device) * p1)


model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)

In [5]:
def safi_denoise(y):
    s = 1e-3
    n = 5
    e = 1e-5
    q = (e/s) ** (1/n)
    tols = [s * q**(i) for i in range(n+1)]
    n_out = 10
    for _ in range(n_out-(n+1)):
        tols.append(e)
    with torch.no_grad(): 
        n_in = 100  
        c_k = model.prox_denoise_no_mask(y, y, n_in, tols[0], check_tol=True)
        for it in range(n_out - 1):   
            model.cal_mask(c_k)
            c_k_new = model.prox_denoise_with_mask(y, c_k, n_in, tols[it+1], check_tol=True)
            rel_err_out = (torch.norm(c_k - c_k_new) / torch.norm(c_k)).item() 
            sc = "{:0.1e}".format(rel_err_out)
            c_k = c_k_new
    
    return c_k

In [6]:
i = 0
psnrs_all = list()

with torch.no_grad():
    for runs in range(5):
        psnrs = list()
        for img in test_dataset:
            i = i + 1
            if True:
                gt = img.to(device)[None, :, :, :]
                noisy_image = (img.to(device) + (15/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
                denoised =safi_denoise(noisy_image)
                psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
                psnrs.append(psnr)
                print(i, psnr)
                
        psnrs_all.append(psnrs)

1 32.39
2 30.93
3 33.19
4 32.27
5 32.43
6 29.37
7 31.74
8 31.34
9 32.78
10 31.47
11 30.77
12 28.92
13 31.98
14 29.35
15 30.7
16 33.24
17 30.92
18 31.55
19 33.23
20 31.94
21 35.21
22 29.1
23 29.01
24 27.53
25 30.62
26 31.47
27 30.73
28 35.11
29 29.03
30 32.22
31 30.42
32 33.86
33 32.34
34 26.9
35 31.23
36 30.74
37 31.59
38 30.15
39 35.56
40 29.01
41 32.49
42 35.09
43 32.9
44 33.02
45 29.72
46 32.17
47 31.34
48 29.37
49 28.32
50 40.77
51 30.29
52 35.28
53 36.23
54 32.67
55 31.54
56 29.81
57 29.3
58 30.09
59 27.41
60 29.27
61 30.73
62 30.47
63 32.9
64 27.56
65 31.71
66 35.8
67 34.83
68 32.19
69 32.37
70 30.91
71 33.22
72 32.25
73 32.44
74 29.35
75 31.76
76 31.33
77 32.77
78 31.52
79 30.75
80 28.94
81 31.98
82 29.34
83 30.68
84 33.17
85 30.86
86 31.49
87 33.23
88 31.97
89 35.24
90 29.11
91 28.99
92 27.5
93 30.7
94 31.5
95 30.81
96 35.16
97 29.04
98 32.2
99 30.41
100 33.88
101 32.29
102 26.91
103 31.17
104 30.75
105 31.63
106 30.12
107 35.59
108 29.0
109 32.51
110 35.2
111 32.89
112 32.94
1

In [8]:
psnr_mat = np.zeros((5, 68))

for i in range(5):
    psnr_mat[i, :] = np.array(psnrs_all[i])

std_vec = np.std(psnr_mat, axis=0)
avg_vec = np.mean(psnr_mat, axis=0)

print('std mean:', np.mean(std_vec))
print('mean mean:', np.mean(avg_vec))

std mean: 0.02462980282062336
mean mean: 31.55114705882354
