In [1]:
from validate_coarse_to_fine import * 
from dataloader.BSD500 import BSD500
import torch
from R_network_mm import RNet
import json
from matplotlib import pyplot as plt
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 15.0  

val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_15_mm_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]

exp = "/home/pourya/mm_final/exps/sigma15/64_2_7_1e-3_mm"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:1'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()

model.W2.W1.weight.data = model.sumtoone(model.W2.W1.weight)
for i, _ in enumerate(model.W2.W1s): 
    model.W2.W1s[i].weight.data = model.sumtoone(model.W2.W1s[i].weight)


model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)


model.eigenimage = model.eigenimage.to(device)
print(model.lmbda, p1)
lip, model.eigenimage = model.W1.cal_lip(model.eigenimage, 1000)
model.alpha = 1 / lip
model.lmbda = torch.nn.Parameter(torch.ones(1, 1).to(device) * p1)


Parameter containing:
tensor([[0.0006]], device='cuda:1', requires_grad=True) 0.0005889984895475


In [5]:
def mmr_denoise(y):
    s = 1e-3
    n = 5
    e = 1e-5
    q = (e/s) ** (1/n)
    tols = [s * q**(i) for i in range(n+1)]
    n_out = 10
    for _ in range(n_out-(n+1)):
        tols.append(e)
    with torch.no_grad(): 
        n_in = 100  
        c_k = model.prox_denoise_no_mask(y, y, n_in, tols[0], check_tol=True)
        for it in range(n_out - 1):   
            model.cal_mask(c_k)
            c_k = model.prox_denoise_with_mask(y, c_k, n_in, tols[it+1], check_tol=True)

    return c_k

In [6]:
i = 0
psnrs_all = list()

with torch.no_grad():
    for runs in range(5):
        psnrs = list()
        for img in test_dataset:
            i = i + 1
            if True:
                gt = img.to(device)[None, :, :, :]
                noisy_image = (img.to(device) + (15/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
                denoised =mmr_denoise(noisy_image)
                psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
                psnrs.append(psnr)
                print(i, psnr)
                
        psnrs_all.append(psnrs)

1 31.95
2 30.51
3 32.84
4 31.83
5 32.0
6 29.12
7 31.02
8 30.9
9 32.13
10 31.15
11 30.2
12 28.69
13 31.67
14 29.07
15 30.08
16 32.66
17 30.52
18 30.84
19 32.73
20 31.25
21 34.71
22 27.45
23 28.71
24 27.24
25 30.36
26 30.96
27 30.11
28 34.76
29 28.74
30 31.62
31 29.98
32 33.11
33 31.55
34 26.72
35 30.82
36 30.4
37 31.14
38 29.91
39 34.63
40 28.69
41 32.12
42 34.09
43 32.09
44 32.37
45 29.29
46 31.8
47 31.0
48 29.09
49 28.04
50 39.64
51 29.85
52 34.9
53 35.12
54 32.25
55 31.15
56 29.54
57 28.78
58 29.65
59 27.21
60 29.02
61 30.19
62 30.12
63 32.06
64 27.43
65 31.26
66 34.87
67 34.25
68 31.84
69 31.95
70 30.54
71 32.74
72 31.77
73 32.04
74 29.13
75 30.99
76 30.85
77 32.12
78 31.13
79 30.14
80 28.75
81 31.63
82 29.11
83 30.09
84 32.65
85 30.54
86 30.88
87 32.68
88 31.18
89 34.76
90 27.44
91 28.74
92 27.28
93 30.42
94 30.84
95 30.06
96 34.77
97 28.73
98 31.6
99 29.93
100 33.08
101 31.55
102 26.72
103 30.83
104 30.46
105 31.15
106 29.91
107 34.67
108 28.66
109 32.15
110 33.91
111 32.12
112 32

In [8]:
psnr_mat = np.zeros((5, 68))

for i in range(5):
    psnr_mat[i, :] = np.array(psnrs_all[i])

std_vec = np.std(psnr_mat, axis=0)
avg_vec = np.mean(psnr_mat, axis=0)

print('std mean:', np.mean(std_vec))
print('mean mean:', np.mean(avg_vec))

std mean: 0.025895517156123194
mean mean: 31.055058823529407
