# 2. Inference with COSDD

In this notebook, we load a trained model and use it to denoise the low signal-to-noise data.

In [None]:
import os
import logging
import yaml

import torch
import pytorch_lightning as pl
from pytorch_lightning.plugins.environments import LightningEnvironment
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import tifffile

import utils
from models.get_models import get_models
from models.hub import Hub

logger = logging.getLogger('pytorch_lightning')
logger.setLevel(logging.WARNING)
%matplotlib inline

In [None]:
assert torch.cuda.is_available()

### 2.1. Load test data
The images that we want to denoise are loaded here.

In [None]:
# Load data
n_dimensions = 2
axes = "SYX"
low_snr, original_sizes = utils.load_data(
    paths="./data",
    patterns="actin-confocal-lowsnr.tif",
    axes=axes,
    n_dimensions=n_dimensions,
)
print(f"Noisy data size: {low_snr.size()}")

### Part 2. Create prediction dataloader

`predict_batch_size` (int): Number of denoised images to produce at a time.

In [None]:
predict_batch_size = 1

predict_set = utils.PredictDataset(low_snr)
predict_loader = torch.utils.data.DataLoader(
    predict_set,
    batch_size=predict_batch_size,
    shuffle=False,
    pin_memory=True,
)

### 2.3. Load trained model

In the cell below, we initialise all the model components again. The parameters of the model trained in the previous notebook are loaded by setting `model_name`.

In [None]:
model_name = "actin-confocal"
checkpoint_path = os.path.join("checkpoints", model_name)
with open(os.path.join(checkpoint_path, "training-config.yaml")) as f:
    train_cfg = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
lvae, ar_decoder, s_decoder, direct_denoiser = get_models(train_cfg, low_snr.shape[1])

In [None]:
hub = Hub.load_from_checkpoint(
    os.path.join(checkpoint_path, "final_model.ckpt"),
    vae=lvae,
    ar_decoder=ar_decoder,
    s_decoder=s_decoder,
    direct_denoiser=direct_denoiser,
)

gpu_idx = [0]
predictor = pl.Trainer(
    accelerator="gpu",
    devices=gpu_idx,
    enable_progress_bar=False,
    enable_checkpointing=False,
    logger=False,
    precision="bf16-mixed",
    plugins=[LightningEnvironment()],
)

### 2.4. Denoise
In this section, we will look at how COSDD does inference. <br>

The model denoises images randomly, giving us a different output each time. First, we will compare seven randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass.

### 2.4.1 Random sampling 
First, we will denoise each image seven times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Z | Y | X] where different denoised samples for the same image are stored along sample index.

In [None]:
use_direct_denoiser = False
n_samples = 7

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = predictor.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)
if samples.dtype == torch.bfloat16:
    # bfloat16 can't be plotted, so switches to float32
    samples = samples.float()

Here, we'll look at the original noisy image and the seven denoised estimates. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop.

In [None]:
vmin = np.percentile(low_snr.numpy(), 1)
vmax = np.percentile(low_snr.numpy(), 99)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0, 0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0, 0].set_title("Input")
for i in range(n_samples):
    ax[(i + 1) // 4, (i + 1) % 4].imshow(
        samples[img_idx][i][crop], vmin=vmin, vmax=vmax
    )
    ax[(i + 1) // 4, (i + 1) % 4].set_title(f"Sample {i+1}")

plt.show()

The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem.

### 2.4.2 MMSE estimate

In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. 

In [None]:
use_direct_denoiser = False
n_samples = 100

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = predictor.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1)
if samples.dtype == torch.bfloat16:
    # bfloat16 can't be plotted, so switches to float32
    samples = samples.float()
MMSEs = torch.mean(samples, dim=1)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Sample")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

### 2.4.3 Direct denoising
Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image.

The following cell uses the direct denoiser for inference and saves the result to the `direct` variable.

In [None]:
use_direct_denoiser = True
hub.direct_pred = use_direct_denoiser

direct = predictor.predict(hub, predict_loader)
direct = torch.cat(direct, dim=0)
if direct.dtype == torch.bfloat16:
    # bfloat16 can't be plotted, so switches to float32
    direct = direct.float()

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(low_snr[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Direct")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

In the following cell, choose a directory to save your results by setting `save_dir`.

Choose which images you want to save out of `direct` or `MMSEs` by adding or removing them from the `to_save` list.

In [None]:
save_dir = "results"
to_save = ["direct", "MMSEs"]

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
results = {"direct": direct, "MMSEs": MMSEs}
for s in to_save:
    save_path = os.path.join(save_dir, s + ".tif")
    result = results[s]
    if result.dtype == torch.bfloat16:
        # bfloat16 can't be saved as tiff, so switches to float32
        result = result.float()
    # Restores dimensions to how they were stored before converting to pytorch [S, C, Z | Y | X]
    result = utils.SCZYX_to_axes(
        result.numpy(), original_axes=axes, original_sizes=original_sizes
    )
    tifffile.imwrite(save_path, result)