# Training a Denoising Variational Lossy Autoencoder<br>
We will create a ladder variational autoencoder or lvae (<code>lvae</code>), an autoregressive decoder (<code>noise_model</code>) and a deterministic decoder (<code>s_decoder</code>). Our goal is to remove the noise from our data, $\mathbf{x}$, revealing an estimate of the underlying clean signal, $\mathbf{s}$. In this example, the data is corrupted by noise that is correlated along rows.


The lvae</code> and autoregressive decoder will work together to train a latent variable model of the distribution over the noisy data, with latent variables produced by the lvae and then decoded into a model of the data distribution by the autoregressive decoder. The autoregressive decoder will have a 1-dimensional receptive field, oriented horizontally, allowing it accurately model the noise component of the data distribution, hence the name <code>noise_model</code>, but not the signal component. The lvae will therefore produce latent variables containing only signal content.

The deterministic decoder will learn to take these latent variables and map them back into image space, hence the name <code>s_decoder</code>.



In [None]:
import os
import urllib

import torch
from torchvision import transforms
import tifffile
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from lvae.models.lvae import LadderVAE
from noise_model.pixelcnn import PixelCNN
from s_decoder import SDecoder
from dvlae import DVLAE

In [None]:
use_cuda = torch.cuda.is_available()

### Download data

We will be using the C. Majalis dataset, first published in: <br>
Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020, April. Removing structured noise with self-supervised blind-spot networks. In 2020 IEEE 17th International Symposium on Biomedical Imaging (ISBI) (pp. 159-163). IEEE.

In [None]:
# create a folder for our data.
if not os.path.exists("./data"):
    os.mkdir("./data")

# check if data has been downloaded already
data_path = "data/flower.tif"
if not os.path.exists(data_path):
    urllib.request.urlretrieve("https://download.fht.org/jug/n2v/flower.tif", data_path)

# load the data
low_snr = tifffile.imread(data_path).astype(float)
low_snr = torch.from_numpy(low_snr).to(torch.float32)[:, None]

### Create training and validation dataloaders

In [None]:
class TrainDatasetUnsupervised(torch.utils.data.Dataset):
    def __init__(self, images, n_iters=1, transform=None):
        self.images = images
        self.n_images = len(images)
        self.n_iters = n_iters
        self.transform = transform

    def __len__(self):
        return self.n_images * self.n_iters

    def __getitem__(self, idx):
        idx = idx % self.n_images
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image

<code>batch_size</code> Number of images in a training batch <br>
<code>crop_size</code> The data will be randomly cropped during training. This specifies the size of that crop. Reduce if images are smaller than 256x256.

In [None]:
batch_size = 4
crop_size = 256
n_iters = (low_snr[0].shape[-1] * low_snr[0].shape[-2]) // crop_size**2
transform = transforms.RandomCrop(crop_size)

low_snr = low_snr[torch.randperm(len(low_snr))]
train_set = low_snr[: int(len(low_snr) * 0.9)]
val_set = low_snr[int(len(low_snr) * 0.9) :]

train_set = TrainDatasetUnsupervised(train_set, n_iters=n_iters, transform=transform)
val_set = TrainDatasetUnsupervised(val_set, n_iters=n_iters, transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=batch_size, shuffle=False, pin_memory=True
)

### Create the models
<code>lvae</code> The ladder variational autoencoder that will output latent variables.<br>
* <code>s_code_channels</code> Number of channels in outputted latent variable. Set to 64 for reduced memory consumption.
* <code>n_layers</code> Number of levels in the hierarchical vae. Set to 6 for reduced memory consumption.
* <code>z_dims</code> the numer of latent space dimensionas at each level of the hierarchy.

<code>noise_model</code> The autoregressive decoder that will decoder the latent variables into a distribution over the input.<br>
* <code>kernel_size</code> Length of 1D convolutional kernels.
* <code>RF_shape</code> Whether the receptive field should be oriented "horizontal" or "vertical", to match the orientation of the noise.
* <code>n_filters</code> Number of feature channels.
* <code>n_out_layers</code> Number of final 1x1 convolutions.
* <code>n_gaussians</code> Number of components in Gaussian mixture used to model data.

<code>s_decoder</code> A decoder that will map the latnet variables into image space. <br>
<code>dvlae</code> The backbone that will unify and train the above three models.
* <code>n_grad_batches</code> Number of batches to accumulate gradients for before updating weights of all models.

In [None]:
s_code_channels = 128

n_layers = 14
z_dims = [s_code_channels // 2] * n_layers
if n_layers == 14:
    downsampling = [0, 1] * (n_layers // 2)
elif n_layers <= 7:
    downsampling = [1] * n_layers

lvae = LadderVAE(
    colour_channels=low_snr.shape[1],
    img_shape=(crop_size, crop_size),
    s_code_channels=s_code_channels,
    n_filters=s_code_channels,
    z_dims=z_dims,
    downsampling=downsampling,
)

noise_model = PixelCNN(
    colour_channels=low_snr.shape[1],
    s_code_channels=s_code_channels,
    kernel_size=5,
    RF_shape="horizontal",
    n_filters=64,
    n_layers=4,
    n_out_layers=1,
    n_gaussians=3,
)

s_decoder = SDecoder(
    colour_channels=low_snr.shape[1], s_code_channels=s_code_channels, n_filters=s_code_channels
)

dvlae = DVLAE(
    vae=lvae,
    noise_model=noise_model,
    s_decoder=s_decoder,
    data_mean=low_snr.mean(),
    data_std=low_snr.std(),
    n_grad_batches=4,
)

### Train

In [None]:
model_name = "convallaria"
checkpoint_path = os.path.join("checkpoints", model_name)

trainer = pl.Trainer(
    default_root_dir=checkpoint_path,
    accelerator="gpu" if use_cuda else "cpu",
    devices=1,
    max_epochs=1000,
    log_every_n_steps=len(train_set) // batch_size,
    callbacks=[EarlyStopping(patience=100, monitor="val/sd_loss")],
)

trainer.fit(dvlae, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))