# Hierarchical DivNoising - Prediction
This notebook contains an example on how to use a previously trained Hierarchical DivNoising model to denoise images corrupted with pixel noises and horizontal structured noise.
If you haven't done so please first run '1-CreateNoiseModel.ipynb' and '2-Training.ipynb' notebooks.

In [None]:
# We import all our dependencies.
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.optimizer import Optimizer

import sys
sys.path.append('../../../')

from models.lvae import LadderVAE
from boilerplate import boilerplate
import lib.utils as utils
import training

import os
import glob
import zipfile
import urllib
from tifffile import imread, imsave
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Load noisy test data
The GT test data (```signal```) is created by averaging the noisy images (```observation```).

In [None]:
path = "./data/Struct_Convallaria/"
# The test data is just one quater of the full image ([:,:512,:512]) following the works which have used this data earlier
observation = imread(path+"flower.tif")[:,:512,:512].astype("float32")
signal=np.mean(observation[:,...],axis=0)[np.newaxis,...]
img_width, img_height = signal.shape[1], signal.shape[2]

plt.figure(figsize=(15, 5))
plt.imshow(signal[0],cmap='magma')

# Load our model

In [None]:
%%capture
model = torch.load(".Trained_model/model/convallaria_last_vae.net")
model.mode_pred=True
model.eval()

# Compute PSNR
The higher the PSNR, the better the denoising performance is.
PSNR is computed using the formula: 

```PSNR = 20 * log(rangePSNR) - 10 * log(mse)``` <br> 
where ```mse = mean((gt - img)**2)```, ```gt``` is ground truth image and ```img``` is the prediction from HDN. All logarithms are with base 10.<br>
rangePSNR = max(```gt```)-min(```gt```) for as used in this [paper](https://ieeexplore.ieee.org/abstract/document/9098612/).

In [None]:
gaussian_noise_std = None
num_samples = 100 # number of samples used to compute MMSE estimate
tta = False # turn on test time augmentation when set to True. It may improve performance at the expense of 8x longer prediction time
psnrs = []
range_psnr = np.max(signal[0])-np.min(signal[0])
for i in range(observation.shape[0]):
    img_mmse, samples = boilerplate.predict(observation[i],num_samples,model,gaussian_noise_std,device,tta)
    psnr = utils.PSNR(signal[0], img_mmse, range_psnr)
    psnrs.append(psnr)
    print("image:", i, "PSNR:", psnr, "Mean PSNR:", np.mean(psnrs))

# Here we look at some qualitative solutions

In [None]:
fig=plt.figure(figsize=(20, 10))
gt = signal[0][250:400,175:325]
vmin=np.percentile(gt,0)
vmax=np.percentile(gt,99)


columns = 5
rows = 1
fig.add_subplot(rows, columns, 1)
plt.imshow(observation[-1][250:400,175:325],cmap='magma')
plt.title("Raw")
fig.add_subplot(rows, columns, 2)
plt.imshow(gt,vmin=vmin, vmax=vmax,cmap='magma')
plt.title("GT")
fig.add_subplot(rows, columns, 3)
plt.imshow(img_mmse[250:400,175:325],vmin=vmin, vmax=vmax,cmap='magma')
plt.title("MMSE")
for i in range(4, columns*rows+1):
    img = samples[i][250:400,175:325]
    fig.add_subplot(rows, columns, i)
    plt.imshow(img,vmin=vmin, vmax=vmax,cmap='magma')
    plt.title("Sample "+str(i-4))
plt.show()