```{eval-rst}
.. role:: nge-red
```
{nge-red}`Define a Custom Training Engine`
===================================

For the most control when training, it is possible to define your own training engine which handles the forward and backward pass. This is for those familiar with deep learning techniques and pytorch and therefore will offer less explanation than before.
:::{note}
For consistency with previous tutorials, we recommend not changing the names of any keyword arguments.
:::


## Imports

In [None]:
import os.path
import numpy as np
import skimage.io as io
import skoots.train.loss
from skoots.train.utils import sum_loss, show_box_pred
from skoots.train.sigma import Sigma
from skoots.lib.vector_to_embedding import vector_to_embedding
from skoots.lib.embedding_to_prob import baked_embed_to_prob

from skoots.train.utils import mask_overlay, write_progress

from typing import List, Tuple, Callable, Union, OrderedDict, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.models.detection import FasterRCNN

from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm import trange
from torch.cuda.amp import GradScaler, autocast
from statistics import mean
from torchvision.utils import flow_to_image, draw_keypoints, make_grid
import matplotlib.pyplot as plt
import torch.optim.swa_utils
Dataset = Union[Dataset, DataLoader]

## Function Definition and Keyword Arguments
For a detailed understanding of each keyword argumetn, see the detailed training guide. These parameters should ideally be passed via a kwarg dict.

In [None]:
def engine(
        model: nn.Module,
        lr: float,
        wd: float,
        vector_scale: Tuple[int, int, int],
        epochs: int,
        optimizer: Optimizer,
        scheduler,
        sigma: Sigma,
        loss_embed,
        loss_prob,
        loss_skele,
        device: str,
        savepath: str,
        train_data: Dataset,
        rank: int,
        val_data: Optional[Dataset] = None,
        train_sampler=None,
        test_sampler=None,
        writer=None,
        verbose=False,
        distributed=True,
        mixed_precision=False,
        n_warmup: int = 100,
        force=False,
        **kwargs,
) -> Tuple[OrderedDict, OrderedDict, List[float]]:
    pass

## Initalization
Here we initalize the optimizer, scheduler, and grade scaler for automatic mixed preciions.

In [None]:
optimizer = optimizer(model.parameters(), lr=lr, weight_decay=wd)
scheduler = scheduler(optimizer)
scaler = GradScaler(enabled=mixed_precision)

vector_scale = vector_scale.to(device)

## Stochastic Weight Averaging
We may want to use Stochastic Weight Averaging to improve generalizability of the model. I am not convinced it does much however.

In [None]:
swa_model = torch.optim.swa_utils.AveragedModel(model)
swa_start = 100
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)

## Saving the Loss
We can save the loss values of all training and validation outputs for writing to tensorboard later. This is optional, and only if you want logging.

In [None]:
# Save each loss value in a list...
avg_epoch_loss = []
avg_epoch_embed_loss = []
avg_epoch_prob_loss = []
avg_epoch_skele_loss = []

avg_val_loss = []
avg_val_embed_loss = []
avg_val_prob_loss = []
avg_val_skele_loss = []

## Warmup
This is critical and worth more discussion. These models are incredibly hard to train from scratch. It is therefore usefull to over-train a model on one input which warms up the weights, and makes the model better positioned for general training. To do this, we do a mini version of our actual training loop as outlined below. The API reference for training can be found in the 'API Flow Guide` section of the documentation.

In [None]:
# Warmup... Get the first from train_data
for images, masks, skeleton, skele_masks, baked in train_data:
    pass

warmup_range = trange(n_warmup, desc='Warmup: {}')
for w in warmup_range:
    optimizer.zero_grad(set_to_none=True)

    with autocast(enabled=mixed_precision):  # Saves Memory!
        out: Tensor = model(images)

        # break the singular model output tensor into its respective components
        probability_map: Tensor = out[:, [-1], ...]
        vector: Tensor = out[:, 0:3:1, ...]
        predicted_skeleton: Tensor = out[:, [-2], ...]

        embedding: Tensor = vector_to_embedding(vector_scale, vector)
        out: Tensor = baked_embed_to_prob(embedding, baked, sigma(0))

        # Loss functions are nn.Modules which have their own forward pass.
        _loss_embed = loss_embed(out, masks.gt(0).float())  # out = [B, 2/3, X, Y, Z?]
        _loss_prob = loss_prob(probability_map, masks.gt(0).float())
        _loss_skeleton = loss_skele(predicted_skeleton, skele_masks.gt(
            0).float())

        # We may want to weight one loss more than another. Therefore we can multiply by a set amount. 1-1-1 is usually good though.
        loss = _loss_embed + (1 * _loss_prob) + (1 * _loss_skeleton)

        warmup_range.desc = f'{loss.item()}'

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

## Model Outputs to Loss
It is worth discussing how we go from the embedding vectors, which lie between 1 and 1, to the embedding loss. Our vectors $V$ for dimension $l$ at locations $ijk$ must be scaled by our vector scaling parameter $\Gamma$ for dimension $l$. Formally: $V_{l}^{ijk} = V_l^{ijk} + \Gamma_{l}$ for $l \in [x, y, z]$.

We now can apply these vectors to their own index $ijk$ to determine our embedding $E_{ijk}$. Formally: $E_{ijk} = i+\Gamma_x, j+\Gamma_y, k + \Gamma_z$. This operation is performed by ```skoots.lib.vector_to_embedding.vector_to_embedding```. From here we can calculate a score based on a baked skeleton tensor. A baked skeleton tensor $S$ in $ijk$ is defined as the closest skeleton of an instance

In [None]:
# Train Step...
epoch_range = trange(epochs, desc=f'Loss = {1.0000000}') if rank == 0 else range(epochs)
for e in epoch_range:
    _loss, _embed, _prob, _skele = [], [], [], []

    # Necessary for random sampling with DDP
    if distributed:
        train_sampler.set_epoch(e)

    for images, masks, skeleton, skele_masks, baked in train_data:
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=mixed_precision):  # Saves Memory!
            out: Tensor = model(images)

            probability_map: Tensor = out[:, [-1], ...]
            vector: Tensor = out[:, 0:3:1, ...]
            predicted_skeleton: Tensor = out[:, [-2], ...]

            embedding: Tensor = vector_to_embedding(num, vector)
            out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

            _loss_embed = loss_embed(out, masks.gt(0).float())
            _loss_prob = loss_prob(probability_map, masks.gt(0).float())
            _loss_skeleton = loss_skele(predicted_skeleton, skele_masks.gt(
                0).float())

            # It sometimes is hard to learn all features at once. It can be beneficial to let the model learn the
            # vectors/semantic mask first and then the skeleton
            loss = _loss_embed + (1 * _loss_prob) + ((1 if e > 10 else 0) * _loss_skeleton)


        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # For stochastic weight averaging
        if e > swa_start:
            swa_model.update_parameters(model)

        _loss.append(loss.item())
        _embed.append(_loss_embed.item())
        _prob.append(_loss_prob.item())
        _skele.append(_loss_skeleton.item())

    # Avg epoch loss
    avg_epoch_loss.append(mean(_loss))
    avg_epoch_embed_loss.append(mean(_embed))
    avg_epoch_prob_loss.append(mean(_prob))
    avg_epoch_skele_loss.append(mean(_skele))

    # update the learning rate
    scheduler.step()

    # tensorboard writing
    if writer and (rank == 0):
        writer.add_scalar('lr', scheduler.get_last_lr()[-1], e)
        writer.add_scalar('Loss/train', avg_epoch_loss[-1], e)
        writer.add_scalar('Loss/embed', avg_epoch_embed_loss[-1], e)
        writer.add_scalar('Loss/prob', avg_epoch_prob_loss[-1], e)
        writer.add_scalar('Loss/skele-mask', avg_epoch_skele_loss[-1], e)
        write_progress(writer=writer, tag='Train', epoch=e, images=images, masks=masks,
                       probability_map=probability_map,
                       vector=vector, out=out, skeleton=skeleton,
                       predicted_skeleton=predicted_skeleton, gt_skeleton=skele_masks)

## Validation Step
The validation step soely exists to write to tensorboard. No backprop. For speed we only do this once ever 10 epochs.

In [None]:
# # Validation Step
# for e in trange(epochs): (FROM ABOVE!!!)
    if e % 10 == 0 and val_data:
        _loss, _embed, _prob, _skele = [], [], [], []
        for images, masks, skeleton, skele_masks, baked in val_data:
            with autocast(enabled=mixed_precision):  # Saves Memory!
                with torch.no_grad():
                    out: Tensor = swa_model(images)

                    probability_map: Tensor = out[:, [-1], ...]
                    predicted_skeleton: Tensor = out[:, [-2], ...]
                    vector: Tensor = out[:, 0:3:1, ...]

                    embedding: Tensor = vector_to_embedding(num, vector)
                    out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

                    _loss_embed = loss_embed(out, masks.gt(0).float())
                    _loss_prob = loss_prob(probability_map, masks.gt(0).float())
                    _loss_skeleton = loss_prob(predicted_skeleton, skele_masks.gt(0).float())

                    loss = (1 * _loss_embed) + (1 * _loss_prob) + _loss_skeleton

            scaler.scale(loss)
            _loss.append(loss.item())
            _embed.append(_loss_embed.item())
            _prob.append(_loss_prob.item())
            _skele.append(_loss_skeleton.item())

        avg_val_loss.append(mean(_loss))
        avg_val_embed_loss.append(mean(_embed))
        avg_val_prob_loss.append(mean(_prob))
        avg_val_skele_loss.append(mean(_skele))

        if writer and (rank == 0):
            writer.add_scalar('Validation/train', avg_val_loss[-1], e)
            writer.add_scalar('Validation/embed', avg_val_embed_loss[-1], e)
            writer.add_scalar('Validation/prob', avg_val_prob_loss[-1], e)
            write_progress(writer=writer, tag='Validation', epoch=e, images=images, masks=masks,
                           probability_map=probability_map,
                           vector=vector, out=out, skeleton=skeleton,
                           predicted_skeleton=predicted_skeleton, gt_skeleton=skele_masks)

## TQDM writing
We use tqdm to get estimates on training speed and quick glance model loss

In [None]:
# for e in trange(epochs): (FROM ABOVE)
    if rank == 0:
        epoch_range.desc = f'lr={scheduler.get_last_lr()[-1]:.3e}, Loss (train | val): ' + f'{avg_epoch_loss[-1]:.5f} | {avg_val_loss[-1]:.5f}'

## Return the Model
The training engine should return the model state dict, optimizer state dict, and the avg validation loss to be compatibvle with other scripts.

In [None]:
# for e in trange(epochs):
    # if rank == 0:
        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        if e % 100 == 0:
            torch.save(state_dict, savepath + f'/test_{e}.trch')

return state_dict, optimizer.state_dict(), avg_val_loss

## All Together
Note - this is simply the entirety of ```skoots.train.engine```

In [None]:
def engine(
        model: FasterRCNN,
        lr: float,
        wd: float,
        vector_scale: Tuple[int, int, int],
        epochs: int,
        optimizer: Optimizer,
        scheduler,
        sigma: Sigma,
        loss_embed,
        loss_prob,
        loss_skele,
        device: str,
        savepath: str,
        train_data: Dataset,
        rank: int,
        val_data: Optional[Dataset] = None,
        train_sampler=None,
        test_sampler=None,
        writer=None,
        verbose=False,
        distributed=True,
        mixed_precision=False,
        n_warmup: int = 100,
        force=False,
        **kwargs,
) -> Tuple[OrderedDict, OrderedDict, List[float]]:

    # Print out each kwarg to std out
    if verbose and rank == 0:
        print('Initiating Training Run', flush=False)
        vars = locals()
        for k in vars:
            if k != 'model':
                print(f'\t> {k}: {vars[k]}', flush=False)
        print('', flush=True)

    num = torch.tensor(vector_scale, device=device)


    optimizer = optimizer(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = scheduler(optimizer)
    scaler = GradScaler(enabled=mixed_precision)

    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_start = 100
    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)

    # Save each loss value in a list...
    avg_epoch_loss = []
    avg_epoch_embed_loss = []
    avg_epoch_prob_loss = []
    avg_epoch_skele_loss = []

    avg_val_loss = []
    avg_val_embed_loss = []
    avg_val_prob_loss = []
    avg_val_skele_loss = []

    # skel_crossover_loss = skoots.train.loss.split(n_iter=3, alpha=2)

    # Warmup... Get the first from train_data
    for images, masks, skeleton, skele_masks, baked in train_data:
        pass

    warmup_range = trange(n_warmup, desc='Warmup: {}')
    for w in warmup_range:
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=mixed_precision):  # Saves Memory!
            out: Tensor = model(images)

            probability_map: Tensor = out[:, [-1], ...]
            vector: Tensor = out[:, 0:3:1, ...]
            predicted_skeleton: Tensor = out[:, [-2], ...]

            embedding: Tensor = vector_to_embedding(num, vector)
            out: Tensor = baked_embed_to_prob(embedding, baked, sigma(0))

            _loss_embed = loss_embed(out, masks.gt(0).float())  # out = [B, 2/3, X, Y, Z?]
            _loss_prob = loss_prob(probability_map, masks.gt(0).float())
            _loss_skeleton = loss_skele(predicted_skeleton, skele_masks.gt(
                0).float())  # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())
            loss = _loss_embed + (1 * _loss_prob) + (1 * _loss_skeleton)

            # print('All Skeleton Loss: ', _loss_skeleton.item())
            # print('Skeleton Loss of just crossover: ',
            #       skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float()))

            warmup_range.desc = f'{loss.item()}'

            if torch.isnan(loss):
                print(f'Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}')
                print(f'\t{torch.any(torch.isnan(vector))}')
                print(f'\t{torch.any(torch.isnan(embedding))}')
                continue

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Train Step...
    epoch_range = trange(epochs, desc=f'Loss = {1.0000000}') if rank == 0 else range(epochs)
    for e in epoch_range:
        _loss, _embed, _prob, _skele = [], [], [], []

        if distributed:
            train_sampler.set_epoch(e)

        for images, masks, skeleton, skele_masks, baked in train_data:
            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=mixed_precision):  # Saves Memory!
                out: Tensor = model(images)

                probability_map: Tensor = out[:, [-1], ...]
                vector: Tensor = out[:, 0:3:1, ...]
                predicted_skeleton: Tensor = out[:, [-2], ...]

                embedding: Tensor = vector_to_embedding(num, vector)
                out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

                _loss_embed = loss_embed(out, masks.gt(0).float())  # out = [B, 2/3, X, Y, :w
                # Z?]
                _loss_prob = loss_prob(probability_map, masks.gt(0).float())
                _loss_skeleton = loss_skele(predicted_skeleton, skele_masks.gt(
                    0).float())  # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())

                loss = _loss_embed + (1 * _loss_prob) + ((1 if e > 10 else 0) * _loss_skeleton)



                if torch.isnan(loss):
                    print(f'Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}')
                    print(f'\t{torch.any(torch.isnan(vector))}')
                    print(f'\t{torch.any(torch.isnan(embedding))}')
                    continue
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if e > swa_start:
                swa_model.update_parameters(model)

            _loss.append(loss.item())
            _embed.append(_loss_embed.item())
            _prob.append(_loss_prob.item())
            _skele.append(_loss_skeleton.item())

        avg_epoch_loss.append(mean(_loss))
        avg_epoch_embed_loss.append(mean(_embed))
        avg_epoch_prob_loss.append(mean(_prob))
        avg_epoch_skele_loss.append(mean(_skele))
        scheduler.step()

        if writer and (rank == 0):
            writer.add_scalar('lr', scheduler.get_last_lr()[-1], e)
            writer.add_scalar('Loss/train', avg_epoch_loss[-1], e)
            writer.add_scalar('Loss/embed', avg_epoch_embed_loss[-1], e)
            writer.add_scalar('Loss/prob', avg_epoch_prob_loss[-1], e)
            writer.add_scalar('Loss/skele-mask', avg_epoch_skele_loss[-1], e)
            write_progress(writer=writer, tag='Train', epoch=e, images=images, masks=masks,
                           probability_map=probability_map,
                           vector=vector, out=out, skeleton=skeleton,
                           predicted_skeleton=predicted_skeleton, gt_skeleton=skele_masks)

        # # Validation Step
        if e % 10 == 0 and val_data:
            _loss, _embed, _prob, _skele = [], [], [], []
            for images, masks, skeleton, skele_masks, baked in val_data:
                with autocast(enabled=mixed_precision):  # Saves Memory!
                    with torch.no_grad():
                        out: Tensor = swa_model(images)

                        probability_map: Tensor = out[:, [-1], ...]
                        predicted_skeleton: Tensor = out[:, [-2], ...]
                        vector: Tensor = out[:, 0:3:1, ...]

                        embedding: Tensor = vector_to_embedding(num, vector)
                        out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

                        _loss_embed = loss_embed(out, masks.gt(0).float())
                        _loss_prob = loss_prob(probability_map, masks.gt(0).float())
                        _loss_skeleton = loss_prob(predicted_skeleton, skele_masks.gt(0).float())

                        loss = (2 * _loss_embed) + (2 * _loss_prob) + _loss_skeleton

                        if torch.isnan(loss):
                            print(
                                f'Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}')
                            print(f'\t{torch.any(torch.isnan(vector))}')
                            print(f'\t{torch.any(torch.isnan(embedding))}')
                            continue

                scaler.scale(loss)
                _loss.append(loss.item())
                _embed.append(_loss_embed.item())
                _prob.append(_loss_prob.item())
                _skele.append(_loss_skeleton.item())

            avg_val_loss.append(mean(_loss))
            avg_val_embed_loss.append(mean(_embed))
            avg_val_prob_loss.append(mean(_prob))
            avg_val_skele_loss.append(mean(_skele))

            if writer and (rank == 0):
                writer.add_scalar('Validation/train', avg_val_loss[-1], e)
                writer.add_scalar('Validation/embed', avg_val_embed_loss[-1], e)
                writer.add_scalar('Validation/prob', avg_val_prob_loss[-1], e)
                write_progress(writer=writer, tag='Validation', epoch=e, images=images, masks=masks,
                               probability_map=probability_map,
                               vector=vector, out=out, skeleton=skeleton,
                               predicted_skeleton=predicted_skeleton, gt_skeleton=skele_masks)

        if rank == 0:
            epoch_range.desc = f'lr={scheduler.get_last_lr()[-1]:.3e}, Loss (train | val): ' + f'{avg_epoch_loss[-1]:.5f} | {avg_val_loss[-1]:.5f}'

        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        if e % 100 == 0:
            torch.save(state_dict, savepath + f'/test_{e}.trch')

        return state_dict, optimizer.state_dict(), avg_val_loss
