# Hierarchical DivNoising - Prediction
This notebook contains an example on how to use a previously trained Hierarchical DivNoising model to denoise images.
If you haven't done so please first run '1-train_noise_model.ipynb' and '2-train_denoisers.ipynb' notebooks.

In [1]:
import os

import torch
import tifffile
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from backbone import Backbone
from hdn.lib.utils import PSNR

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Load noisy test data
The GT test data (```signal```) is created by averaging the noisy images (```observations```).

In [None]:
path = "./data/Convallaria_diaphragm/"

# The test data is just one quater of the full image ([:,:512,:512]) following the works which have used this data earlier
observation = tifffile.imread(
    path + "20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif"
).astype(np.float32)[:, np.newaxis, :512, :512]
signal = np.mean(observation[:, ...], axis=0, keepdims=True).repeat(
    observation.shape[0], axis=0
)
img_width, img_height = signal.shape[1], signal.shape[2]

plt.figure(figsize=(15, 5))
plt.imshow(signal[0, 0], cmap="magma")

# Load our model

In [None]:
model_name = "convallaria"
checkpoint_path = os.path.join("checkpoints", model_name)

backbone = Backbone.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

# Carry out inference

In this cell we use the traditional approach, averaging samples from $q(\text{signal}|\text{observation})$ to estimate $\mathbb{E}_{q(\text{signal}|\text{observation})}[\text{signal}]$, for each observation in our inference set.

In [None]:
n_samples = 100  # Number of samples to average
batch_size = 1  # Number of samples to predict at a time

backbone.vae.to(device)

mmses = []
for i in tqdm(range(observation.shape[0])):
    img = torch.from_numpy(observation[i : i + 1]).to(device)

    samples = backbone.predict_vae(img, n_samples=n_samples, batch_size=batch_size)
    samples = samples.cpu().numpy()
    mmse = samples.mean(0, keepdims=True)
    mmses.append(mmse)
mmses = np.concatenate(mmses, axis=0)

In this cell, the Direct Denoiser estimates $\mathbb{E}_{q(\text{signal}|\text{observation})}[\text{signal}]$ for each observation in our inference set in a single pass.

In [None]:
backbone.direct_denoiser.to(device)

direct_estimates = []
for i in tqdm(range(observation.shape[0])):
    img = torch.from_numpy(observation[i : i + 1]).to(device)

    direct_estimate = backbone.predict_direct_denoiser(img)
    direct_estimate = direct_estimate.cpu().numpy()
    direct_estimates.append(direct_estimate)
direct_estimates = np.concatenate(direct_estimates, axis=0)

In [None]:
direct_estimates[0].shape

# Compute PSNR
The higher the PSNR, the better the denoising performance is.
PSNR is computed using the formula: 

```PSNR = 20 * log(rangePSNR) - 10 * log(mse)``` <br> 
where ```mse = mean((gt - img)**2)```, ```gt``` is ground truth image and ```img``` is the prediction. All logarithms are with base 10.<br>
rangePSNR = max(```gt```)-min(```gt```) for as used in this [paper](https://ieeexplore.ieee.org/abstract/document/9098612/).

In [None]:
# PSNR of results from old approach
range_psnr = np.max(signal[0]) - np.min(signal[0])
old_psnrs = []
for i in range(len(mmses)):
    psnr = PSNR(signal[i], mmses[i], range_psnr)
    old_psnrs.append(psnr)
    print("image:", i, "PSNR:", psnr, "Mean PSNR:", np.mean(old_psnrs))

In [None]:
# PSNR of results from new approach
range_psnr = np.max(signal[0]) - np.min(signal[0])
new_psnrs = []
for i in range(len(mmses)):
    psnr = PSNR(signal[i], direct_estimates[i], range_psnr)
    new_psnrs.append(psnr)
    print("image:", i, "PSNR:", psnr, "Mean PSNR:", np.mean(new_psnrs))

### Visualize results

In [None]:
idx = 0
img_patch = (0, slice(200, 300), slice(200, 300))

fig, ax = plt.subplots(2, 2)

ax[0, 0].imshow(observation[idx][img_patch], cmap="magma")
ax[0, 0].set_title("Observation")
ax[0, 0].axis("off")

ax[0, 1].imshow(signal[idx][img_patch], cmap="magma")
ax[0, 1].set_title("Ground truth")
ax[0, 1].axis("off")

ax[1, 0].imshow(mmses[idx][img_patch], cmap="magma")
ax[1, 0].set_title("Denoised (old approach)")
ax[1, 0].axis("off")

ax[1, 1].imshow(direct_estimates[idx][img_patch], cmap="magma")
ax[1, 1].set_title("Denoised (new approach)")
ax[1, 1].axis("off")

plt.tight_layout()