In [1]:
from validate_coarse_to_fine import * 
from dataloader.BSD500 import BSD500
import torch
from R_network_mm import RNet
import json
from matplotlib import pyplot as plt
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim

In [2]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def compute_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = compare_psnr(Iclean[0,:,:,:], Img[0,:,:,:], data_range=data_range)
    return PSNR
    

In [3]:
test_dataset = BSD500('data/test.h5')

In [4]:
sigma = 25.0  

val_data = pd.read_csv(f"{'denoise_val_results'}/validation_scores_{'sigma_25_mm_valid'}.csv").reset_index(drop=True)
p1 = val_data.loc[val_data["psnr"].idxmax()]["p1"]

exp = "/home/pourya/mm_final/exps/sigma25/64_2_7_1e-3_majorized"
path_ckp = exp + "/checkpoints/checkpoint_best_epoch.pth"
path_config = exp + "/config.json"
device = 'cuda:3'
config = json.load(open(path_config))
ckp = torch.load(path_ckp, map_location={'cuda:0':device,'cuda:1':device,'cuda:2':device,'cuda:3':device})

model = RNet(config['model_params'])
model.to(device)
model.load_state_dict(ckp['state_dict'])
model.eval()

model.W2.W1.weight.data = model.sumtoone(model.W2.W1.weight)
for i, _ in enumerate(model.W2.W1s): 
    model.W2.W1s[i].weight.data = model.sumtoone(model.W2.W1s[i].weight)


model.W1.W1.weight.data = model.zeromean(model.W1.W1.weight)
for i, _ in enumerate(model.W1.W1s): 
    model.W1.W1s[i].weight.data = model.zeromean(model.W1.W1s[i].weight)


model.eigenimage = model.eigenimage.to(device)
print(model.lmbda, p1)
lip, model.eigenimage = model.W1.cal_lip(model.eigenimage, 1000)
model.alpha = 1 / lip
model.lmbda = torch.nn.Parameter(torch.ones(1, 1).to(device) * p1)


Parameter containing:
tensor([[0.0011]], device='cuda:3', requires_grad=True) 0.0011094470382648


In [5]:
def safi_denoise(y):
    s = 1e-3
    n = 5
    e = 1e-5
    q = (e/s) ** (1/n)
    tols = [s * q**(i) for i in range(n+1)]
    n_out = 10
    for _ in range(n_out-(n+1)):
        tols.append(e)
    with torch.no_grad(): 
        n_in = 100  
        c_k = model.prox_denoise_no_mask(y, y, n_in, tols[0], check_tol=True)
        for it in range(n_out - 1):   
            model.cal_mask(c_k)
            c_k = model.prox_denoise_with_mask(y, c_k, n_in, tols[it+1], check_tol=True)

    return c_k

In [6]:
i = 0
psnrs = list()

with torch.no_grad():
    for img in test_dataset:
        i = i + 1
        if True:
            gt = img.to(device)[None, :, :, :]
            noisy_image = (img.to(device) + (25/255.0)*torch.randn(img.shape, device=device))[None, :, :, :]
            denoised =safi_denoise(noisy_image)
            psnr = np.round(compute_PSNR(denoised, gt, 1), 2)
            psnrs.append(psnr)
            print(i, psnr)

1 29.65
2 28.2
3 30.3
4 29.38
5 29.54
6 26.29
7 28.29
8 28.23
9 29.19
10 28.57
11 27.69
12 25.84
13 29.26
14 26.39
15 27.43
16 30.26
17 27.99
18 28.78
19 31.0
20 28.94
21 32.58
22 26.71
23 26.41
24 24.12
25 27.53
26 28.32
27 27.27
28 32.54
29 25.67
30 29.76
31 27.41
32 31.53
33 29.47
34 23.68
35 28.16
36 27.86
37 28.6
38 27.26
39 32.81
40 25.89
41 30.43
42 31.9
43 30.47
44 29.85
45 26.59
46 29.35
47 28.64
48 26.52
49 25.6
50 37.56
51 27.15
52 33.05
53 32.58
54 29.73
55 28.6
56 26.99
57 25.79
58 27.01
59 24.33
60 26.35
61 27.98
62 27.65
63 30.49
64 24.56
65 28.39
66 32.05
67 32.45
68 29.17


In [7]:
print(np.round(np.mean(np.array(psnrs)), 2), np.round(np.std(np.array(psnrs)), 2))

28.62 2.46
