In this notebook, we use the trained noise model to guide the training of a VAE for denoising.

In [1]:
import sys
import os

import torch
from torchvision import transforms
from tifffile import imread
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

sys.path.append("../")
from noise_model.PixelCNN import PixelCNN
from HDN.models.lvae import LadderVAE
from utils.dataloaders import create_dn_loader

Load noisy images
These should be numpy ndarrays of shape [Number, Channels, Height, Width] or [Number, Height, Width]. </br>
If working with 1-dimensional signals, the shape should be [Number, Channels, 1, Width] or [Number, 1, Width]


In [2]:
observation_location = "../data/conv/observation.tif"
observation = imread(observation_location)

Load trained noise model and disable gradients


In [3]:
noise_model_location = "../nm_checkpoint/conv/final_params.ckpt"
noise_model = PixelCNN.load_from_checkpoint(noise_model_location).eval()

for param in noise_model.parameters():
    param.requires_grad = False

Create data loaders and get the shape, mean and standard deviation of the noisy images.</br>
Use the transforms argument to apply a torchvision transformation to images as they are loaded. E.g. `transform = transforms.RandomCrop(64)`.

In [4]:
transform = None
dn_train_loader, dn_val_loader, img_shape, data_mean, data_std = create_dn_loader(
    observation, batch_size=8, split=0.8, transform=transform
)

Set denoiser checkpoint directory


In [5]:
dn_checkpoint_path = "../dn_checkpoint/conv"

Initialise trainer and noise model.</br>


The defauly hyperparameters should work for most cases, but if training takes too long or an out of memory error is encountered, the `num_latents` can be decreased to `4`to reduce the size of the network while still getting good results. Alternatively, better performance could be achieved by increasing the `num_latents` to `8` and `z_dims` to `[64] * num_latents`.</br>
Sometimes, increasing `dropout` to `0.1` or `0.2` can help when working with a limited amount of training data.</br>
The `free_bits` value has the effect of setting the minimum amount of information expressed by the latent variables, information that will thus not be modelled by the decoder/noise model. Since our decoder/noise model is pretrained and frozen, the information it can model is predetermined and setting a `free_bits` value greater than zero should not be necessary. However, if the kl_loss is observed to drop very fast in tensorboard and plateau at a value less than ~1e-2, increasing the `free_bits` to `0.5`-`1.0` can help prevent the objective getting stuck in this undesirable equilibrium.


In [None]:
use_cuda = torch.cuda.is_available()
trainer = pl.Trainer(
    default_root_dir=dn_checkpoint_path,
    accelerator="gpu" if use_cuda else "cpu",
    devices=[1],
    max_epochs=500,
    logger=TensorBoardLogger(dn_checkpoint_path),
    log_every_n_steps=len(dn_train_loader),
    callbacks=[LearningRateMonitor(logging_interval="epoch")],
)

num_latents = 6
z_dims = [32] * num_latents
vae = LadderVAE(
    z_dims=z_dims,
    noiseModel=noise_model,
    img_shape=img_shape,
    gaussian_noise_std=None,
    use_uncond_mode_at=[],
    dropout=0.0,
    free_bits=0.0,
    data_mean=data_mean,
    data_std=data_std,
)

Train and save final parameters</br>
Training logs can be monitored on Tensorboard. Run the two cells below to activate it in the notebook. Alternatively, open a terminal, activate an environment with Tensorboard installed and enter `tensorboard --logdir path/to/autonoise/nm_checkpoint/` then open a browser and enter localhost:6006. 

The main metric to monitor here is the validation reconstruction loss, or val/reconstruction_loss. This should go down sharply at first then level off. The kl divergence, or kl_loss, is expected to go either up or down. The evidence lower bound, or elbo, is the sum of these two losses, and training should stop when both of these have plateaued. 

In [7]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ../dn_checkpoint

In [9]:
trainer.fit(vae, dn_train_loader, dn_val_loader)
trainer.save_checkpoint(os.path.join(dn_checkpoint_path, "final_params.ckpt"))

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | first_bottom_up  | Sequential           | 75.8 K
1 | top_down_layers  | ModuleList           | 4.1 M 
2 | bottom_up_layers | ModuleList           | 2.7 M 
3 | final_top_down   | Sequential           | 412 K 
4 | likelihood       | NoiseModelLikelihood | 5.3 M 
----------------------------------------------------------
7.3 M     Trainable params
5.3 M     Non-trainable params
12.6 M    Total params
50.546    Total estimated model params size (MB)


Epoch 197: 100%|██████████| 6/6 [00:02<00:00,  2.82it/s, v_num=1]          Epoch 00198: reducing learning rate of group 0 to 1.5000e-04.
Epoch 214: 100%|██████████| 6/6 [00:02<00:00,  2.82it/s, v_num=1]Epoch 00215: reducing learning rate of group 0 to 7.5000e-05.
Epoch 231: 100%|██████████| 6/6 [00:02<00:00,  2.83it/s, v_num=1]Epoch 00232: reducing learning rate of group 0 to 3.7500e-05.
Epoch 262: 100%|██████████| 6/6 [00:02<00:00,  2.82it/s, v_num=1]Epoch 00263: reducing learning rate of group 0 to 1.8750e-05.
Epoch 279: 100%|██████████| 6/6 [00:02<00:00,  2.82it/s, v_num=1]Epoch 00280: reducing learning rate of group 0 to 9.3750e-06.
Epoch 290: 100%|██████████| 6/6 [00:02<00:00,  2.83it/s, v_num=1]Epoch 00291: reducing learning rate of group 0 to 4.6875e-06.
Epoch 301: 100%|██████████| 6/6 [00:02<00:00,  2.81it/s, v_num=1]Epoch 00302: reducing learning rate of group 0 to 2.3437e-06.
Epoch 312: 100%|██████████| 6/6 [00:02<00:00,  2.82it/s, v_num=1]Epoch 00313: reducing learning rate 

`Trainer.fit` stopped: `max_epochs=500` reached.


Epoch 499: 100%|██████████| 6/6 [00:02<00:00,  2.35it/s, v_num=1]
