# 2. Inference with COSDD

In this notebook, we load a trained model and use it to denoise the low signal-to-noise data. We'll then use reference high signal-to-noise data to evaluate its performance.

In [None]:
import os
import logging

import tifffile
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as PSNR

import utils
from dvlae import DVLAE

logger = logging.getLogger('pytorch_lightning')
logger.setLevel(logging.WARNING)
%matplotlib inline

In [None]:
use_cuda = torch.cuda.is_available()

### Part 1. Load test data
The images that we want to denoise are loaded here. Since high signal-to-noise ratio reference images are available for this dataset, we'll load those too.<br>
For the Actin-Confocal dataset, we follow original authors and use the last 8 images as a test set.

In [None]:
low_snr_path = "/group/dl4miacourse/image_regression/mito-confocal/mito-confocal-lowsnr.tif"
high_snr_path = "/group/dl4miacourse/image_regression/mito-confocal/mito-confocal-highsnr.tif"

low_snr = tifffile.imread(low_snr_path)[-8:, None]
low_snr = torch.from_numpy(low_snr.astype(float)).to(torch.float)
high_snr = tifffile.imread(high_snr_path)[-8:, None]
print(low_snr.shape)

# The high snr reference images for this dataset are on a different scale to the low snr images.
# We will scale and shift the reference images to match the noisy images.
high_snr = utils.minimise_mse(high_snr, low_snr.numpy())

As with training, data should be a `torch.Tensor` with dimensions: [Number of images, Channels, Height, Width] with data type float32.

### Part 2. Create prediction dataloader

`predict_batch_size` Number of denoised images to produce at a time.

In [None]:
predict_batch_size = 1

predict_set = utils.PredictDataset(low_snr)
predict_loader = torch.utils.data.DataLoader(
    predict_set,
    batch_size=predict_batch_size,
    shuffle=False,
    pin_memory=True,
)

### Part 3. Load trained model

<div class="alert alert-info">

## Task 1

We need to know what name the model was given to load it. Look at Part 5 of the training.ipynb notebook and find the value we need to give for `model_name` to load the model trained there.
</div>

In [None]:
model_name = ...   ### Insert a string here
# model_name = "mito-confocal"
checkpoint_path = os.path.join("checkpoints", model_name)

dvlae = DVLAE.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

trainer = pl.Trainer(
    accelerator="gpu" if use_cuda else "cpu",
    devices=1,
    enable_progress_bar=False,
)

### Part 4. Denoise
In this section, we will look at how COSDD does inference. <br>

The model denoises images randomly, giving us a different output each time. First, we will compare two randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass.

### 4.1 Random sampling 
First, we will denoise each image six times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Height, Width] where different denoised samples for the same image are stored along sample index.

In [None]:
use_direct_denoiser = False
n_samples = 6

dvlae.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = trainer.predict(dvlae, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)

<div class="alert alert-info">

### Task 2

Here, we look at the original noisy image, the six denoised estimates and the reference high snr image. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop.
</div>

In [None]:
vmin = np.percentile(low_snr.numpy(), 1)
vmax = np.percentile(low_snr.numpy(), 99)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0, 0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0, 0].set_title("Input")
for i in range(n_samples):
    ax[(i + 1) // 4, (i + 1) % 4].imshow(
        samples[img_idx][i][crop], vmin=vmin, vmax=vmax
    )
    ax[(i + 1) // 4, (i + 1) % 4].set_title(f"Sample {i+1}")
ax[1, 3].imshow(high_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1, 3].set_title("Reference")

plt.show()

The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem. We can use the reference high snr data to compare their Peak Signal-to-Noise Ration (PSNR).

In [None]:
for i in range(n_samples):
    psnrs = []
    for j in range(len(low_snr)):
        gt = high_snr[j].squeeze()
        test = samples[j, i].numpy().squeeze()

        data_range = np.max(gt) - np.min(gt)

        psnrs.append(PSNR(gt, test, data_range=data_range.item()))

    print(f"PSNR sample {i}: {np.mean(psnrs)}")

### 4.2 MMSE estimate

In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. 

<div class="alert alert-info">

### Task 3
Set `n_samples` to 100 to average 100 images, or a different value to average a different number. Then visually inspeect the results.
</div>

In [None]:
use_direct_denoiser = False
n_samples = ...   ### Insert an integer here
# n_samples = 100

dvlae.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = trainer.predict(dvlae, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)
MMSEs = torch.mean(samples, dim=1)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 4, figsize=(16, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Sample")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")
ax[3].imshow(high_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[3].set_title("Reference")

plt.show()

The MMSE will usuallty be closer to the reference than an individual sample and would score a higher PSNR, although it will also be blurrier.

In [None]:
psnrs = []
for j in range(len(low_snr)):
    gt = high_snr[j].squeeze()
    test = MMSEs[j].numpy().squeeze()

    data_range = np.max(gt) - np.min(gt)

    psnrs.append(PSNR(gt, test, data_range=data_range.item()))

print(f"PSNR MMSE: {np.mean(psnrs)}")

### 4.3 Direct denoising
Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image.

<div class="alert alert-info">

### Task 4

Set `use_direct_denoiser` to `True` to use the Direct Denoiser for inference instead of taking random samples, then visually inspect the results.
</div>

In [None]:
use_direct_denoiser = ...   ### Insert a boolean here
# use_direct_denoiser = True
dvlae.direct_pred = use_direct_denoiser

direct = trainer.predict(dvlae, predict_loader)
direct = torch.cat(direct, dim=0)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 4, figsize=(16, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Direct")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")
ax[3].imshow(high_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[3].set_title("Reference")

plt.show()

The PSNR of the direct estimate is often higher than the PSNR of the average of 100 samples.

In [None]:
psnrs = []
for j in range(len(low_snr)):
    gt = high_snr[j].squeeze()
    test = direct[j].numpy().squeeze()

    data_range = np.max(gt) - np.min(gt)

    psnrs.append(PSNR(gt, test, data_range=data_range.item()))

print(f"PSNR direct: {np.mean(psnrs)}")

<div class="alert alert-success">

### Checkpoint 2
Continue to the next notebook, generation.ipynb

</div>