# DivNoising - Training
This notebook contains an example on how to train a DivNoising VAE.  This requires having a noise model (model of the imaging noise) which can be either measured from calibration data or estimated from raw noisy images themselves. If you haven't done so, please first run 'Convallaria-CreateNoiseModel.ipynb', which will download the data and create a noise model. 

In [1]:
# We import all our dependencies.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
from torch.nn import init
import os
import glob
from tifffile import imread
from matplotlib import pyplot as plt

import sys
sys.path.append('../../')
from divnoising import dataLoader
from divnoising import utils
from divnoising import training_multi
from nets import model_multi
from divnoising import histNoiseModel
from divnoising.gaussianMixtureNoiseModel import GaussianMixtureNoiseModel

import urllib
import os
import zipfile

from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")

### Specify ```path``` to load data
Your data should be stored in the directory indicated by ```path```.

In [3]:
# create a folder for our data
if not os.path.isdir('./data'):
    os.mkdir('data')

# check if data has been downloaded already
zipPath="data/BSD68_reproducibility.zip"
if not os.path.exists(zipPath):
    #download and unzip data
    data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/pbj89sV6n6SyM29/download', zipPath)
    with zipfile.ZipFile(zipPath, 'r') as zip_ref:
        zip_ref.extractall("data")

In [4]:
signal = np.load('data/BSD68_reproducibility_data/test/bsd68_groundtruth.npy', allow_pickle=True)
train_images_ = signal[:int(0.85*signal.shape[0])]
val_images_ = signal[int(0.85*signal.shape[0]):]

In [5]:
patch_size = 128
img_width = 321
img_height = 481
num_patches = int(float(img_width*img_height)/float(patch_size**2)*2)

x_train_crops = utils.extract_patches(train_images_, patch_size, num_patches)
x_val_crops = utils.extract_patches(val_images_, patch_size, num_patches)
train_images = x_train_crops
val_images = x_val_crops



100%|██████████| 57/57 [00:00<00:00, 758.78it/s]
100%|██████████| 11/11 [00:00<00:00, 870.83it/s]


In [6]:
gaussian_noise_std = 25
train_images = train_images+np.random.randn(train_images.shape[0], train_images.shape[1], train_images.shape[2])*gaussian_noise_std
val_images = val_images+np.random.randn(val_images.shape[0], val_images.shape[1], val_images.shape[2])*gaussian_noise_std

In [7]:
print("Shape of training images:", train_images.shape, "Shape of validation images:", val_images.shape)
train_images = utils.augment_data(train_images)

Shape of training images: (1026, 128, 128) Shape of validation images: (198, 128, 128)
Raw image size after augmentation (8208, 128, 128)


# Training Data Preparation

For training we need to follow some preprocessing steps first which will prepare the data for training purposes.

### Data preprocessing
We first divide the data into training and validation sets with 85% images allocated to training set  and rest to validation set. Then we augment the training data 8-fold by 90 degree rotations and flips.

We extract overlapping patches of size ```patch_size x patch_size``` from training and validation images. Specify the parameter ```patch_size```. The number of patches to be extracted is automatically determined depending on the size of images.

Finally, we compute the mean and standard deviation of our combined train and validation sets and do some additional preprocessing.

In [8]:
data_mean, data_std = utils.getMeanStdData(train_images, val_images)

x_train, x_val = utils.convertToFloat32(train_images, val_images)
x_train_extra_axis = x_train[:,np.newaxis]
x_val_extra_axis = x_val[:,np.newaxis]

x_train_tensor = utils.convertNumpyToTensor(x_train_extra_axis)
x_val_tensor = utils.convertNumpyToTensor(x_val_extra_axis)

In [9]:
print("Shape of training tensor:", x_train_tensor.shape)

Shape of training tensor: torch.Size([8208, 1, 128, 128])


# Configure DivNoising model

Here we specify some parameters of our DivNoising network needed for training. <br> 
The parameter <code>z_dim</code> specifies the size of the bottleneck dimension corresponding to each pixel. <br> 
The parameter <code>in_channels</code> specifies the number of input channels which for this dataset is 1.<br> 
We currently have support for only 1 channel input but it may be extended to arbitrary number of channels in the future. <br>
The parameter <code>init_filters</code> specifies the number of filters in the first layer of the network. <br>
The parameter <code>n_depth</code> specifies the depth of the network. 
The parameter <code>batch_size</code> specifies the batch size used for training. <br>
The parameter <code>n_filters_per_depth</code> specifies the number of convolutions per depth. <br>
The parameter <code>directory_path</code> specifies the directory where the model will be saved. <br>
The parameter <code>n_epochs</code> specifies the number of training epochs. <br>
The parameter <code>lr</code> specifies the learning rate. <br>
The parameter <code>val_loss_patience</code> specifies the number of epochs after which training will be terminated if the validation loss does not decrease by a factor of 1e-6. <br>
The parameter <code>noiseModel</code> is the noise model you want to use. Run the notebook  ```Convallaria-CreateNoiseModel.ipynb```, if you have not yet generated the noise model for this dataset yet. If set to None a Gaussian noise model is used.<br>
The parameter <code>gaussian_noise_std</code> is the standard deviation of the Gaussian noise model. This should only be set if 'noiseModel' is None. Otherwise, if you have created a noise model already, set it to ```None```. <br>
The parameter <code>model_name</code> specifies the name of the model with which the weights will be saved for prediction later.

__Note:__ We observed good performance of the DivNosing network for most datasets with the default settings in the next cell. However, we also observed that playing with the paramaters sensibly can also improve performance.

In [10]:
in_channels = 1
init_filters = 10
batch_size=32
directory_path = "./"
n_epochs = int(22000000/(x_train_tensor.shape[0])) # A heurisitc to set the number of epochs
lr=0.001
val_loss_patience = 100
gaussian_noise_std = gaussian_noise_std
noise_model_params= None
noiseModel = None
model_name = "bsd68-10filters-"

# Train network

__Note:__ We observed that for certain datasets, the KL loss goes towards 0. This phenomenon is called ```posterior collapse``` and is undesirable.
We prevent it by aborting and restarting the training once the KL dropy below a threshold (```kl_min```).
An alternative approach is a technique called *KL Annealing* where we increase the weight on KL divergence loss term from 0 to 1 gradually in a numer of steps.
This cann be activated by setting the parameter ```kl_annealing``` to ```True```. <br>
The parameter ```kl_start``` specifies the epoch when KL annelaing will start. <br>
The parameter ```kl_annealtime``` specifies until which epoch KL annealing will be operational. <br>
If the parameter ```kl_annealing``` is set to ```False```, the values of ```kl_start``` and  ```kl_annealtime``` are ignored.

In [None]:
train_dataset = dataLoader.MyDataset(x_train_tensor,x_train_tensor)
val_dataset = dataLoader.MyDataset(x_val_tensor,x_val_tensor)

trainHist, reconHist, klHist, valHist = None, None, None, None
attempts=0
while trainHist is None:
    attempts+=1
    print('start training: attempt '+ str(attempts))
    vae = model_multi.MultiVAE(in_channels=in_channels,start_filts=init_filters,depth=3)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    trainHist, reconHist, klHist, valHist = training_multi.trainNetwork(net=vae, train_loader=train_loader, 
                                                                     val_loader=val_loader,
                                                                     device=device,directory_path=directory_path,
                                                                     model_name=model_name,
                                                                     n_epochs=n_epochs, batch_size=batch_size,lr=lr,
                                                                     val_loss_patience = val_loss_patience,
                                                                     kl_annealing = False,
                                                                     kl_start = 0, 
                                                                     kl_annealtime = 3,
                                                                     kl_min=1e-5,
                                                                     data_mean =data_mean,data_std=data_std, 
                                                                     noiseModel = noiseModel,
                                                                     gaussian_noise_std = gaussian_noise_std)
    

start training: attempt 1
Epoch[1/2680] Training Loss: 1.179 Reconstruction Loss: 0.993 KL Loss: 0.186
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
saving ./bsd68-10filters-best_vae.net
Patience: 0 Validation Loss: 0.9292399287223816 Min validation loss: 0.9292399287223816
Time for epoch: 21seconds
Est remaining time: 15:37:39 or 56259 seconds
----------------------------------------
Epoch[2/2680] Training Loss: 0.829 Reconstruction Loss: 0.605 KL Loss: 0.224
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
saving ./bsd68-10filters-best_vae.net
Patience: 0 Validation Loss: 0.8060247302055359 Min validation loss: 0.8060247302055359
Time for epoch: 22seconds
Est remaining time: 16:21:56 or 58916 seconds
----------------------------------------
Epoch[3/2680] Training Loss: 0.744 Reconstruction Loss: 0.473 KL Loss: 0.271
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
saving ./bsd68-10filters-best_vae.net
Patience: 0 Validation Loss: 0.7238110899925232 Min validation loss:

Epoch[24/2680] Training Loss: 0.507 Reconstruction Loss: 0.174 KL Loss: 0.333
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 1 Validation Loss: 0.5186618566513062 Min validation loss: 0.5168175101280212
Time for epoch: 22seconds
Est remaining time: 16:13:52 or 58432 seconds
----------------------------------------
Epoch[25/2680] Training Loss: 0.507 Reconstruction Loss: 0.175 KL Loss: 0.333
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 2 Validation Loss: 0.5202991366386414 Min validation loss: 0.5168175101280212
Time for epoch: 23seconds
Est remaining time: 16:57:45 or 61065 seconds
----------------------------------------
Epoch[26/2680] Training Loss: 0.507 Reconstruction Loss: 0.174 KL Loss: 0.333
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 3 Validation Loss: 0.5207587480545044 Min validation loss: 0.5168175101280212
Time for epoch: 22seconds
Est remaining time: 16:13:08 or 58388 seconds
----------------------------------------
Epoc

saving ./bsd68-10filters-best_vae.net
Patience: 0 Validation Loss: 0.5113505721092224 Min validation loss: 0.5113505721092224
Time for epoch: 22seconds
Est remaining time: 16:05:04 or 57904 seconds
----------------------------------------
Epoch[49/2680] Training Loss: 0.502 Reconstruction Loss: 0.172 KL Loss: 0.330
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 1 Validation Loss: 0.5123568177223206 Min validation loss: 0.5113505721092224
Time for epoch: 22seconds
Est remaining time: 16:04:42 or 57882 seconds
----------------------------------------
Epoch[50/2680] Training Loss: 0.502 Reconstruction Loss: 0.172 KL Loss: 0.330
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
saving ./bsd68-10filters-best_vae.net
Patience: 0 Validation Loss: 0.5111391544342041 Min validation loss: 0.5111391544342041
Time for epoch: 22seconds
Est remaining time: 16:04:20 or 57860 seconds
----------------------------------------
Epoch[51/2680] Training Loss: 0.502 Reconstruction Loss: 0.

Epoch[73/2680] Training Loss: 0.501 Reconstruction Loss: 0.171 KL Loss: 0.330
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 11 Validation Loss: 0.5120495557785034 Min validation loss: 0.5092954635620117
Time for epoch: 22seconds
Est remaining time: 15:55:54 or 57354 seconds
----------------------------------------
Epoch[74/2680] Training Loss: 0.500 Reconstruction Loss: 0.171 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 12 Validation Loss: 0.5132355093955994 Min validation loss: 0.5092954635620117
Time for epoch: 22seconds
Est remaining time: 15:55:32 or 57332 seconds
----------------------------------------
Epoch[75/2680] Training Loss: 0.501 Reconstruction Loss: 0.172 KL Loss: 0.330
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 13 Validation Loss: 0.517033576965332 Min validation loss: 0.5092954635620117
Time for epoch: 22seconds
Est remaining time: 15:55:10 or 57310 seconds
----------------------------------------
Ep

Epoch[98/2680] Training Loss: 0.500 Reconstruction Loss: 0.171 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 1 Validation Loss: 0.508832573890686 Min validation loss: 0.5081331729888916
Time for epoch: 22seconds
Est remaining time: 15:46:44 or 56804 seconds
----------------------------------------
Epoch[99/2680] Training Loss: 0.500 Reconstruction Loss: 0.171 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 2 Validation Loss: 0.5152464509010315 Min validation loss: 0.5081331729888916
Time for epoch: 22seconds
Est remaining time: 15:46:22 or 56782 seconds
----------------------------------------
Epoch[100/2680] Training Loss: 0.500 Reconstruction Loss: 0.171 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 3 Validation Loss: 0.5091010928153992 Min validation loss: 0.5081331729888916
Time for epoch: 22seconds
Est remaining time: 15:46:00 or 56760 seconds
----------------------------------------
Epoc

Patience: 7 Validation Loss: 0.509434700012207 Min validation loss: 0.5075125694274902
Time for epoch: 22seconds
Est remaining time: 15:37:56 or 56276 seconds
----------------------------------------
Epoch[123/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 8 Validation Loss: 0.5115122199058533 Min validation loss: 0.5075125694274902
Time for epoch: 22seconds
Est remaining time: 15:37:34 or 56254 seconds
----------------------------------------
Epoch[124/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 9 Validation Loss: 0.5131334066390991 Min validation loss: 0.5075125694274902
Time for epoch: 22seconds
Est remaining time: 15:37:12 or 56232 seconds
----------------------------------------
Epoch[125/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Pa

Epoch[147/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 21 Validation Loss: 0.5093660950660706 Min validation loss: 0.5069770216941833
Time for epoch: 22seconds
Est remaining time: 15:28:46 or 55726 seconds
----------------------------------------
Epoch[148/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 22 Validation Loss: 0.509712815284729 Min validation loss: 0.5069770216941833
Time for epoch: 22seconds
Est remaining time: 15:28:24 or 55704 seconds
----------------------------------------
Epoch[149/2680] Training Loss: 0.499 Reconstruction Loss: 0.170 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 23 Validation Loss: 0.5073207020759583 Min validation loss: 0.5069770216941833
Time for epoch: 22seconds
Est remaining time: 15:28:02 or 55682 seconds
----------------------------------------

Patience: 9 Validation Loss: 0.5100373029708862 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:19:58 or 55198 seconds
----------------------------------------
Epoch[172/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 10 Validation Loss: 0.5126010775566101 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:19:36 or 55176 seconds
----------------------------------------
Epoch[173/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 11 Validation Loss: 0.5111594796180725 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:19:14 or 55154 seconds
----------------------------------------
Epoch[174/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net

Epoch[196/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 34 Validation Loss: 0.5115024447441101 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:10:48 or 54648 seconds
----------------------------------------
Epoch[197/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 35 Validation Loss: 0.5097360610961914 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:10:26 or 54626 seconds
----------------------------------------
Epoch[198/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 36 Validation Loss: 0.5105729103088379 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:10:04 or 54604 seconds
---------------------------------------

Epoch[221/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 59 Validation Loss: 0.5114241242408752 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:01:38 or 54098 seconds
----------------------------------------
Epoch[222/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 60 Validation Loss: 0.5081404447555542 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:01:16 or 54076 seconds
----------------------------------------
Epoch[223/2680] Training Loss: 0.498 Reconstruction Loss: 0.169 KL Loss: 0.329
kl_weight: 1.0
saving ./bsd68-10filters-last_vae.net
Patience: 61 Validation Loss: 0.5094484686851501 Min validation loss: 0.5055508613586426
Time for epoch: 22seconds
Est remaining time: 15:00:54 or 54054 seconds
---------------------------------------

# Plotting losses

In [None]:
trainHist=np.load(directory_path+"/train_loss.npy")
reconHist=np.load(directory_path+"/train_reco_loss.npy")
klHist=np.load(directory_path+"/train_kl_loss.npy")
valHist=np.load(directory_path+"/val_loss.npy")

In [None]:
plt.figure(figsize=(18, 3))
plt.subplot(1,3,1)
plt.plot(trainHist,label='training')
plt.plot(valHist,label='validation')
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend()

plt.subplot(1,3,2)
plt.plot(reconHist,label='training')
plt.xlabel("epochs")
plt.ylabel("reconstruction loss")
plt.legend()

plt.subplot(1,3,3)
plt.plot(klHist,label='training')
plt.xlabel("epochs")
plt.ylabel("KL loss")
plt.legend()
plt.show()