```{eval-rst}
.. role:: nge-green
```
{nge-green}`Training a SKOOTS Model`
===================================

This guide is for people unfamiliar with YACS and the command line. The SKOOTS library provides necessary pre-written evaluation functions which make it easy to train a segmentation model. We typically do this through configuration files, which is then used to define a training run. This all happens through the command line. We will first show you how to prepare your data, construct a configuration file, train using the command line. For details on the training script, how it works, and how to hack it, please see the detailed training
The built in training scripts uses pytorch's DistributedDataParallel by default with an 'nvcc' communication server, so unfortunately
requires an Nvidia GPU.

## Prepare Your Data
We start by preparing our data. SKOOTS expects training images to be large tiff images with associated masks. SKOOTS will associate the mask to the image by its filename and tag. Images may be named whatever you'd like, for example: ``` training_data.tif ```. The associate labels must therefore be named as such: ```training_data.labels.tif ```. The background of each label must be zero.

## Precompute Ground Truth Skeletons
Once our training data is in an appropriate place, we must pre-compute the ground truth skeletons. This generates a seperate file and need only happen once. SKOOTS provides necessary utility functions for creating the skeletons, however an explicit script must be created for your own data.

In [None]:
import torch
import glob
import numpy as np
import skimage.io as io
from typing import Dict

from skoots.train.generate_skeletons import calculate_skeletons

training_directory = './train'  # base directory containing all our data.

# Sometimes skeletons are 'weird' looking due to anisotropy.
# Scaling the image can a predetermined amount can help with this.
# This may need trial and error to get skeletons which are easily predicable
scale_factors = torch.tensor([1, 1, 0.5])

# Loop over all the mask files
for f in glob.glob(training_directory + '/*.labels.tif'):
    masks: np.ndarray = io.imread(f) # will read in as an uint16 numpy array with shape [Z, X, Y]
    masks = torch.from_numpy(masks.astype(np.int32))  # pytorch cannot import uint16, convert to 32bit int instead.
    masks = masks.permute(1, 2, 0)  # the script expects the tensor to be [X, Y, Z]
    skeletons: Dict[int, torch.Tensor] = calculate_skeletons(masks, scale_factors)  # calculate the skeletons

    f = f[:-11:] # get rid of '.labels.tif'
    torch.save(skeletons, f + '.skeletons.trch')  # IMPORTANT! skeletons must be saved with this extension and tag!!!

We now have three files for each training image. The image: ```train_image.tif```, the instance masks: ```train_image.labels.tif```, and the precomputed skeletons: ```train_image.skeletons.trch```. Precomputing the skeletons need only happen once, saving training time, as it can be an expensive procedure. All three must be present in the same folder for training! In this case, lets put them in ```./train```  . We may do the same process for validation images and put them in ```./validate```. You may also want to provide a set of background images by which to training the model to be robust against. These images have no masks and therefore no skeletons. We'll put these in the folder ```./background```.

We can also do this through the skoots CLI! We simply put all training image masks in a single folder, and run this in the terminal:
```bash
skoots --skeletonize_train_data "path/to/training/data" --downscaleXY 1 --downscaleZ 0.5
```

This command will create a bunch of files with the extension ```*.skeletons.trch``` with the same filename as each training mask.

## Configure your training

SKOOTS uses YACS to configure training of our models. YACS is a python extension which allows us to define all variables which might influence training in a single text file: ```config.yaml```. SKOOTS has a defualt configuration for everything, and they are defined in the file ```skoots/config.py```. YACS uses these defaults to fill in everything you dont want to. Each line in the config file, maps to a YAML file which could be used to specify the configuration. For instance, in config.py, ```_C.TRAIN.NUM_EPOCHS``` would be set in a YAML file that looks like this:
```
TRAIN:
    NUM_EPOCHS: 100
```

There are other training parameters too. Say we want to set the training batch size, defined in config.py as ```_C.TRAIN.TRAN_BATCH_SIZE```. We could set both configurations as:

```
TRAIN:
    NUM_EPOCHS: 100
    TRAIN_BATCH_SIZE: 2
```

Notice how the YAML file groups similar configs together. To set up a pretty minimalistic training run, your config file might look like this:

```
TRAIN:
    PRETRAINED_MODEL_PATH: ['path/to/pretrained_model.trch']
    TRAIN_DATA_DIR: ['path/to/train/data', '/path/to/more/train/data]  # can specifiy multiple sources of train data
    VALIDATION_DATA_DIR: ['path/to/validation/data', '/path/to/more/validation/data] # similarly, validation data

```

For reference, the entirety of the config.py file looks like this:

In [1]:
from yacs.config import CfgNode as CN

# -----------------------------------------------------------------------------
# Training config definition
# -----------------------------------------------------------------------------
_C = CN()

# -----------------------------------------------------------------------------
# System
# -----------------------------------------------------------------------------
_C.SYSTEM = CN()

_C.SYSTEM.NUM_GPUS = 2  # number of available NVIDIA GPU's to train on
_C.SYSTEM.NUM_CPUS = 1  # How many CPU's do you have? Might be more than 1 when doing distributed training

# Define a BISM Model
_C.MODEL = CN()
_C.MODEL.ARCHITECTURE = 'bism_unext'  # name of bism model
_C.MODEL.IN_CHANNELS = 1  # number of input color channels, 1 for grayscale, 3 for rgb
_C.MODEL.OUT_CHANNELS = 32  # output of model backbone, but not skoots
_C.MODEL.DIMS = [32, 64, 128, 64, 32]   # number of channels at each level of a unet
_C.MODEL.DEPTHS = [2, 2, 2, 2, 2]  # number of computational blocks at each level of the unet
_C.MODEL.KERNEL_SIZE = 7  # kernel size of each convolution in unet
_C.MODEL.DROP_PATH_RATE = 0.0
_C.MODEL.LAYER_SCALE_INIT_VALUE = 1.
_C.MODEL.ACTIVATION = 'gelu'
_C.MODEL.BLOCK = 'block3d'  # computational blocks, see bism for all available
_C.MODEL.CONCAT_BLOCK = 'concatconv3d'  # concatenation bocks for skip connections
_C.MODEL.UPSAMPLE_BLOCK = 'upsamplelayer3d' # upsample operation
_C.MODEL.NORMALIZATION='layernorm'  # normalization layer, could be batch norm...

# Training Configurations
_C.TRAIN = CN()
_C.TRAIN.DISTRIBUTED = True  # do we use pytorch data distributed parallel?
_C.TRAIN.PRETRAINED_MODEL_PATH = [
    '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/models/Oct20_11-54-51_CHRISUBUNTU.trch'
]  # path to pretrained model

# embedding loss function and their constructor keywords
_C.TRAIN.LOSS_EMBED = 'tversky'  # could also be "soft_cldice"
_C.TRAIN.LOSS_EMBED_KEYWORDS = ['alpha', 'beta', 'eps']  # kwargs for the loss function
_C.TRAIN.LOSS_EMBED_VALUES = [0.25, 0.75, 1e-8]  # values for each kwarg

# Semantic mask loss function and their constructor keywords
_C.TRAIN.LOSS_PROBABILITY = 'tversky'
_C.TRAIN.LOSS_PROBABILITY_KEYWORDS = ['alpha', 'beta', 'eps']
_C.TRAIN.LOSS_PROBABILITY_VALUES = [0.5, 0.5, 1e-8]

# Skeleton mask loss function and their constructor keywords
_C.TRAIN.LOSS_SKELETON = 'tversky'
_C.TRAIN.LOSS_SKELETON_KEYWORDS = ['alpha', 'beta', 'eps']
_C.TRAIN.LOSS_SKELETON_VALUES = [0.5, 1.5, 1e-8]

# We sum the loss values of each part together, scaled by some factor set here
_C.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT = 1.0
_C.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT = 1.0
_C.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT = 1.0

# We may not consider each loss until a certain epoch. This may be
# because some tasks are hard to learn at the start, and must only be considered later
# roughly does this:
# loss_embed = loss_embed if cfg.TRAIN.LOSS_EMBED_START_EPOCH < epoch else torch.tensor(0)
# ...same for skeleton and semantic mask...
_C.TRAIN.LOSS_EMBED_START_EPOCH = -1
_C.TRAIN.LOSS_PROBABILITY_START_EPOCH = -1  # sets the epoch where the semantic mask loss is added to
_C.TRAIN.LOSS_SKELETON_START_EPOCH = 10

# Train, Validation, and Background data share similar syntax
# *_DATA_DIR is a list of locations to look for training data, can have multiple sources
# *_SAMPLE_PER_IMAGE is a list of the number of times to sample each image in an epoch, must be the same size as *_DATA_DIR
# *_BATCH_SIZE is the batch size
# BACKGROUND are just images of nothing, and do not need associated label files
_C.TRAIN.TRAIN_DATA_DIR = [
    '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/data/unscaled/train']
_C.TRAIN.TRAIN_SAMPLE_PER_IMAGE = [32]
_C.TRAIN.TRAIN_BATCH_SIZE = 2
_C.TRAIN.VALIDATION_DATA_DIR = [
    '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/data/unscaled/validate']
_C.TRAIN.VALIDATION_SAMPLE_PER_IMAGE = [6]
_C.TRAIN.VALIDATION_BATCH_SIZE = 1
_C.TRAIN.BACKGROUND_DATA_DIR = [
    '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/data/background']
_C.TRAIN.BACKGROUND_SAMPLE_PER_IMAGE = [8]

_C.TRAIN.STORE_DATA_ON_GPU = False  # Sends all training data to GPU for faster access

# Sigma sets the distance penalty for embedding. Each number is in units of pixels
# See the manuscript methods section for more detail
_C.TRAIN.INITIAL_SIGMA = [20., 20., 20.]  # [X, Y, Z]
_C.TRAIN.SIGMA_DECAY = [[0.66, 200], [0.66, 800], [0.66, 1500], [0.5, 20000], [0.5, 20000]] # List of sigma decays [[fraction, decay_epoch], ...], i.e sigma *= fraction if epoch > decay_epoch
_C.TRAIN.NUM_EPOCHS = 10000  # total number of epoch to train
_C.TRAIN.LEARNING_RATE = 5e-4  # optimizer learning rate
_C.TRAIN.WEIGHT_DECAY = 1e-6  # optimizer weight decay
_C.TRAIN.OPTIMIZER = 'adamw'  # Train optimizer. Valid are: 'Adam', 'AdamW', 'SGD',
_C.TRAIN.OPTIMIZER_EPS = 1e-8  # Optimizer eps
_C.TRAIN.SCHEDULER = 'cosine_annealing_warm_restarts'  # learning rate scheduler, currently this is the only implemented
_C.TRAIN.SCHEDULER_T0 = 10000 + 1  # period of learning rate scheduler
_C.TRAIN.MIXED_PRECISION = True  # train using pytorch automatic mixed precision AMP
_C.TRAIN.N_WARMUP = 1500  # number of times to train on an inital example to warm up model.
_C.TRAIN.SAVE_PATH = '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/models'  # where do we save the model, and intermediate files?
_C.TRAIN.SAVE_INTERVAL = 100  # saves a snapshot of the model every SAVE_INTERVAL number of epochs
_C.TRAIN.CUDNN_BENCHMARK = True  # sets torch.backends.cudnn.benchmark
_C.TRAIN.AUTOGRAD_PROFILE = False  # sets torch.autograd.profiler.profile
_C.TRAIN.AUTOGRAD_EMIT_NVTX = False # sets torch.autograd.profiles.emit_nvtx(enabled= * )
_C.TRAIN.AUTOGRAD_DETECT_ANOMALY = False # sets torch.autograd.set_detect_anomaly( * )

# Augmentation Configuration
# these are self-explanatory and set the valid image augmentations for training
# for more detail on usage, see skoots/train/merge_transform.py
_C.AUGMENTATION = CN()
_C.AUGMENTATION.CROP_WIDTH = 300
_C.AUGMENTATION.CROP_HEIGHT = 300
_C.AUGMENTATION.CROP_DEPTH = 20
_C.AUGMENTATION.FLIP_RATE = 0.5
_C.AUGMENTATION.BRIGHTNESS_RATE = 0.4
_C.AUGMENTATION.BRIGHTNESS_RANGE = [-0.1, 0.1]
_C.AUGMENTATION.NOISE_GAMMA = 0.1
_C.AUGMENTATION.NOISE_RATE = 0.2
_C.AUGMENTATION.CONTRAST_RATE = 0.33
_C.AUGMENTATION.CONTRAST_RANGE = [0.75, 2.0]
_C.AUGMENTATION.AFFINE_RATE = 0.66
_C.AUGMENTATION.AFFINE_SCALE = [0.85, 1.1]
_C.AUGMENTATION.AFFINE_YAW = [-180, 180]
_C.AUGMENTATION.AFFINE_SHEAR = [-7, 7]
_C.AUGMENTATION.SMOOTH_SKELETON_KERNEL_SIZE = (3, 3, 1)
_C.AUGMENTATION.BAKE_SKELETON_ANISOTROPY = (1.0, 1.0, 3.0) # does not necessarily have to reflect data anisotropy
_C.AUGMENTATION.N_SKELETON_MASK_DILATE = 1


# Skoots Generics
_C.SKOOTS = CN()
# This sets the vector scaling of your data, and should roughly equate to the
# max distance of any pixel to a skeleton.
_C.SKOOTS.VECTOR_SCALING = (60, 60, 60 // 5)

# this sets the anisotropy of your data,
_C.SKOOTS.ANISOTROPY = (1.0, 1.0, 3.0)


def get_cfg_defaults():
    r"""Get a yacs CfgNode object with default values for my_project."""
    # Return a clone so that the defaults will not be altered
    # This is for the "local variable" use pattern
    return _C.clone()

## Train your model

We are now able to train a model. Make sure all data is where it should be and that your configuration file is set and named something informative. Now we write in the terminal:

```bash
skoots-train --config-file "my_config.yaml"
```

SKOOTS will log intermediary steps in tensorboard for diagnoses, and save a file in ```_C.TRAIN.SAVE_PATH``` with the logdir name. All information on training is saved in this file, including the configuration, if you ever forget what model was trained for what dataset. To evaluate this model, in the terminal type:

 ```bash
 skoots --pretrained_checkpoint "mymodel.trch" --image "inference_image.tif"
 ```

For more details on hot training works and how to extend it, see the detailed training example.