# 2. Inference with COSDD

In this notebook, we load a trained model and use it to denoise the low signal-to-noise data.

In [None]:
import os
import logging

import tifffile
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import utils
from hub import Hub

logger = logging.getLogger('pytorch_lightning')
logger.setLevel(logging.WARNING)
%matplotlib inline

In [None]:
use_cuda = torch.cuda.is_available()

### 2.1. Load test data
The images that we want to denoise are loaded here.  We'll only load 10 images as denoising the entire dataset takes some time.

In [None]:
low_snr_path = "/group/dl4miacourse/image_regression/penicillium/penicillium_low_snr.tif"

low_snr = tifffile.imread(low_snr_path)[:, None]
low_snr = low_snr[500:510] # Remove this line to denoise an entire dataset
low_snr = torch.from_numpy(low_snr.astype(float)).to(torch.float)
print(low_snr.shape)

As with training, data should be a `torch.Tensor` with dimensions: [Number of images, Channels, Z | Y | X] with data type float32.

### Part 2. Create prediction dataloader

`predict_batch_size` (int): Number of denoised images to produce at a time.

In [None]:
predict_batch_size = 1

predict_set = utils.PredictDataset(low_snr)
predict_loader = torch.utils.data.DataLoader(
    predict_set,
    batch_size=predict_batch_size,
    shuffle=False,
    pin_memory=True,
)

### 2.3. Load trained model

<div class="alert alert-info">

### Task 2.1.

The model we trained in the previous notebook was only trained for 20 minutes. This is long enough to get some denoising results, but a model trained for longer would do better. In the cell below, load the model trained in the previous notebook by recalling the value you gave for `model_name`. Then procede through the notebook to look at how well it performs. 

Once you reach the end of the notebook, return to this cell to load a model that has been trained for 12 hours by uncommenting line 4, then run the notebook again to see how much difference the extra training time makes.
</div>

In [None]:
model_name = "penicillium"   ### Insert a string here
checkpoint_path = os.path.join("checkpoints", model_name)

# checkpoint_path = "pretrained_penicillium" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model

hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

trainer = pl.Trainer(
    accelerator="gpu" if use_cuda else "cpu",
    devices=1,
    enable_progress_bar=False,
)

### 2.4. Denoise
In this section, we will look at how COSDD does inference. <br>

The model denoises images randomly, giving us a different output each time. First, we will compare seven randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass.

### 2.4.1 Random sampling 
First, we will denoise each image seven times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Z | Y | X] where different denoised samples for the same image are stored along sample index.

In [None]:
use_direct_denoiser = False
n_samples = 7

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = trainer.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)

<div class="alert alert-info">

### Task 2.2.

Here, we'll look at the original noisy image and the seven denoised estimates. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop. Use this section to really explore the results. Compare high intensity reigons to low intensity reigons, zoom in and out and spot the differences between the different samples. 
</div>

In [None]:
vmin = np.percentile(low_snr.numpy(), 1)
vmax = np.percentile(low_snr.numpy(), 99)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0, 0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0, 0].set_title("Input")
for i in range(n_samples):
    ax[(i + 1) // 4, (i + 1) % 4].imshow(
        samples[img_idx][i][crop], vmin=vmin, vmax=vmax
    )
    ax[(i + 1) // 4, (i + 1) % 4].set_title(f"Sample {i+1}")

plt.show()

The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem.

### 2.4.2 MMSE estimate

In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. 

<div class="alert alert-info">

### Task 2.3.
Set `n_samples` to 100 to average 100 images, or a different value to average a different number. Then visually inspeect the results. Examine how the MMSE result differs from the random sample.
</div>

In [None]:
use_direct_denoiser = False
n_samples = 100   ### Insert an integer here

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = trainer.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)
MMSEs = torch.mean(samples, dim=1)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Sample")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

The MMSE will usually be closer to the reference than an individual sample and would score a higher PSNR, although it will also be blurrier.

### 2.4.3 Direct denoising
Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image.

<div class="alert alert-info">

### Task 2.4.

Did you enable the direct denoiser in the previous notebook? If so, set `use_direct_denoiser` to `True` to use the Direct Denoiser for inference. If not, go back to section 2.3 to load the pretrained model and return here. 

Notice how much quicker the direct denoiser is than generating the MMSE results. Visually inspect and explore the results in the same way as before, notice how similar the direct estimate and MMSE estimate are.
</div>

In [None]:
use_direct_denoiser = True   ### Insert a boolean here
hub.direct_pred = use_direct_denoiser

direct = trainer.predict(hub, predict_loader)
direct = torch.cat(direct, dim=0)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Direct")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

### 2.5. Incorrect receptive field

We've now trained a model and used it to remove structured noise from our data. Before moving onto the next notebook, we'll look at what happens when a COSDD model is trained without considering the noise structures present. 

COSDD is able to separate imaging noise from clean signal because its autoregressive decoder has a receptive field that spans pixels containing correlated noise, i.e., the row or column of pixels. If its receptive field did not contain pixels with correlated noise, it would not be able to model them and they would be captured by the VAE's latent variables. To demonstrate this, the image below shows a Direct and MMSE estimate of a denoised image where the autoregressive decoder's receptive field was incorrectly set to vertical, leaving it unable to model horizontal noise.

<img src="./resources/penicillium_ynm.png">

<div class="alert alert-success">

## Checkpoint 2

We've completed the process of training and applying a COSDD model for denoising, but there's still more it can do. Continue to the next notebook, generation.ipynb, to see how the model of the data can be used to generate new clean and noisy images.

</div>