# Training a Hierarchical DivNoising network with a Direct Denoiser for Convallaria data which is intrinsically noisy
In this notebook, we train a Hierarchical DivNoising Ladder VAE alongside a Direct Denoising network for an intrinsically noisy data. This requires having a noise model (model of the imaging noise) which can be either measured from calibration data or bootstrapped from raw noisy images themselves. If you haven't done so, please first run '1-train_noise_model.ipynb', which will download the data and create a noise model. 

In [None]:
import os

import torch
from torch.utils.data import TensorDataset, DataLoader
import tifffile
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from hdn.lib import utils
from hdn.models.lvae import LadderVAE
from unet import UNet
from hdn.lib.gaussianMixtureNoiseModel import GaussianMixtureNoiseModel
from backbone import Backbone

In [None]:
use_cuda = torch.cuda.is_available()

### Specify ```path``` to load training data
Your data should be stored in the directory indicated by ```path```.

In [None]:
path = "./data/Convallaria_diaphragm/"
observation = tifffile.imread(path + "20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif")

# Training Data Preparation

For training we need to follow some preprocessing steps first which will prepare the data for training purposes.

We first divide the data into training and validation sets with 85% images allocated to training set  and rest to validation set. Then we augment the training data 8-fold by 90 degree rotations and flips.

In [None]:
train_data = observation[: int(0.85 * observation.shape[0])]
val_data = observation[int(0.85 * observation.shape[0]) :]
print(
    "Shape of training images:",
    train_data.shape,
    "Shape of validation images:",
    val_data.shape,
)
train_data = utils.augment_data(
    train_data
)  ### Data augmentation disabled for fast training, but can be enabled

In [None]:
### We extract overlapping patches of size ```patch_size x patch_size``` from training and validation images.
### Usually 64x64 patches work well for most microscopy datasets
patch_size = 64

In [None]:
img_width = observation.shape[2]
img_height = observation.shape[1]
num_patches = int(float(img_width * img_height) / float(patch_size**2) * 1)
train_images = utils.extract_patches(train_data, patch_size, num_patches)
val_images = utils.extract_patches(val_data, patch_size, num_patches)
val_images = val_images[
    :1000
]  # We limit validation patches to 1000 to speed up training but it is not necessary
img_shape = (train_images.shape[1], train_images.shape[2])
print(
    "Shape of training images:",
    train_images.shape,
    "Shape of validation images:",
    val_images.shape,
)

<code>batch_size</code>: Number of patches for which loss will be calculated before updating weights.<br>
<code>virtual_batch</code>: Number of patches that will be passed through the network at a time. Increase to save time, decrease to save memory.

In [None]:
### We create PyTorch dataloaders for training and validation data
batch_size = 64 
virtual_batch = 8
n_grad_batches = batch_size // virtual_batch

train_images = torch.from_numpy(train_images[:, np.newaxis]).float()
val_images = torch.from_numpy(val_images[:, np.newaxis]).float()
train_dataset = TensorDataset(train_images)
val_dataset = TensorDataset(val_images)
train_loader = DataLoader(train_dataset, batch_size=virtual_batch, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=virtual_batch, shuffle=True)

# Configure Hierarchical DivNoising model and the Direct Denoiser

<code>model_name</code> specifies the name of the model with which the weights will be saved and wil be loaded later for prediction.<br>
<code>checkpoint_path</code> specifies the directory where the model weights will be saved. <br>
<code>gaussian_noise_std</code> is only applicable if dataset is synthetically corrupted with Gaussian noise of known std. For real datasets, it should be set to ```None```.<br>
<code>noiseModel</code> specifies a noise model for training. If noisy data is generated synthetically using Gaussian noise, set it to None. Else set it to the GMM based noise model (.npz file)  generated from '1-CreateNoiseModel.ipynb'.<br>
<code>lr</code> specifies the learning rate.<br>
<code>max_epochs</code> specifies the total number of training epochs. Around $150-200$ epochs work well generally.<br>
<code>steps_per_epoch</code> specifies how many steps to take per epoch of training. Around $400-500$ steps work well for most datasets.<br>
<code>num_latents</code> specifies the number of stochastic layers. The default setting of $6$ works well for most datasets but quite good results can also be obtained with as less as $4$ layers. However, more stochastic layers may improve performance for some datasets at the cost of increased training time.<br>
<code>z_dims</code> specifies the number of bottleneck dimensions (latent space dimensions) at each stochastic layer per pixel. The default setting of $32$ works well for most datasets.<br>
<code>blocks_per_layer</code> specifies how many residual blocks to use per stochastic layer. Usually, setting it to be $4$ or more works well. However, more residual blocks improve performance at the cost of increased training time.<br>
<code>batchnorm</code> specifies if batch normalization is used or not. Turning it to True is recommended.<br>
<code>free_bits</code> specifies the threshold below which KL loss is not optimized for. This prevents the [KL-collapse problem](https://arxiv.org/pdf/1511.06349.pdf%3Futm_campaign%3DRevue%2520newsletter%26utm_medium%3DNewsletter%26utm_source%3Drevue). The default setting of $1.0$ works well for most datasets.<br>

In [None]:
model_name = "convallaria"
checkpoint_path = os.path.join("checkpoints", model_name)

# Load trained noise model
gaussian_noise_std = None
noise_model_params = np.load(
    "./data/Convallaria_diaphragm/GMMNoiseModel_convallaria_3_2_calibration.npz"
)
noiseModel = GaussianMixtureNoiseModel(params=noise_model_params)

# Training specific
lr = 3e-4
max_epochs = 500
steps_per_epoch = 400
limit_train_batches = steps_per_epoch * n_grad_batches

# VAE specific
num_latents = 6
z_dims = [32]*int(num_latents)
blocks_per_layer = 5
batchnorm = True
free_bits = 1.0

<code>lvae</code>: The traditional Hierarchical DivNoising model which uses the <code>noiseModel</code>, $p(\text{observation}|\text{signal})$, to train the approximate posterior, $q(\text{signal}|\text{observation})$.<br>
<code>direct_denoiser</code>: A deterministic network that is trained by DivNoising to estimate $\mathbb{E}_{q(\text{signal}|\text{observation})}[\text{signal}]$.<br>
<code>backbone</code>: The <code>lvae</code> and <code>direct_denoiser</code> models are trained simulataneously but with their own optimizers. This module handles their co-training.

In [None]:
data_mean = observation.mean()
data_std = observation.std()

lvae = LadderVAE(
    z_dims=z_dims,
    blocks_per_layer=blocks_per_layer,
    data_mean=data_mean,
    data_std=data_std,
    noiseModel=noiseModel,
    batchnorm=batchnorm,
    free_bits=free_bits,
    img_shape=img_shape,
)

direct_denoiser = UNet(depth=4, start_filters=32)

backbone = Backbone(
    lvae,
    direct_denoiser,
    data_mean=data_mean,
    data_std=data_std,
    gaussian_noise_std=gaussian_noise_std,
    n_grad_batches=n_grad_batches,
    lr=lr,
)

### Train networks<br>
Training can be monitored in Tensorboard. Enter `tensorboard --logdir <path-to-this-directory>/checkpoints/convallaria` to run it.

In [None]:
logger = TensorBoardLogger(checkpoint_path, name=model_name)

trainer = pl.Trainer(
    default_root_dir=checkpoint_path,
    accelerator="gpu" if use_cuda else "cpu",
    devices=1,
    limit_train_batches=limit_train_batches,
    max_epochs=max_epochs,
    callbacks=[EarlyStopping(patience=30, monitor="val/elbo")],
    logger=logger,
)

trainer.fit(backbone, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))