```{eval-rst}
.. role:: nge-green
```
{nge-green}`Training a SKOOTS Model`
===================================

The SKOOTS library provides necessary pre-written evaluation functions which make it easy to train a segmentaiton model. We typically do this through configuration files, which is then used to define a training run. This all happens through the command line. We will first show you how to prepare your data, construct a configuration file, train using the command line. For details on the training script, how it works, and how to hack it, please see the detailed training tutorial.
The built in training scripts uses pytorch's DistributedDataParallel by default with an 'nvcc' communication server, so unfortunately
requires an Nvidia GPU.

## Prepare Your Data
We start by preparing our data. SKOOTS expects training images to be large tiff images with associated masks. SKOOTS will associate the mask to the image by its filename and tag. Images may be named whatever you'd like, for example: ``` training_data.tif ```. The associate labels must therefore be named as such: ```training_data.labels.tif ```. The background of each label must be zero.

## Precompute Ground Truth Skeletons
Once our training data is in an appropriate place, we must pre-compute the ground truth skeletons. This generates a seperate file and need only happen once. SKOOTS provides necessary utility functions for creating the skeletons, however an explicit script must be created for your own data.

In [None]:
import torch
import glob
import numpy as np
import skimage.io as io
from typing import Dict

from skoots.train.generate_skeletons import calculate_skeletons

training_directory = './train'  # base directory containing all our data.

# Sometimes skeletons are 'weird' looking due to anisotropy.
# Scaling the image can a predetermined amount can help with this.
# This may need trial and error to get skeletons which are easily predicable
scale_factors = torch.tensor([1, 1, 0.5])

# Loop over all the mask files
for f in glob.glob(training_directory + '/*.labels.tif'):
    masks: np.ndarray = io.imread(f) # will read in as an uint16 numpy array with shape [Z, X, Y]
    masks = torch.from_numpy(masks.astype(np.int32))  # pytorch cannot import uint16, convert to 32bit int instead.
    masks = masks.permute(1, 2, 0)  # the script expects the tensor to be [X, Y, Z]
    skeletons: Dict[int, torch.Tensor] = calculate_skeletons(masks, scale_factors)  # calculate the skeletons

    f = f[:-11:] # get rid of '.labels.tif'
    torch.save(skeletons, f + '.skeletons.trch')  # IMPORTANT! skeletons must be saved with this extension and tag!!!

We now have three files for each training image. The image: ```train_image.tif```, the instance masks: ```train_image.labels.tif```, and the precomputed skeletons: ```train_image.skeletons.trch```. Precomputing the skeletons need only happen once, saving training time, as it can be an expensive procedure. All three must be present in the same folder for training! In this case, lets put them in ```./train```  . We may do the same process for validation images and put them in ```./validate```. You may also want to provide a set of background images by which to training the model to be robust against. These images have no masks and therefore no skeletons. We'll put these in the folder ```./background```.

We can also do this through the skoots CLI! We simply put all training image masks in a single folder, and run this in the terminal:
```bash
skoots --skeletonize_train_data "path/to/training/data" --downscaleXY 1 --downscaleZ 0.5
```

This command will create a bunch of files with the extension ```*.skeletons.trch``` with the same filename as each training mask.

## Imports
Now we are ready to start training. First lets import the necessary functions:

In [None]:
# Necessary Inputs from python standard library
from functools import partial
from typing import Tuple, Callable, Dict
import os.path

# Necessary import from pyTorch
import torch
import torch.nn as nn
import torch.multiprocessing as mp

# Everything we need from SKOOTS
from skoots.train.distributed import train
from skoots.lib.mp_utils import setup_process, cleanup, find_free_port

## Define a Model
Once imported, we need to define a UNet like model for training. This model must accept a 5D tensor of shape $(B_{in}, C_{in}, X_{in}, Y_{in}, Z_{in})$ and must return a tensor of shape $(B_{in}, C_{out}=5, X_{in}, Y_{in}, Z_{in})$. Special attention should be paid to the activation of the output tensors. The SKOOTS training script expects the 5 output channels to be a concatenation of the embedding vectors $E_{x,y,z} \in [-1, ..., -1]$, semantic probability map $P \in [0, ..., 1]$, and skeleton probability map $P \in [0, ..., 1]$. The channels are therefore, $[E_x, E_y, E_z, P, S]$. Therefore it is sensible to pass channels 0-2 through a tanh activation, and channels 3 and 4 through a sigmoid. Ideally the implementation should be torchscript-able and contain no parameters which are not used in producing a loss. Here is a toy example:

In [None]:
class skoots_model(nn.Module):
    def __init__(self, in_channels=1):
        self.backbone = nn.Conv3d(1, 5, kernel_size=3, stride=2, padding=1)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.backbone(x)
        x = torch.cat((
            self.tanh(x[:, 0:2, ...]),    # Embedding vectors
            self.sigmoid(x[:, [3], ...]), # Skeletons
            self.sigmoid(x[:, [4], ...])  # Probability map
        ), dim=1)
        return x

Once we have a model, we are ready to start training. If you have a previously trained checkpoint, now is the time to load the weights.

In [None]:
model = skoots_model()
checkpoint = torch.load('checkpoint.trch')
model.load_state_dict(checkpoint['model_state_dict'])

We can optionally provide a hyper_parameter dict, which gets passed to a tensorboard writer down the line. You can put whatever you want in here as long as the key is a string, and the value is an int, float, or string.

In [None]:
hyperparams = {'training_name': 'test_training_run'}

## Anisotropy
SKOOTS needs to know the anisotropy of your data, as well as the amount it should scale the embedding vectors. In our data, the Z dimensions is roughly 5 times the X and Y, and the maximum cross section of a mitochondria is around 60 pixels in X, Y and 12 pixels in Z. We therefore define the following:

In [None]:
vector_scale = (60.0, 60.0, 12.0)
anisotropy = (1.0, 1.0, 3.0)

SKOOTS training needs the pytorch DistributedDataParallel, even for single GPU training. We therefore define a few new values:

In [8]:
port = find_free_port()  # Assumes a multi-gpu, single computer model. Please change accordingly for multi-computer training
world_size = 2  # Number of devices to run on. I have two GPU's so this is 2. If you only have one GPU, set this value to 1

## Launch the Training Function
Now we call into pytorch's multiprocessing library to launch multiple instances of our training engine. We pass the model, hyperparameters, training and validation data locations, vector scaling and anisotropy. The training engine should handle the rest!

In [None]:
train_dir, validation_dir, background_dir = './train', './validate', './background'
# Note the args MUST be in this order and named as such
mp.spawn(train, args=(port, world_size, model, hyperparams, train_dir, validation_dir, background_dir, vector_scale, anisotropy),
         nprocs=world_size, join=True)

The model will now train for 10000 epochs and log the outputs to a tensorboard folder ```./runs```. Outputs will be saved in a folder named ```./models``` with the same name as the tensorboard logdir. For more detailed control on the training process, including setting your own learning rate, optimizer, etc..., please see the detailed training scripts!