# Training of StarDist3D

In [1]:
import sys, os
import yaml
import argparse

from pathlib import Path
import h5py
from skimage import io
import numpy as np
import time
import warnings

from tqdm import tqdm

import pandas as pd
import matplotlib.pyplot as plt

import torch

from stardist_tools import calculate_extents, Rays_GoldenSpiral
from stardist_tools.matching import matching, matching_dataset
from stardist_tools.csbdeep_utils import download_and_extract_zip_file


from src.training import train
from src.data.stardist_dataset import get_train_val_dataloaders
from src.utils import seed_all, prepare_conf, plot_img_label
from src.models.config import ConfigBase, Config3D
from src.models.stardist3d import StarDist3D

### Define training parameters

In [2]:
output_dir = '/directory/to/export/models/'

conf = Config3D(
    # Name to give to the model
    name                           = 'Stadist3D_v.....',
    
    # Random seed to use for reproducibility
    random_seed                    = 42,
    
    # ========================= Networks configurations ==================
    init_type                       = 'normal',
    init_gain                       = 0.02,

    backbone                        = 'resnet',

    # Subsampling factors (must be powers of 2) for each of the axes.
    # Model will predict on a subsampled grid for increased efficiency and larger field of view.
    grid                            = 'auto',

    # Anisotropy on your imaging dataset
    anisotropy                      = [4,1,1],

    # Number of channel of images
    n_channel_in                   = 1,

    # Kernel size to use in neural network (use [3,3,3] for 3D)
    kernel_size                    = [3,3,3],

    # Number of ResNet blocks to use
    resnet_n_blocks                = 4,

    # None if grid = 'auto'. Otherwise, it has to match the grid.
    resnet_n_downs                 = '',

    # Number of filter in the convolution layer before the final prediction layer.
    n_filter_of_conv_after_resnet  = 128,

    # Number of filter to use in the first convolution layer
    resnet_n_filter_base           = 32,

    # Number of convolution layers to use in each ResNet block.
    resnet_n_conv_per_block        = 3,
    use_batch_norm                  = False,

    #======================================================================

    # ========================= dataset ==================================
    # Path to datasets directory containing a daughter directory named 'train' with two daughter directories 
    # named 'images' and 'masks'
    # Note that if  evaluate is set to True, it will validate during training using the data provided under datasets/val/ 
    data_dir                       = r'/path/to/datasets/' ,

    # Fraction (0...1) of data from the `train` folder
    # to use as validation set when the `val` folder doesn't exist
    val_size                       = 0.15,
    
    # Number of rays to use in in the star-convex representation of nuclei shape 
    n_rays                         = 96,

    # Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
    foreground_prob                = 0.9,

    n_classes                      = None,
    
    # Size of image to crop from original images
    patch_size                     = [10,32,32],

    # Set True to store indices of valid patches in RAM
    cache_sample_ind               = True,

    # Set True to store training data in RAM
    cache_data                     = True,

    # Size of batches
    batch_size                     = 2,

    # Number of subprocesses to use for data training.
    num_workers                    = 0,

    # Type of augmentation to do on training data.
    # available augmentations: none|flip|randintensity
    # You can use multiple of them, e.g. flip_randintensity
    preprocess                     = "flip_randintensity",

    # Type of augmentation to do on validation data.
    preprocess_val                 = "none",

    # Range from which to sample weight to multiply image intensities.
    # Associated to `randintensity` augmentation.     
    intensity_factor_range         = [0.6, 2.],

    # Range from which to sample bias to add to image intentsities.
    # Associated to `randintensity` augmentation. 
    intensity_bias_range           = [-0.2, 0.2],

    #======================================================================


    # ========================= Training ==================================
    # Whether to use GPU
    use_gpu                        = True if torch.cuda.is_available() else False,
    #gpu_ids                       = [0],

    # whether to use Automatic Mixed Precision
    use_amp                        = True,

    # Whether to initialize model in traning mode (set optimizers, schedulers ...)
    isTrain                        = True,

    # Whether to perform evaluation during traning.
    evaluate                       = False,

    # If not None, it will load state corresponding to epoch `load_epoch` and continue training from there
    load_epoch                     = '',

    # Number of training epochs
    n_epochs                       = 400,

    # Number of weights updates per epoch  
    n_steps_per_epoch              = 100,

    # Epoch saving frequency
    save_epoch_freq                = 50,

    # Epoch after which to start to save the best model
    start_saving_best_after_epoch  = 5,

    lambda_prob                    = 1.0,
    lambda_dist                    = 0.2,
    lambda_reg                     = 0.0001,
    lambda_prob_class              = 1.0,

    #======================================================================


    # ========================= Optimizers ================================
    # Learning rate
    lr                             = 0.0002,

    # Parameters for Adam optimizer
    beta1                          = 0.9,
    beta2                          = 0.999,

    # Learning rate scheduler policy
    # Possible values:
    #       "none" -> keep the same learning rate for all epochs
    #       "plateau" -> Pytorch ReduceLROnPlateau scheduler
    #       "linear_decay" -> linearly decay learning rate from `lr` to 0
    #       "linear" -> linearly increase  learning rate from 0 to `lr` during the first `lr_linear_n_epochs` and use `lr` for the remaining epochs
    #       "step" -> reduce learning rate by 10 every `lr_step_n_epochs`
    #       "cosine" -> Pytorch CosineAnnealingLR  
    lr_policy                      = "plateau",

    # Parameters when lr_policy = "plateau"
    lr_plateau_factor              = 0.5,
    lr_plateau_threshold           = 0.0000001,
    lr_plateau_patience            = 40,
    min_lr                         = 0.00000001,
    
    # Parameters when lr_policy = "linear"
    lr_linear_n_epochs             = '',

    # Parameters when lr_policy = "linear_decay"  
    lr_decay_iters                 = '',

    # T_max parameter of Pytorch CosineAnnealingLR when `lr_policy` = "cosine
    T_max                          = ''
    )

conf.checkpoints_dir = output_dir + '/checkpoints'
conf.log_dir = output_dir + '/logs'
conf.result_dir = output_dir + '/results'

### Train of StarDist3D

In [None]:
seed_all(conf.random_seed)

opt = prepare_conf(conf)

model = StarDist3D(opt)

print(model)

print("Total number of epochs".ljust(25), ":", model.opt.n_epochs)

fov = np.array( [max(r) for r in model._compute_receptive_field()] )
object_median_size = opt.extents

print("Median object size".ljust(25), ":", object_median_size)
print("Network field of view".ljust(25), ":", fov)

if any(object_median_size > fov):
    warnings.warn("WARNING: median object size larger than field of view of the neural network.")

rays = Rays_GoldenSpiral(opt.n_rays, anisotropy=opt.anisotropy)

train_dataloader, val_dataloader = get_train_val_dataloaders(opt, rays)

total_nb_samples = len( train_dataloader.dataset ) + ( len(val_dataloader.dataset) if val_dataloader is not None else 0 )
nb_samples_train = len(train_dataloader.dataset)
nb_samples_val = total_nb_samples - nb_samples_train

print("Total nb samples: ".ljust(40), total_nb_samples)
print("Train nb samples: ".ljust(40), nb_samples_train)
print("Val nb samples: ".ljust(40), nb_samples_val)

print("Train augmentation".ljust(25), ":",  train_dataloader.dataset.opt.preprocess)
print("Val augmentation".ljust(25), ":", val_dataloader.dataset.opt.preprocess)

train(model, train_dataloader, val_dataloader)

### NMS threshold optimization
This can be performed in either the train or the validation dataset:
- [Optimize based on the training dataset](#train_dataset_optimization).
- [Optimize based on the validation dataset](#val_dataset_optimization).

#### Optimization based on the training dataset <a id='train_dataset_optimization'></a>

In [None]:
# Select which checkpoint you would like to load
# Set model_epoch_to_load = 'best' or 'last'
model_epoch_to_load = "last"

print("Using training dataset for NMS threshold optimization:")
X, Y = train_dataloader.dataset.get_all_data()

if model_epoch_to_load== 'best':
    print("Optimizing thresholds for best model...")
    conf.load_epoch = "best"

elif model_epoch_to_load== 'last': 
    print("Optimizing thresholds for last model...")
    conf.load_epoch = conf.n_epochs

model = StarDist3D(conf)

model.optimize_thresholds(X, Y)

#### Optimization based on the validation dataset <a id='val_dataset_optimization'></a>

In [None]:
# Select which checkpoint you would like to load
# Set model_epoch_to_load = 'best' or 'last'
model_epoch_to_load = "best"

print("Using validation dataset for NMS threshold optimization:")
X, Y = val_dataloader.dataset.get_all_data()

if model_epoch_to_load== 'best':
    print("Optimizing thresholds for best model...")
    conf.load_epoch = "best"

elif model_epoch_to_load== 'last': 
    print("Optimizing thresholds for last model...")
    conf.load_epoch = conf.n_epochs

model = StarDist3D(conf)