# Training Notebook

## Import of all needed scripts

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../../src")

import gc
import time

import numpy as np
import torch
import torch.distributed as dist

from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.data.training import DatasetTraining
from juart.dl.loss.loss import JointLoss
from juart.dl.model.unrollnet import (
    LookaheadModel,
    UnrolledNet,
)
from juart.dl.operation.modules import training, validation
from juart.dl.utils.dist import GradientAccumulator

## Defining shuffle function
When training a model the slices and subjects of the datasets should not be in order. Therefore this function creates a random order for every single epoch, granting that every slice is used once and only once in every epoch. It is possible to give the function a seed so that the random order is the same order in every training. This is used for better comparability of the models, because the accuracy of the models can slightly variate with the order of the slices and subjects in the training.

In [2]:
def shuffled_indices(num_samples, num_epochs, rng):
    indices = np.repeat(np.arange(num_samples), num_epochs)
    indices = indices.reshape((num_samples, num_epochs))
    indices = rng.permuted(indices, axis=0)
    indices = indices.T.ravel()

    # Check if each sample is used once and only once in every epoch
    assert indices.size == num_samples * num_epochs
    for i in np.split(indices, num_epochs):
        assert np.unique(i).size == num_samples

    return indices

## Define important variables

In [3]:
# dataset options
nX, nY, nZ = 256, 256, 1  # Number of pixels in x-/y-/z-direction
nTI, nTE = 2, 2  # Number of measurements during the T1/T2 decay
shape = (nX, nY, nZ, nTI, nTE)  # Defining the shape later used for the model
num_spokes = 64  # number of spokes that should be used for training
nD = 1  # Number of subjects
nS = 160  # Number of slices per subject

# device options
device = "cpu"  # defines whether the model should be trained on the cpu or gpu
group = None
group_rank = 0
group_index = 0
num_groups = 1

# CheckpointManager Options
load_model_state = True  # Load the last saved model state
load_averaged_model_state = True  # Load the last saved averaged model state
load_optim_state = True  # Load the las saved optimizer state
load_metrics = True  # Load the last saved metrics (iterations and loss)
directory = "model_test"  # Name that is used for the save directory of the model
root_dir = "/home/jovyan/models"  # path of the model directory
backend = "local"  # backend of the model directory

# Training loop options
num_epochs = 1  # Number of epochs of the training
model_training = True  # Activate Training mode
model_validation = False  # Activate validation mode
save_checkpoint = (
    True  # Create save files that contain the current state of the training
)
checkpoint_frequency = 10  # Sets the number of iterations between creating save files
single_epoch = True  # Create a seperate save file after every single epoch
batch_size = 1  # Number of slices that should be used for training per batch

batch_size_local = batch_size // num_groups
num_iterations = nD * nS * num_epochs  # complete number of iterations

## Randomize indices
The shuffle function gets used to shuffle the indices of the training and validation dataset.

In [4]:
rng = np.random.default_rng(seed=0)

training_indices = shuffled_indices(nD * nS, num_epochs, rng)
training_indices_batched = training_indices.reshape((-1, batch_size_local, num_groups))

validation_indices = shuffled_indices(nD * nS, num_epochs, rng)
validation_indices_batched = validation_indices.reshape(
    (-1, batch_size_local, num_groups)
)

In [5]:
dist.init_process_group(
    backend="gloo", init_method="tcp://127.0.0.1:23456", world_size=1, rank=0
)

[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0


## Initializing the model
In this cell the model gets initialized and set up for the specific dataset, that is used.

In [6]:
model = UnrolledNet(
    shape,
    features=64,
    CG_Iter=10,
    num_unroll_blocks=10,
    activation="ReLU",
    disable_progress_bar=True,
    timing_level=0,
    validation_level=0,
    device=device,
)

## Defining the loss function
The loss function takes the prediction of the model and the correct result to compute the resulting loss. The loss will be separated in different parts (kspace, ispace, wavelet,...) and then be weighted with the matching weight.

In [7]:
loss_fn = JointLoss(
    shape,
    (3, 3),
    weights_kspace_loss=(0.5, 0.5),
    weights_ispace_loss=(0.0, 0.0),
    weights_wavelet_loss=(0.0, 0.0),
    weights_hankel_loss=(0.0, 0.0),
    weights_casorati_loss=(0.0, 0.0),
    normalized_loss=True,
    timing_level=0,
    validation_level=0,
    group=group,
    device=device,
)

## Setting up the optimizer
The optimizer is used for adapting the model parameters so that the prediction of the model will get closer to the correct result.

In [8]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.0001,
    betas=[0.9, 0.999],
    eps=1.0e-8,
    weight_decay=0.0,
)

In [9]:
accumulator = GradientAccumulator(
    model,
    accumulation_steps=batch_size_local,
    max_norm=1.0,
    normalized_gradient=False,
)

## The average model
The average model uses the floating average of the original model. This approach is more robust then using the original model directly for the reconstruction.

In [10]:
averaged_model = LookaheadModel(
    model,
    alpha=0.5,
    k=5,
)

## Initializing the CheckpointManager
The CheckpointManager allows to save the model in a specific location or load a model from a specific location. Here the location is defined in which the CheckpointManager is operating.

In [11]:
checkpoint_manager = CheckpointManager(
    directory=directory,
    root_dir=root_dir,
    backend=backend,
)

## Loading save files with the CheckpointManager if available
### Load saved model
The following cell will test if theres a saved model_state already. It will load the current state and will provide that the training process will continue at the same point where it was saved the last time.

In [12]:
if load_model_state:
    print("Loading model state ...")
    checkpoint = checkpoint_manager.load(["model_state"], map_location=device)
    if all(checkpoint.values()):
        model.load_state_dict(checkpoint["model_state"])
    else:
        print("Could not load model state.")

Loading model state ...
Could not load model state.


### Load saved averaged model
The next cell will do the same thing but just for the averaged model and not for the original model.

In [13]:
if load_averaged_model_state:
    print("Loading averaged model state ...")
    checkpoint = checkpoint_manager.load(["averaged_model_state"], map_location=device)
    if all(checkpoint.values()):
        averaged_model.load_state_dict(checkpoint["averaged_model_state"])
    else:
        print("Could not load averaged model state.")

Loading averaged model state ...
Could not load averaged model state.


### Load saved optimizer
This cell will load the last saved state of the optimizer of the saved model, so that the optimizer parameters, which where achieved through previous iterations, still have their impact.

In [14]:
if load_optim_state:
    print("Loading optim state ...")
    checkpoint = checkpoint_manager.load(["optim_state"], map_location=device)
    if all(checkpoint.values()):
        optimizer.load_state_dict(checkpoint["optim_state"])
    else:
        print("Could not load optim state.")

Loading optim state ...
Could not load optim state.


### Load saved metrics
In the following cells the last saved loss and the last saved iteration of the saved model will be loaded from the save files.

In [15]:
total_trn_loss = list()
total_val_loss = list()
iteration = 0

In [16]:
if load_metrics:
    print("Loading metrics ...")
    checkpoint = checkpoint_manager.load(["trn_loss", "val_loss", "iteration"])
    if all(checkpoint.values()):
        total_trn_loss = list(checkpoint["trn_loss"])
        total_val_loss = list(checkpoint["val_loss"])
        iteration = checkpoint["iteration"]
    else:
        print("Could not load metrics.")

Loading metrics ...
Could not load metrics.


This cells gives an output which will have the information about the last saved iteration from the model. The training will continue at this iteration.

In [17]:
print(f"Continue with iteration {iteration} ...")

Continue with iteration 0 ...


## Loading Dataset for Training and Validation
The dataset can be varied in the number of slices, number of spokes, split_fractions and mode (training/validation).

In [18]:
training_data = DatasetTraining(
    "qrage/sessions/%s/preproc.zarr/preproc.zarr",
    ["7T1566"],
    np.arange(0, 160),
    num_spokes,
    [0.0, 0.5, 0.5],
    mode="training",
    group_rank=group_rank,
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
)

In [19]:
validation_data = DatasetTraining(
    "qrage/sessions/%s/preproc.zarr/preproc.zarr",
    ["7T1566"],
    np.arange(0, 160),
    num_spokes,
    [0.0, 0.5, 0.5],
    mode="validation",
    group_rank=group_rank,
    endpoint_url="https://s3.fz-juelich.de",
    backend="s3",
)

## Main training loop

In [20]:
while iteration < num_iterations:
    tic = time.time()

    # Reset the seed so that training can be resumed
    np.random.seed(iteration)
    torch.manual_seed(iteration)

    training_index = training_indices_batched[
        iteration // batch_size,
        :,
        group_index,
    ].tolist()
    validation_index = validation_indices_batched[
        iteration // batch_size, :, group_index
    ].tolist()

    # TRAINING
    if model_training:
        print(f"Training index {training_index} ...")

        trn_loss = training(
            training_index,
            training_data,
            model,
            loss_fn,
            optimizer,
            accumulator,
            group=group,
            device=device,
        )

        averaged_model.update_parameters(
            model,
        )

        torch.cuda.empty_cache()
        gc.collect()

    else:
        trn_loss = [0] * batch_size

    # VALIDATION
    if model_validation:
        print(f"Validation index {validation_index} ...")

        val_loss = validation(
            validation_index,
            validation_data,
            averaged_model,
            loss_fn,
            group=group,
            device=device,
        )
        torch.cuda.empty_cache()
        gc.collect()

    else:
        val_loss = [0] * batch_size

    total_trn_loss += trn_loss
    total_val_loss += val_loss

    # SAVING
    # Completed epoch
    if save_checkpoint and np.mod(iteration + batch_size, nD * nS) == 0:
        print("Creating tagged checkpoint ...")

        checkpoint = {
            "iteration": iteration + batch_size,
            "model_state": model.state_dict(),
            "averaged_model_state": averaged_model.state_dict(),
            "optim_state": optimizer.state_dict(),
            "trn_loss": total_trn_loss,
            "val_loss": total_val_loss,
        }

        epoch = (iteration + batch_size) // (nD * nS)
        checkpoint_manager.save(checkpoint, tag=f"_epoch_{epoch}")

        if single_epoch:
            # Also save the checkpoint as untagged checkpoint
            # Otherwise, training will be stuck in endless loop
            checkpoint_manager.save(checkpoint)
            checkpoint_manager.release()
            break

    # Intermediate checkpoint
    elif save_checkpoint and np.mod(iteration + batch_size, checkpoint_frequency) == 0:
        print("Creating untagged checkpoint ...")

        checkpoint = {
            "iteration": iteration + batch_size,
            "model_state": model.state_dict(),
            "averaged_model_state": averaged_model.state_dict(),
            "optim_state": optimizer.state_dict(),
            "trn_loss": total_trn_loss,
            "val_loss": total_val_loss,
        }

        checkpoint_manager.save(checkpoint, block=False)

    toc = time.time() - tic

    print(
        (
            f"Iteration: {iteration} - "
            + f"Elapsed time: {toc:.0f} - "
            + f"Training loss: {[f'{loss:.3f}' for loss in trn_loss]} - "
            + f"Validation loss: {[f'{loss:.3f}' for loss in val_loss]}"
        )
    )

    torch.cuda.empty_cache()
    gc.collect()

    iteration += batch_size

Training index [10] ...
Data - Started loading Dataset 7T1566 - Slice 10 ...
Data - Completed loading dataset in 3.7 seconds.
Rank 0 - Data - Started creating masks [0.0, 0.5, 0.5] ...
Rank 0 - Data - Source fractions 0.5.
Rank 0 - Data - Target fractions 0.5.
Rank 0 - Data - Completed creating mask torch.Size([3, 1, 16384, 2, 2]) in 0.1 seconds.
Rank 0 - Data - Started regridding dataset ...
Rank 0 - Data - Completed regridding dataset torch.Size([256, 256, 1, 2, 2]) in 0.3 seconds.
Layer 0 done
Layer 1 done
Layer 2 done
Layer 3 done
Layer 4 done
Layer 5 done
Layer 6 done
Layer 7 done
Layer 8 done
Layer 9 done
Layer 10 done
Layer 11 done
Layer 12 done
Layer 13 done
Layer 14 done
Layer 0 done
Layer 1 done
Layer 2 done
Layer 3 done
Layer 4 done
Layer 5 done
Layer 6 done
Layer 7 done
Layer 8 done
Layer 9 done
Layer 10 done
Layer 11 done
Layer 12 done
Layer 13 done
Layer 14 done
Layer 0 done
Layer 1 done
Layer 2 done
Layer 3 done
Layer 4 done
Layer 5 done
Layer 6 done
Layer 7 done
Layer 8

Process SpawnProcess-1:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.13/multiprocessing/process.py", line 313, in _bootstrap
    self.run()
    ~~~~~~~~^^
  File "/opt/conda/lib/python3.13/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/juart/examples/dl/../../src/juart/dl/checkpoint/manager.py", line 89, in save_checkpoint_process
    self.save_buffer_to_filesystem(*self.save_queue.get())
                                    ~~~~~~~~~~~~~~~~~~~^^
  File "/opt/conda/lib/python3.13/multiprocessing/queues.py", line 101, in get
    res = self._recv_bytes()
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", 

KeyboardInterrupt: 