```{eval-rst}
.. role:: nge-yellow
```
{nge-yellow}`Detailed Training With a Custom Training Function`
===================================

For more fine grain control of the training process, we must define out own train function. This requires a bit more work, but allows much more flexibility. The training function must perform all initalization, dataloading and hyperparameter tuning. We will break down each step outside the function, then compile it all together at the end.

## Imports
We must first import each package necessary for training. SKOOTS tries to take a functional approach at training. It not exactly in line with functional programing best practices, but avoids you from going into a hell of inheritance.

In [None]:
# Python standard library
from functools import partial
import os.path
from typing import Tuple, Callable, Dict

# Pytorch imports
import torch.optim.lr_scheduler
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import Tensor

# Skoots imports
from skoots.train.dataloader import dataset, MultiDataset, skeleton_colate
from skoots.train.sigma import Sigma
from skoots.train.loss import tversky
from skoots.train.merged_transform import merged_transform_3D, background_transform_3D
from skoots.train.engine import engine
from skoots.train.setup import setup_process, cleanup, find_free_port

## Define a Training Function
We need to define 3 mandatory inputs: ```rank```, ```port```, and ```world_size```. Starting in reverse, ```world_size``` is the total number of devices to run distributed training on. If you have two GPU's in one machine, then your world size would be 2. ```port``` is the port of a local web server by which to run distributed training. ```rank``` is the process number. So for a ```world_size``` of 2, we would get two process, one where ```rank=0``` and one with ```rank=1```.

In [None]:
def train(rank: str,
          port: str,
          world_size: int,
          model: nn.Module
          ) -> None:
    pass

It is therefore reasonable that we may pass the following:

In [None]:
rank = '0'
port = '51234'
world_size = 2

Lets also set up some other constants necessary for training, namely the anisotropy and vector scaling parameters.

In [None]:
anisotropy = (1.0, 1.0, 5.0)
vector_scale = (60, 60, 12)

## DDP Initalization
To run this in DDP, we need ot setup the process, send the model to the GPU, and wrap it in a DDP wrapper.

In [None]:
setup_process(rank, world_size, port, backend='nccl')

device = f'cuda:{rank}'
model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model)

## Data Loading
We now need to load our data. Using the SKOOTS dataloader, and skeleton collate function, we can ensure our data is properly handled. First Lets define our augmentations. We can either write our own, or use the pre-built augmentations (which is recommended). See our tutorial on augmentations for more details.

In [None]:
augmentations: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(merged_transform_3D,
                                                                          bake_skeleton_anisotropy=anisotropy,
                                                                          device=device)
augmentations_background: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(background_transform_3D, device=device)

We can now load our data from our training and validation datasets. To do this we use the SKOOTS dataloaders. See the tutorial on dataloading for more details. If you have multiple datasets and therefore multiple dataloaders, you can use the ```MultiDataset``` class to merge two datasets.

In [None]:
train_data = dataset(path='./train', transforms=augmentations, sample_per_image=32, device=device, pad_size=100)
background_data = dataset(path='./background', transforms=augmentations_background, sample_per_image=32, device=device, pad_size=100)
merged_train_data = MultiDataset(train_data, background_data)

validation_data = dataset(path='./validation', transforms=augmentations, sample_per_image=32, device=device, pad_size=100)

Now we need to create a distributed sampler to ensure each dataset is sampled appropriatly. This is necessary by pytorch.

In [None]:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_data)

We now wrap our datasets in pytorch dataloaders to allow for automatic batching and collation! We must use the DDP training sampler and the SKOOTS colate function.

In [None]:
train_dataloader = DataLoader(merged_train_data, num_workers=0, batch_size=2, sampler=train_sampler, collate_fn=skeleton_colate)
validation_dataloader = DataLoader(validation_data, num_workers=0, batch_size=2, sampler=validation_sampler, collate_fn=skeleton_colate)

## Embedding Distance Penalty
To calculate the embedding loss, SKOOTS needs a value reflecting the distance penalty. We term this value sigma, and there are different values for x, y, and z. We use the ```skoots.train.sigma``` library to construct an object allowing us to decay the sigma at set epochs by a multiplier. See the API reference for more details.

In [None]:
initial_sigma = torch.tensor([20., 20., 20.], device=device)
a = {'multiplier': 0.66, 'epoch': 200}
b = {'multiplier': 0.66, 'epoch': 800}
c = {'multiplier': 0.66, 'epoch': 1500}
d = {'multiplier': 0.5, 'epoch': 20000}
f = {'multiplier': 0.5, 'epoch': 20000}
sigma = Sigma([a, b, c, d, f], initial_sigma, device)

## Hyperparameters
We now must define a dictionary of hyperparameters which will be passed as keyword argumetns to a further training engine, which handles loss calculation and backpropagaition. These hyperparameters may be saved along with the model weights, ensuring replicability. Each key in this dictionary must be filled (even with a None) and spelled as is.

In [None]:
constants = {
    'model': model,  # UNet model
    'vector_scale': vector_scale,
    'anisotropy': anisotropy,
    'lr': 5e-4,  # learning rate
    'wd': 1e-6,  # optimizer weight decay
    'optimizer': partial(torch.optim.AdamW, eps=1e-16),
    'scheduler': partial(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10000+ 1),
    'sigma': sigma,
    'loss_embed': tversky(alpha=0.25, beta=0.75, eps=1e-8, device=device), # Loss functions, see API reference for more details
    'loss_prob': tversky(alpha=0.5, beta=0.5, eps=1e-8, device=device),
    'loss_skele': tversky(alpha=0.5, beta=1.5, eps=1e-8, device=device),
    'epochs': 10000,  # total number of training epochs
    'device': device,
    'train_data': train_dataloader,
    'val_data': validation_dataloader,
    'train_sampler': train_sampler,
    'test_sampler': validation_sampler,
    'distributed': True,
    'mixed_precision': True,  # can use automatic mixed precision which may speed up training
    'rank': rank,
    'savepath': './models',  # where to save the model at the end
}

## Training Engine
We are now ready to pass the hyperparameters to the training engine, which handles the forward and backward passes. We can optionally provide a tensorboard writer to track the training. The training engine will run for the predetermined number of epochs then return the model state dict, optimizer state dict, and loss at each epoch.

In [None]:
writer = SummaryWriter() if rank == 0 else None
model_state_dict, optimizer_state_dict, avg_loss = engine(writer=writer, verbose=True, force=True, **constants)

## Saving
Now that training is done, we save the model and its hyperparams. We only need to do this for one process, as the model weights are shared via DDP.

In [None]:
if rank == 0:
    for k in constants: # Some hyperparams cannot be saved as is, so we simply get a string representation. This is usually good enough.
        if k in ['model', 'train_data', 'val_data', 'train_sampler', 'test_sampler', 'loss_embed', 'loss_prob']:
            constants[k] = str(constants[k])

    # Save the weights of the model and optimizer
    constants['model_state_dict'] = model_state_dict
    constants['optimizer_state_dict'] = optimizer_state_dict


# Save the model to a file!
torch.save(constants,f'./models/{os.path.split(writer.log_dir)[-1]}.trch')

## Cleanup
We now must run a mandatory cleanup for our process. This is mandated by pytorch DDP.

In [None]:
cleanup(rank)

## All Together
And now we are done! This should train your entire model. The entire script is below:

In [None]:
def train(rank: str,
          port: str,
          world_size: int,
          model: nn.Module
          ) -> None:
    # setup
    setup_process(rank, world_size, port, backend='nccl')

    device = f'cuda:{rank}'
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model)

    # Augmentations
    augmentations: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(merged_transform_3D,
                                                                          bake_skeleton_anisotropy=anisotropy,
                                                                          device=device)
    augmentations_background: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(background_transform_3D, device=device)

    # Load data and place in dataloader
    train_data = dataset(path='./train', transforms=augmentations, sample_per_image=32, device=device, pad_size=100)
    background_data = dataset(path='./background', transforms=augmentations_background, sample_per_image=32, device=device, pad_size=100)
    merged_train_data = MultiDataset(train_data, background_data)

    validation_data = dataset(path='./validation', transforms=augmentations, sample_per_image=32, device=device, pad_size=100)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_data)
    train_dataloader = DataLoader(merged_train_data, num_workers=0, batch_size=2, sampler=train_sampler, collate_fn=skeleton_colate)
    validation_dataloader = DataLoader(validation_data, num_workers=0, batch_size=2, sampler=validation_sampler, collate_fn=skeleton_colate)

    # Define embedding distance penalty (sigma)
    initial_sigma = torch.tensor([20., 20., 20.], device=device)
    a = {'multiplier': 0.66, 'epoch': 200}
    b = {'multiplier': 0.66, 'epoch': 800}
    c = {'multiplier': 0.66, 'epoch': 1500}
    d = {'multiplier': 0.5, 'epoch': 20000}
    f = {'multiplier': 0.5, 'epoch': 20000}
    sigma = Sigma([a, b, c, d, f], initial_sigma, device)

    # Define constants for training engine
    constants = {
        'model': model,  # UNet model
        'vector_scale': vector_scale,
        'anisotropy': anisotropy,
        'lr': 5e-4,  # learning rate
        'wd': 1e-6,  # optimizer weight decay
        'optimizer': partial(torch.optim.AdamW, eps=1e-16),
        'scheduler': partial(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10000+ 1),
        'sigma': sigma,
        'loss_embed': tversky(alpha=0.25, beta=0.75, eps=1e-8, device=device), # Loss functions, see API reference for more details
        'loss_prob': tversky(alpha=0.5, beta=0.5, eps=1e-8, device=device),
        'loss_skele': tversky(alpha=0.5, beta=1.5, eps=1e-8, device=device),
        'epochs': 10000,  # total number of training epochs
        'device': device,
        'train_data': train_dataloader,
        'val_data': validation_dataloader,
        'train_sampler': train_sampler,
        'test_sampler': validation_sampler,
        'distributed': True,
        'mixed_precision': True,  # can use automatic mixed precision which may speed up training
        'rank': rank,
        'savepath': './models',  # where to save the model at the end
    }

    # tensorboard logging
    writer = SummaryWriter() if rank == 0 else None

    # train model from hyperparams
    model_state_dict, optimizer_state_dict, avg_loss = engine(writer=writer, verbose=True, force=True, **constants)

    # Convert hyperparams to string for saving
    if rank == 0:
        for k in constants: # Some hyperparams cannot be saved as is, so we simply get a string representation. This is usually good enough.
            if k in ['model', 'train_data', 'val_data', 'train_sampler', 'test_sampler', 'loss_embed', 'loss_prob']:
                constants[k] = str(constants[k])

        # Save the weights of the model and optimizer to constants dict
        constants['model_state_dict'] = model_state_dict
        constants['optimizer_state_dict'] = optimizer_state_dict


    # Save the model to a file!
    torch.save(constants,f'./models/{os.path.split(writer.log_dir)[-1]}.trch')

    # cleanup the DDP process
    cleanup(rank)