# Exercise 3. Generating new images with COSDD

As mentioned in the training.ipynb notebook, COSDD is a deep generative model that captures the structures and characteristics of our data. In this notebook, we'll see how accurately it can represent our training data, in both the signal and the noise. We'll do this by using the model to generate entirely new images. These will be images that look like the ones in our training data but don't actually exist. This is the same as how models like DALL-E can generate entirely new images.

<div class="alert alert-danger">

Set your python kernel to <code>02_regression</code>
</div>

In [None]:
import os

import tifffile
import torch
import matplotlib.pyplot as plt
import numpy as np

import utils
from hub import Hub

%matplotlib inline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 3.1. Load trained model and clean and noisy data

<div class="alert alert-info">

### Task 3.1.

Load the model trained in the first notebook by entering your `model_name`, or alternatively uncomment line 4 to load the pretrained model.
</div>

In [None]:
model_name = "penicillium"    # Insert a value here   
checkpoint_path = os.path.join("checkpoints", model_name)

# checkpoint_path = "pretrained_penicillium"

hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt")).to(device)

In [None]:
low_snr_path = "/group/dl4miacourse/image_regression/penicillium/penicillium_low_snr.tif"

low_snr = tifffile.imread(low_snr_path)[:, np.newaxis]
low_snr = torch.from_numpy(low_snr.astype(float)).to(torch.float)
print(low_snr.shape)

### 3.2. Generating new noise for a real noisy image

First, we'll pass a noisy image to the VAE and generate a random sample from the AR decoder. This will give us another noisy image with the same underlying clean signal but a different random sample of noise.

`inp_image` (torch.Tensor): The real noisy image we're going to add a different random sample of noise to.<br>
`denoised` (torch.Tensor): The denoised version of `inp_image`.<br>
`noisy` (torch.Tensor): The same underlying signal as `inp_image` but a different sample of noise.

In [None]:
inp_image = low_snr[500:501, :, :512, :512].to(device)
reconstructions = hub.reconstruct(inp_image)
denoised = reconstructions["s_hat"].cpu()
noisy = reconstructions["x_hat"].cpu()

In [None]:
vmin = np.percentile(inp_image.cpu().numpy(), 0.1)
vmax = np.percentile(inp_image.cpu().numpy(), 99.9)

<div class="alert alert-info">

### Task 3.2.

Now we will look at the original noisy image and the generated noisy image. Adjust `top`, `bottom`, `left` and `right` to view different crops of the reconstructed image.

</div>

In [None]:
top = 0
bottom = 512
left = 0
right = 512

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(inp_image[0][crop].cpu(), vmin=vmin, vmax=vmax)
ax[0].set_title("Original noisy image")
ax[1].imshow(noisy[0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Generated noisy image")
ax[2].imshow(denoised[0][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("Denoised image")

plt.show()

The spatial correlation of the generated noise can be compared to that of the real noise to get an idea of how accurate the model is. Since we have the denoised version of the generated image, we can get a noise sample by just subtracting it from the noisy versions.

In [None]:
real_noise = low_snr[-1, 0, :200, :200]
generated_noise = noisy[0, 0] - denoised[0, 0]

real_ac = utils.autocorrelation(real_noise, max_lag=25)
generated_ac = utils.autocorrelation(generated_noise, max_lag=25)

fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ac1 = ax[0].imshow(real_ac, cmap="seismic", vmin=-1, vmax=1)
ax[0].set_title("Autocorrelation of real noise")
ax[0].set_xlabel("Horizontal lag")
ax[0].set_ylabel("Vertical lag")
ac2 = ax[1].imshow(generated_ac, cmap="seismic", vmin=-1, vmax=1)
ax[1].set_title("Autocorrelation of generated noise")
ax[1].set_xlabel("Horizontal lag")
ax[1].set_ylabel("Vertical lag")

fig.colorbar(ac2, fraction=0.045)
plt.show()

### 3.3. Generating new images

This time, we'll take a sample from the VAE's prior. This will be a latent variable containing information about a brand new signal. The signal decoder will take that latent variable and convert it into a clean image. The AR decoder will take the latent variable and create an image with the same clean image plus noise.

<div class="alert alert-info">

### Task 3.3.

Set the `n_imgs` variable below to decide how many images to generate. If you set it too high you'll get an out-of-memory error, but don't worry, just restart the kernel and run again with a lower value.

Explore the images you generated in the second cell below. Look at the differences between them to see what aspects of the signal the model has learned to generate.

</div>

In [None]:
n_imgs = 5
generations = hub.sample_prior(n_imgs=n_imgs)
new_denoised = generations["s"].cpu()
new_noisy = generations["x"].cpu()

In [None]:
img_idx = 0
top = 0
bottom = 256
left = 0
right = 256

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(new_noisy[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Generated noisy image")
ax[1].imshow(new_denoised[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Generated clean image")

plt.show()

<div class="alert alert-success">

### Checkpoint 3

In this notebook, we saw how the model you trained in the first notebook has learned to describe the data. We first added a new sample of noise to an existing noisy image. We then generated a clean image that looks like it could be from the training data but doesn't actually exist. <br>
You can now optionally return to section 3.1 to load a model that's been trained for much longer, otherwise, you've finished this module on COSDD.

</div>