# Exercise 3. Generating new images with COSDD

As mentioned in the training.ipynb notebook, COSDD is a deep generative model that captures the structures and characteristics of our data. In this notebook, we'll see how accurately it can represent our training data, in both the signal and the noise. We'll do this by using the model to generate entirely new images. These will be images that look like the ones in our training data but don't actually exist. This is the same as how models like DALL-E can generate entirely new images.

In [None]:
import os
import yaml

import torch
import matplotlib.pyplot as plt
import numpy as np

import utils
from models.get_models import get_models
from models.hub import Hub

%matplotlib inline

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

### 3.1. Load test data
The images that we want to denoise are loaded here.

In [None]:
# Load data
low_snr, original_sizes = utils.load_data(paths="./data",
                          patterns="actin-confocal-lowsnr.tif",
                          axes="SYX",
                          n_dimensions=2)
print(f"Noisy data size: {low_snr.size()}")

### 3.2. Load trained model

In the cell below, we initialise all the model components again. The parameters of the model trained in the previous notebook are loaded by setting `model_name`.

In [None]:
model_name = "actin-confocal"
checkpoint_path = os.path.join("checkpoints", model_name)
with open(os.path.join(checkpoint_path, "training-config.yaml")) as f:
    train_cfg = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
lvae, ar_decoder, s_decoder, direct_denoiser = get_models(train_cfg, low_snr.shape[1])

In [None]:
hub = Hub.load_from_checkpoint(
    os.path.join(checkpoint_path, "final_model.ckpt"),
    vae=lvae,
    ar_decoder=ar_decoder,
    s_decoder=s_decoder,
    direct_denoiser=direct_denoiser,
).to(device)

### 3.2. Generating new noise for a real noisy image

First, we'll pass a noisy image to the VAE and generate a random sample from the AR decoder. This will give us another noisy image with the same underlying clean signal but a different random sample of noise.

`inp_image` (torch.Tensor): The real noisy image we're going to add a different random sample of noise to.<br>
`denoised` (torch.Tensor): The denoised version of `inp_image`.<br>
`noisy` (torch.Tensor): The same underlying signal as `inp_image` but a different sample of noise.

In [None]:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_cuda):
    inp_image = low_snr[:1, :, :512, :512].to(device)
    reconstructions = hub.reconstruct(inp_image)
    denoised = reconstructions["s_hat"].cpu()
    noisy = reconstructions["x_hat"].cpu()

In [None]:
vmin = np.percentile(inp_image.cpu().numpy(), 0.1)
vmax = np.percentile(inp_image.cpu().numpy(), 99.9)

Now we will look at the original noisy image and the generated noisy image. Adjust `top`, `bottom`, `left` and `right` to view different crops of the reconstructed image.

In [None]:
top = 0
bottom = 512
left = 0
right = 512

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(inp_image[0][crop].cpu(), vmin=vmin, vmax=vmax)
ax[0].set_title("Original noisy image")
ax[1].imshow(noisy[0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Generated noisy image")
ax[2].imshow(denoised[0][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("Denoised image")

plt.show()

The spatial correlation of the generated noise can be compared to that of the real noise to get an idea of how accurate the model is. Since we have the denoised version of the generated image, we can get a noise sample by just subtracting it from the noisy versions.

In [None]:
real_noise = low_snr[0, 0, 300:500, :200]
generated_noise = noisy[0, 0] - denoised[0, 0]

real_ac = utils.autocorrelation(real_noise, max_lag=25)
generated_ac = utils.autocorrelation(generated_noise, max_lag=25)

fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ac1 = ax[0].imshow(real_ac, cmap="seismic", vmin=-1, vmax=1)
ax[0].set_title("Autocorrelation of real noise")
ax[0].set_xlabel("Horizontal lag")
ax[0].set_ylabel("Vertical lag")
ac2 = ax[1].imshow(generated_ac, cmap="seismic", vmin=-1, vmax=1)
ax[1].set_title("Autocorrelation of generated noise")
ax[1].set_xlabel("Horizontal lag")
ax[1].set_ylabel("Vertical lag")

fig.colorbar(ac2, fraction=0.045)
plt.show()

### 3.3. Generating new images

This time, we'll generate a sample from the VAE's prior and use the two decoders to reveal a brand new clean image and its noisy version.

In [None]:
n_imgs = 1
reconstructions = hub.sample_prior(n_imgs=n_imgs)
denoised = reconstructions["s"].cpu()
noisy = reconstructions["x"].cpu()

In [None]:
top = 0
bottom = 256
left = 0
right = 256

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(noisy[0][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Generated noisy image")
ax[1].imshow(denoised[0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Generated clean image")

plt.show()