```{eval-rst}
.. role:: nge-yellow
```
{nge-yellow}`Detailed Training Protocol`
===================================

 The training process is typically invoked via the command line interface via the ```skoots-train``` command. This calls into the main function in file ```skoots.train.__main__.py```. This function parses all command line arguments, loads the config file and model, initializes pytorch DataDistributedParallel, and finally calls the ```train()``` function from ```skoots.train.engine.py```. To understand how we train SKOOTS, we will go that function in detail.  Throughout the training script, you will see references to a variable ```cfg``` which stores the users configuration data.

## Imports
We must first import each package necessary for training. SKOOTS tries to take a functional approach at training. It not exactly in line with functional programing best practices, but avoids you from going into a hell of inheritance.

In [None]:
import os
import os.path
from functools import partial
from statistics import mean
from typing import Callable, Union, Dict

import torch
import torch.nn as nn
import torch.optim.lr_scheduler
import torch.optim.swa_utils
from torch import Tensor
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from yacs.config import CfgNode

import skoots.train.loss
from skoots.lib.embedding_to_prob import baked_embed_to_prob
from skoots.lib.vector_to_embedding import vector_to_embedding
from skoots.train.dataloader import dataset, MultiDataset, skeleton_colate
from skoots.train.merged_transform import (
    transform_from_cfg,
    background_transform_from_cfg,
)
from skoots.train.setup import setup_process
from skoots.train.sigma import Sigma, init_sigma
from skoots.train.utils import write_progress

## Setup DataDistributedParallel
We need to define 3 mandatory inputs: ```rank```, ```port```, and ```world_size```. Starting in reverse, ```world_size``` is the total number of devices to run distributed training on. If you have two GPU's in one machine, then your world size would be 2. ```port``` is the port of a local web server by which to run distributed training. ```rank``` is the process number. So for a ```world_size``` of 2, we would get two process, one where ```rank=0``` and one with ```rank=1```. World size is handled by the configuration file with ```cfg.SYSTEM.NUM_GPUS```. This function should be called through pytorch multiprocessing. See ```skoots.train.__main__.py```.

In [None]:
# Invoked from skoots.train.__main__.py
def train(rank: str,
          port: str,
          world_size: int,
          base_model: nn.Module,
          cfg: CfgNode
          ) -> None:
    pass

From here we set up required processes for torch DistributedDataParallel as well as compile the model using torch inductor (if available). This lets us use multiple GPU's for training, as well as just-in-time compiled Cuda kernels for accelerated training.

In [None]:
    setup_process(rank, world_size, port, backend="nccl")
    device = f"cuda:{rank}"

    base_model = base_model.to(device)
    base_model = torch.nn.parallel.DistributedDataParallel(base_model)

    if int(rank) == 0:
        print(cfg)

    if int(torch.__version__[0]) >= 2:
        print("Comiled with Inductor")
        model = torch.compile(base_model)
    else:
        model = torch.jit.script(base_model)

## Data Loading and Augmentation
Data augmentation parameters are set by the configuration file and executed as a single function from ```skoots.train.merged_transform.py```. This is to reduce the overhead of chaining multiple augmentation classes together, which some augmentation libraries like to do. There is a seperate set of transformations for background data, as this does not need to process masks or skeletons.

In [None]:
    augmentations: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(
        transform_from_cfg, cfg=cfg, device=device
    )
    background_agumentations: Callable[
        [Dict[str, Tensor]], Dict[str, Tensor]
    ] = partial(background_transform_from_cfg, cfg=cfg, device=device)

This function is takes in a ```data_dict```, which is simply a python dictionary which contains the image, masks, and skeletons. Next, we load our data using the ```dataset``` class from ```skoots.train.dataloader.py```. This dataset class looks for multiple sets of three files in a single of folder with a common prefix and the extensions: ```*.tif```(the image), ```*.label.tif``` (the masks), and ```*.skeletons.trch``` (the precomputed skeletons). Training data often consists of one, really large file, too large to fit in a neural network. Therefore, the notion of an epoch doesn't make sense. Instead, SKOOTS defines an epoch as a set number of samples from each image in a dataset. This might change for different images, (you dont want to sample a small image 30 times), and therefore SKOOTS enables the user to split their datasets up in multiple folders, and define a sample rate for each.
This is set in the config by specifying a list of potential data locations: ```_C.TRAIN.TRAIN_DATA_DIR = [data_loc_1, data_loc_2, ...]```. For each data location, we let the user define the number of samples which defines an epoch. This is reflected in code here:

In [None]:
    _datasets = []  # store multiple datasets
    for path, N in zip(cfg.TRAIN.TRAIN_DATA_DIR, cfg.TRAIN.TRAIN_SAMPLE_PER_IMAGE):
        _device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
        _datasets.append(
            dataset(
                path=path,                  # where is our data
                transforms=augmentations,   # augmentation function
                sample_per_image=N,         # how many times do we sample each image?
                device=device,              # what devive (cpu or gpu) should the data go to
                pad_size=10,                # zero padding added to each image
            )
            .pin_memory()                   # pins the memory in ram for faster access
            .to(_device)                    # if your dataset is small, or GPU is LARGE, all of the data can live on the GPU for faster access
        )

    merged_train = MultiDataset(*_datasets) # helper class which lets us access all datasets in one object

    train_sampler = torch.utils.data.distributed.DistributedSampler(merged_train)
    _n_workers = 0  # if _device != 'cpu' else 2

    # put this in a pytorch dataloader for automatic batching and sampling
    dataloader = DataLoader(
        merged_train,
        num_workers=_n_workers,
        batch_size=cfg.TRAIN.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        collate_fn=skeleton_colate,
    )

We do the same for validation and background datasets.

In [None]:
    for path, N in zip(
        cfg.TRAIN.BACKGROUND_DATA_DIR, cfg.TRAIN.BACKGROUND_SAMPLE_PER_IMAGE
    ):
        _device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
        _datasets.append(
            dataset(
                path=path,
                transforms=background_agumentations,
                sample_per_image=N,
                device=device,
                pad_size=100,
            )
            .pin_memory()
            .to(_device)
        )

    merged_train = MultiDataset(*_datasets)

    train_sampler = torch.utils.data.distributed.DistributedSampler(merged_train)
    _n_workers = 0  # if _device != 'cpu' else 2
    dataloader = DataLoader(
        merged_train,
        num_workers=_n_workers,
        batch_size=cfg.TRAIN.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        collate_fn=skeleton_colate,
    )

    # Validation Dataset
    _datasets = []
    for path, N in zip(
        cfg.TRAIN.VALIDATION_DATA_DIR, cfg.TRAIN.VALIDATION_SAMPLE_PER_IMAGE
    ):
        _device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
        _datasets.append(
            dataset(
                path=path,
                transforms=augmentations,
                sample_per_image=N,
                device=device,
                pad_size=10,
            )
            .pin_memory()
            .to(_device)
        )

    merged_validation = MultiDataset(*_datasets)
    test_sampler = torch.utils.data.distributed.DistributedSampler(merged_validation)
    if _datasets or cfg.TRAIN.VALIDATION_BATCH_SIZE >= 1:
        _n_workers = 0  # if _device != 'cpu' else 2
        valdiation_dataloader = DataLoader(
            merged_validation,
            num_workers=_n_workers,
            batch_size=cfg.TRAIN.VALIDATION_BATCH_SIZE,
            sampler=test_sampler,
            collate_fn=skeleton_colate,
        )

    else:  # we might not want to run validation...
        valdiation_dataloader = None

## Optimizers, Schedulers, Loss
We set optimizers, learning rate schedulers, and loss functions through the config file. The constructors for each come from a list of dictonaries at the top of ```skoots.train.engine.py```:

In [None]:
    _valid_optimizers = {
        "adamw": torch.optim.AdamW,
        "adam": torch.optim.Adam,
        "sgd": torch.optim.SGD,
        "adamax": torch.optim.Adamax,
    }

    _valid_loss_functions = {
        "soft_cldice": skoots.train.loss.soft_dice_cldice,
        "tversky": skoots.train.loss.tversky,
    }

    _valid_lr_schedulers = {
        "cosine_annealing_warm_restarts": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
    }

Within the training script, we get the constructor for each from these valid options, and call into it with other arguments set by the config file. We can set keyword arguments and values for the loss functions via the configuration as well. This is helpful when using tversky loss with different pentalties for foreground and background.

In [None]:
    optimizer = _valid_optimizers[cfg.TRAIN.OPTIMIZER](
            model.parameters(),
            lr=cfg.TRAIN.LEARNING_RATE,
            weight_decay=cfg.TRAIN.WEIGHT_DECAY,
        )
    scheduler = _valid_lr_schedulers[cfg.TRAIN.SCHEDULER](
        optimizer, T_0=cfg.TRAIN.SCHEDULER_T0
    )
    scaler = GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)

    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_start = 100
    swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)

    _kwarg = {
        k: v for k, v in zip(cfg.TRAIN.LOSS_EMBED_KEYWORDS, cfg.TRAIN.LOSS_EMBED_VALUES)
    }
    loss_embed: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_EMBED](**_kwarg)

    _kwarg = {
        k: v
        for k, v in zip(
            cfg.TRAIN.LOSS_PROBABILITY_KEYWORDS, cfg.TRAIN.LOSS_PROBABILITY_VALUES
        )
    }
    loss_prob: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_PROBABILITY](**_kwarg)

    _kwarg = {
        k: v
        for k, v in zip(
            cfg.TRAIN.LOSS_SKELETON_KEYWORDS, cfg.TRAIN.LOSS_SKELETON_VALUES
        )
    }
    loss_skele: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_SKELETON](**_kwarg)

## Sigma
To evaluate embedding accuracy, SKOOTS defines a distance penalty variable called sigma. This is implemented in its own class: ```skoots.train.sigma.py```. The parameters for this are set in the config file, and the class is constructed with the helper function ```skoots.train.sigma.init_sigma()```This penalty decays over multiple epochs and is called like a function:

In [None]:
    sigma: Sigma = init_sigma(cfg, device)
    _ = sigma(100) # whats the sigma at epoch 100?

## Vector Scaling
Our model will ultimately output a set of vectors from -1 to 1. This must be scaled to fit the maximum radius of any object you wish to segment. That is set here.

In [None]:
    vector_scale = torch.tensor(cfg.SKOOTS.VECTOR_SCALING, device=device)

Before final training we also set/initalize a couple of other things

In [None]:
    # these disable some torch checks but can accelerate training speed
    torch.backends.cudnn.benchmark = cfg.TRAIN.CUDNN_BENCHMARK
    torch.autograd.profiler.profile = cfg.TRAIN.AUTOGRAD_PROFILE
    torch.autograd.profiler.emit_nvtx(enabled=cfg.TRAIN.AUTOGRAD_EMIT_NVTX)
    torch.autograd.set_detect_anomaly(cfg.TRAIN.AUTOGRAD_DETECT_ANOMALY)

    # we use tensorboard for logging
    writer = SummaryWriter() if rank == 0 else None
    if writer:
        print("SUMMARY WRITER LOG DIR: ", writer.get_logdir())

    # Save each loss value in a list... we disregard the first one... ;)
    avg_epoch_loss = [9999999999.9999999999]
    avg_epoch_embed_loss = [9999999999.9999999999]
    avg_epoch_prob_loss = [9999999999.9999999999]
    avg_epoch_skele_loss = [9999999999.9999999999]

    avg_val_loss = [9999999999.9999999999]
    avg_val_embed_loss = [9999999999.9999999999]
    avg_val_prob_loss = [9999999999.9999999999]
    avg_val_skele_loss = [9999999999.9999999999]

## Calling the DataLoader and a Simple Training Iteration
The DataLoader acts like an iterable which returns 5 pieces of information: the image, the labeled mask, the skeleton dictonary, the skeleton masks, and the "baked" skeleton. For more reference on what these are, see the Training section. We use each of these to perform a training step. First the image is passed through the model


In [None]:
    # assume current epoch is set here:
    current_epoch = 0
    for images, masks, skeleton, skele_masks, baked in dataloader:
        out: Tensor = model(images)

The out tensor is a 5 channel tensor which contains the semantic probability map, the embedding vectors, and the skeleton map. We can separate these here:

In [None]:
        probability_map: Tensor = out[:, [-1], ...]
        vector: Tensor = out[:, 0:3:1, ...]
        predicted_skeleton: Tensor = out[:, [-2], ...]

To calculate a loss, we need a skeleton embedding. To calculate the skeleton embedding we need the vectors, vector sale, and the function ```vector_to_embedding``` from ```skoots.lib.vector_to_embedding.py```

In [None]:
        embedding: Tensor = vector_to_embedding(vector_scale, vector)

Once we have an embedding, we need a way to calculate a loss value. We do this by generating a probability score for each pixel based on how close the embedding is from it's "true" destination. This true destination is its closest skeleton, and contained in the baked skeleton tensor. To calculate this probability we call the function ```baked_embed_to_prob``` from ```skoots.lib.embedding_to_prob.py```.

In [None]:
        out: Tensor = baked_embed_to_prob(embedding, baked, sigma(current_epoch))

This probability map is just a tensor from 0-1. It's esentially a semantic map, and therefore we can use the tversky loss with the semantic map to generate a single loss value.

In [None]:
        _loss_embed = loss_embed(out, masks.gt(0).float())

The predicted skeletons and probability map have targets generated by the dataloader, and therefore we simply generate a loss using a similar method.

In [None]:
        _loss_prob = loss_prob(probability_map, masks.gt(0).float())
        _loss_skeleton = loss_skele(
            predicted_skeleton, skele_masks.gt(0).float()
        )  # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())

Finally, we let the user define the relative weight each loss value has on the overall training and the epoch at which we should first consider. This is defined in the configuration file and represented in code here.

In [None]:
        loss = (
            (
                cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT
                * (1 if current_epoch > cfg.TRAIN.LOSS_EMBED_START_EPOCH else 0)
                * _loss_embed
            )
            + (
                cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT
                * (1 if current_epoch > cfg.TRAIN.LOSS_PROBABILITY_START_EPOCH else 0)
                * _loss_prob
            )
            + (
                cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT
                * (1 if current_epoch > cfg.TRAIN.LOSS_SKELETON_START_EPOCH else 0)
                * _loss_skeleton
            )
        )

Now we scale the loss (if using stochastic weight averaging) and run backpropagation.

In [None]:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

## Warmup
We found that over training a randomly initialized model, helps that model learn the task on new data down the line. We can do all the steps above, but just in one dataset

In [None]:
 # Warmup... Get the first from train_data
    for images, masks, skeleton, skele_masks, baked in dataloader:
        pass

    assert images is not None, len(dataloader)

    warmup_range = trange(cfg.TRAIN.N_WARMUP, desc="Warmup: {}")
    for w in warmup_range:
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=cfg.TRAIN.MIXED_PRECISION):  # Saves Memory!
            out: Tensor = model(images)

            probability_map: Tensor = out[:, [-1], ...]
            vector: Tensor = out[:, 0:3:1, ...]
            predicted_skeleton: Tensor = out[:, [-2], ...]

            embedding: Tensor = vector_to_embedding(vector_scale, vector)
            out: Tensor = baked_embed_to_prob(embedding, baked, sigma(0))

            _loss_embed = loss_embed(
                out, masks.gt(0).float()
            )  # out = [B, 2/3, X, Y, Z?]
            _loss_prob = loss_prob(probability_map, masks.gt(0).float())
            _loss_skeleton = loss_skele(
                predicted_skeleton, skele_masks.gt(0).float()
            )  # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())

            loss = (
                (cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT * _loss_embed)
                + (cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT * _loss_prob)
                + (cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT * _loss_skeleton)
            )

            warmup_range.desc = f"{loss.item()}"

            if torch.isnan(loss):
                print(
                    f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
                )
                print(f"\t{torch.any(torch.isnan(vector))}")
                print(f"\t{torch.any(torch.isnan(embedding))}")
                continue

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

## Main Training Loop
We can now train our entire model. This simply takes the previous method, but applies it over multiple images in our dataset, multiple times. The only difference here is we do some logging to tensorboard.

In [None]:
# Train Step...
    epoch_range = (
        trange(cfg.TRAIN.NUM_EPOCHS, desc=f"Loss = {1.0000000}") if rank == 0 else range(cfg.TRAIN.NUM_EPOCHS)
    )
    for e in epoch_range:
        _loss, _embed, _prob, _skele = [], [], [], []

        if cfg.TRAIN.DISTRIBUTED:
            train_sampler.set_epoch(e)

        for images, masks, skeleton, skele_masks, baked in dataloader:
            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=cfg.TRAIN.MIXED_PRECISION):  # Saves Memory!
                out: Tensor = model(images)

                probability_map: Tensor = out[:, [-1], ...]
                vector: Tensor = out[:, 0:3:1, ...]
                predicted_skeleton: Tensor = out[:, [-2], ...]

                embedding: Tensor = vector_to_embedding(vector_scale, vector)
                out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

                _loss_embed = loss_embed(
                    out, masks.gt(0).float()
                )  # out = [B, 2/3, X, Y, :w
                # Z?]
                _loss_prob = loss_prob(probability_map, masks.gt(0).float())
                _loss_skeleton = loss_skele(
                    predicted_skeleton, skele_masks.gt(0).float()
                )  # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())

                # fuck this small amount of code.
                loss = (
                    (
                        cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT
                        * (1 if e > cfg.TRAIN.LOSS_EMBED_START_EPOCH else 0)
                        * _loss_embed
                    )
                    + (
                        cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT
                        * (1 if e > cfg.TRAIN.LOSS_PROBABILITY_START_EPOCH else 0)
                        * _loss_prob
                    )
                    + (
                        cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT
                        * (1 if e > cfg.TRAIN.LOSS_SKELETON_START_EPOCH else 0)
                        * _loss_skeleton
                    )
                )

                if torch.isnan(loss):
                    print(
                        f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
                    )
                    print(f"\t{torch.any(torch.isnan(vector))}")
                    print(f"\t{torch.any(torch.isnan(embedding))}")
                    continue

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if e > swa_start:
                swa_model.update_parameters(model)

            _loss.append(loss.item())
            _embed.append(_loss_embed.item())
            _prob.append(_loss_prob.item())
            _skele.append(_loss_skeleton.item())

        avg_epoch_loss.append(mean(_loss))
        avg_epoch_embed_loss.append(mean(_embed))
        avg_epoch_prob_loss.append(mean(_prob))
        avg_epoch_skele_loss.append(mean(_skele))
        scheduler.step()

        if writer and (rank == 0):
            writer.add_scalar("lr", scheduler.get_last_lr()[-1], e)
            writer.add_scalar("Loss/train", avg_epoch_loss[-1], e)
            writer.add_scalar("Loss/embed", avg_epoch_embed_loss[-1], e)
            writer.add_scalar("Loss/prob", avg_epoch_prob_loss[-1], e)
            writer.add_scalar("Loss/skele-mask", avg_epoch_skele_loss[-1], e)
            write_progress(
                writer=writer,
                tag="Train",
                epoch=e,
                images=images,
                masks=masks,
                probability_map=probability_map,
                vector=vector,
                out=out,
                skeleton=skeleton,
                predicted_skeleton=predicted_skeleton,
                gt_skeleton=skele_masks,
            )

        # # Validation Step
        if e % 10 == 0 and valdiation_dataloader:
            _loss, _embed, _prob, _skele = [], [], [], []
            for images, masks, skeleton, skele_masks, baked in valdiation_dataloader:
                with autocast(enabled=cfg.TRAIN.MIXED_PRECISION):  # Saves Memory!
                    with torch.no_grad():
                        out: Tensor = model(images)

                        probability_map: Tensor = out[:, [-1], ...]
                        predicted_skeleton: Tensor = out[:, [-2], ...]
                        vector: Tensor = out[:, 0:3:1, ...]

                        embedding: Tensor = vector_to_embedding(vector_scale, vector)
                        out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))

                        _loss_embed = loss_embed(out, masks.gt(0).float())
                        _loss_prob = loss_prob(probability_map, masks.gt(0).float())
                        _loss_skeleton = loss_prob(
                            predicted_skeleton, skele_masks.gt(0).float()
                        )

                        loss = (2 * _loss_embed) + (2 * _loss_prob) + _loss_skeleton

                        if torch.isnan(loss):
                            print(
                                f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
                            )
                            print(f"\t{torch.any(torch.isnan(vector))}")
                            print(f"\t{torch.any(torch.isnan(embedding))}")
                            continue

                scaler.scale(loss)
                _loss.append(loss.item())
                _embed.append(_loss_embed.item())
                _prob.append(_loss_prob.item())
                _skele.append(_loss_skeleton.item())

            avg_val_loss.append(mean(_loss))
            avg_val_embed_loss.append(mean(_embed))
            avg_val_prob_loss.append(mean(_prob))
            avg_val_skele_loss.append(mean(_skele))

            if writer and (rank == 0):
                writer.add_scalar("Validation/train", avg_val_loss[-1], e)
                writer.add_scalar("Validation/embed", avg_val_embed_loss[-1], e)
                writer.add_scalar("Validation/prob", avg_val_prob_loss[-1], e)
                write_progress(
                    writer=writer,
                    tag="Validation",
                    epoch=e,
                    images=images,
                    masks=masks,
                    probability_map=probability_map,
                    vector=vector,
                    out=out,
                    skeleton=skeleton,
                    predicted_skeleton=predicted_skeleton,
                    gt_skeleton=skele_masks,
                )

        if rank == 0:
            epoch_range.desc = (
                f"lr={scheduler.get_last_lr()[-1]:.3e}, Loss (train | val): "
                + f"{avg_epoch_loss[-1]:.5f} | {avg_val_loss[-1]:.5f}"
            )

        state_dict = (
            model.module.state_dict()
            if hasattr(model, "module")
            else model.state_dict()
        )
        if e % 100 == 0:
            torch.save(state_dict, cfg.TRAIN.SAVE_PATH + f"/test_{e}.trch")

## Save the model
Finally, each model trained by this script is saved as a dictionary with the configuration file ```cfg```, ```model_state_dict```, and the ```optimizer_state_dict```. It is saved to the same name as the SummaryWriter object for tensorboard, linking the two.

In [None]:
    if rank == 0:
        state_dict = (
            model.module.state_dict()
            if hasattr(model, "module")
            else model.state_dict()
        )
        constants = {
            "cfg": cfg,
            "model_state_dict": state_dict,
            "optimizer_state_dict": optimizer.state_dict(),
            "avg_epoch_loss": avg_epoch_loss,
            "avg_epoch_embed_loss": avg_epoch_embed_loss,
            "avg_epoch_prob_loss": avg_epoch_prob_loss,
            "avg_epoch_skele_loss": avg_epoch_skele_loss,
            "avg_val_loss": avg_epoch_loss,
            "avg_val_embed_loss": avg_epoch_embed_loss,
            "avg_val_prob_loss": avg_epoch_prob_loss,
            "avg_val_skele_loss": avg_epoch_skele_loss,
        }
        try:
            torch.save(
                constants,
                f"{cfg.TRAIN.SAVE_PATH}/{os.path.split(writer.log_dir)[-1]}.trch",
            )
        except:
            print(
                f"Could not save at: {cfg.TRAIN.SAVE_PATH}/{os.path.split(writer.log_dir)[-1]}.trch"
                f"Saving at {os.getcwd()}/{os.path.split(writer.log_dir)[-1]}.trch instead"
            )

            torch.save(
                constants,
                f"{os.getcwd()}/{os.path.split(writer.log_dir)[-1]}.trch",
            )