# Dual Domain Training for 3D Datasets
This Notebook is an upgraded version of the already existing Training_3D.ipynb. The reconstructed images of the models trained by the normal training seem to be very noisy. The dual domain training can hopefully remove the noise better. The structure of the Network stays the same but ray and the torch distributor function are used to run the training twice on two different gpus. Furthermore the ispace loss is used to weight the difference between the kspace data of the 2 models and to average their gradients before doing the optimizer step. In this way both trainings are running seperately but the optimizer step they do are the exact same because the greadients used during the optimizers step are the average of the both calculated gradients. Setting the weight_ispace_loss on 0 results in the normal Training again (except of the fact that there are 2 models getting trained but only one is saved at the end).

In [None]:
# Import of all necessary functions and classes
import gc
import os
import time

import h5py
import numpy as np
import ray
import torch
import torch.distributed as dist
from ray.air.config import RunConfig
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

import sys

sys.path.insert(0, "../../src")
# activates the terminal output for print commands in ray
import logging

from juart.conopt.functional.fourier import (
    nonuniform_fourier_transform_adjoint,
)
from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.loss.loss import JointLoss
from juart.dl.model.unrollnet import LookaheadModel, UnrolledNet
from juart.dl.operation.modules import training
from juart.dl.utils.dist import GradientAccumulator

logging.basicConfig(level=logging.INFO)


# Training function that is later passed to the TorchTrainer
def train_func():
    # define variables
    shape = (156, 156, 156, 2, 1)
    nX, nY, nZ, nTI, nTE = shape
    weight_kspace_loss = [0.5, 0.5]  # weight the difference in k space
    weight_ispace_loss = [
        0.1,
        0.1,
    ]  # weight the difference of the two images (dual domain) and average their gradients
    weight_hankel_loss = [0.0, 0.0]
    weight_casorati_loss = [0.0, 0.0]
    weight_wavelet_loss = [0.0, 0.0]  # weight the loss in wavelet domain
    normalized_loss = True

    batch_size = 1  # number of datapoints used per batch iteration
    nD = 1  # number of datasets
    nP = 10  # number of permutations per epoch
    cgiter = 10  # number of dc iterations
    num_epochs = 10  # number of epochs

    global_rank = int(dist.get_rank())
    world_size = int(dist.get_world_size())
    group_size = 2
    model_dir = f"DD_01i_{cgiter}DC_{nP}P"
    root_dir = "/home/jovyan/models"
    endpoint_url = "https://s3.fz-juelich.de"
    model_backend = "local"

    single_epoch = False  # if its true the script will stop after 1 epoch
    save_checkpoint = True  # enables checkpoint saving
    checkpoint_frequency = 5  # number of iterations between the save files
    load_model_state = (
        True  # if true the latest model state will be loaded if available
    )
    load_averaged_model_state = True  # latest averaged model state will be loaded
    load_optim_state = True  # latest optimizer state will be loaded
    load_metrics = True  # the latest metrics (lost, iterations) will be loaded

    num_groups = 1
    batch_size_local = batch_size // num_groups
    num_iterations = nD * nP * num_epochs

    ################################################################
    # Setting the rank for each worker
    for rank in range(0, world_size, group_size):
        ranks = list(range(rank, rank + group_size, 1))
        device = f"cuda:{global_rank}"
        if global_rank in ranks:
            print(f"Rank {global_rank} is in group {ranks} ...")
            group = dist.new_group(ranks, backend="gloo")

    ################################################################
    # reading and shaping data
    data_path = "/home/jovyan/juart/examples/data/3DLiss_vd_preproc.h5"
    with h5py.File(data_path, "r") as f:
        k = torch.from_numpy(f["k"][:])[..., None]
        C = torch.from_numpy(f["coilsens"][:])
        d = torch.from_numpy(f["d"][:])[..., None]

        print(f"Coilsensitivity shape {C.shape}")
        print(f"Trajectory shape {k.shape}")
        print(f"Signal shape {d.shape}")

    k /= 2 * k.max()

    ################################################################
    # Defining the neural network

    model = UnrolledNet(
        shape,
        CG_Iter=cgiter,
        num_unroll_blocks=10,
        num_of_resblocks=15,
        features=[16, 32, 64],
        kernel_size=(3, 3, 3),
        pad_to=160,
        activation="ReLU",
        regularizer="UNet",
        Checkpoints=True,
    ).to(device)

    loss_fn = JointLoss(
        (3, 3),
        weights_kspace_loss=weight_kspace_loss,
        weights_ispace_loss=weight_ispace_loss,
        weights_hankel_loss=weight_hankel_loss,
        weights_casorati_loss=weight_casorati_loss,
        weights_wavelet_loss=weight_wavelet_loss,
        normalized_loss=normalized_loss,
        group=group,
        device=device,
    )

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0,
    )

    accumulator = GradientAccumulator(
        model,
        accumulation_steps=batch_size_local,
        max_norm=1.0,
        normalized_gradient=False,
    )

    averaged_model = LookaheadModel(
        model,
        alpha=0.5,
        k=5,
    )

    dist.barrier()

    checkpoint_manager = CheckpointManager(
        model_dir,
        root_dir=root_dir,
        endpoint_url=endpoint_url,
        backend=model_backend,
    )

    dist.barrier()

    ################################################################
    # LOADING CURRENT MODEL STATE
    if load_model_state:
        print(f"Rank {global_rank} - Loading model state ...")
        checkpoint = checkpoint_manager.load(["model_state"], map_location=device)
        if all(checkpoint.values()):
            model.load_state_dict(checkpoint["model_state"])
        else:
            print(f"Rank {global_rank} - Could not load model state.")

    if load_averaged_model_state:
        print(f"Rank {global_rank} - Loading averaged model state ...")
        checkpoint = checkpoint_manager.load(
            ["averaged_model_state"], map_location=device
        )
        if all(checkpoint.values()):
            averaged_model.load_state_dict(checkpoint["averaged_model_state"])
        else:
            print(f"Rank {global_rank} - Could not load averaged model state.")

    if load_optim_state:
        print(f"Rank {global_rank} - Loading optim state ...")
        checkpoint = checkpoint_manager.load(["optim_state"], map_location=device)
        if all(checkpoint.values()):
            optimizer.load_state_dict(checkpoint["optim_state"])
        else:
            print(f"Rank {global_rank} - Could not load optim state.")

        total_trn_loss = list()
        total_val_loss = list()
        iteration = 0

    if load_metrics:
        print(f"Rank {global_rank} - Loading metrics ...")
        checkpoint = checkpoint_manager.load(["trn_loss", "val_loss", "iteration"])
        if all(checkpoint.values()):
            total_trn_loss = list(checkpoint["trn_loss"])
            total_val_loss = list(checkpoint["val_loss"])
            iteration = checkpoint["iteration"]
        else:
            print(f"Rank {global_rank} - Could not load metrics.")

    print(f"Rank {global_rank} - Continue with iteration {iteration} ...")

    dist.barrier()

    ################################################################
    # ACTUAL TRAINING LOOP
    total_trn_loss = list()
    total_val_loss = list()
    iteration = 0

    generator = torch.Generator()

    while iteration < num_iterations:
        tic = time.time()
        generator.manual_seed(iteration % nP)

        kspace_mask_worker0 = torch.randint(
            0, 2, (1, d.shape[1], 2, 1), generator=generator
        )
        kspace_mask_worker1 = 1 - kspace_mask_worker0

        # Defining data for worker 0
        if global_rank == 0:
            d_masked = d * kspace_mask_worker0
            AHd = nonuniform_fourier_transform_adjoint(k, d_masked, (nX, nY, nZ))
            AHd = torch.sum(torch.conj(C[..., None, None]) * AHd, dim=0)

            data = [
                {
                    "images_regridded": AHd,
                    "kspace_trajectory": k,
                    "sensitivity_maps": C,
                    "kspace_mask_source": kspace_mask_worker1,
                    "kspace_mask_target": kspace_mask_worker0,
                    "kspace_data": d,
                }
            ]

        # Defining data for worker 1
        elif global_rank == 1:
            d_masked = d * kspace_mask_worker1
            AHd = nonuniform_fourier_transform_adjoint(k, d_masked, (nX, nY, nZ))
            AHd = torch.sum(torch.conj(C[..., None, None]) * AHd, dim=0)

            data = [
                {
                    "images_regridded": AHd,
                    "kspace_trajectory": k,
                    "sensitivity_maps": C,
                    "kspace_mask_source": kspace_mask_worker0,
                    "kspace_mask_target": kspace_mask_worker1,
                    "kspace_data": d,
                }
            ]

        trn_loss = training(
            [0],
            data,
            model,
            loss_fn,
            optimizer,
            accumulator,
            group=group,
            device=device,
        )

        val_loss = [0] * batch_size
        total_trn_loss.append(trn_loss)

        ################################################################
        # SAVING DATA
        if global_rank == 0:
            # Completed epoch
            if save_checkpoint and np.mod(iteration + batch_size, nD * nP) == 0:
                print("Creating tagged checkpoint ...")

                checkpoint = {
                    "iteration": iteration + batch_size,
                    "model_state": model.state_dict(),
                    "averaged_model_state": averaged_model.state_dict(),
                    "optim_state": optimizer.state_dict(),
                    "trn_loss": total_trn_loss,
                    "val_loss": total_val_loss,
                }

                epoch = (iteration + batch_size) // (nD * nP)
                checkpoint_manager.save(checkpoint, tag=f"_epoch_{epoch}")

                if single_epoch:
                    # Also save the checkpoint as untagged checkpoint
                    # Otherwise, training will be stuck in endless loop
                    checkpoint_manager.save(checkpoint)
                    checkpoint_manager.release()
                    break

            # Intermediate checkpoint
            elif (
                save_checkpoint
                and np.mod(iteration + batch_size, checkpoint_frequency) == 0
            ):
                print("Creating untagged checkpoint ...")

                checkpoint = {
                    "iteration": iteration + batch_size,
                    "model_state": model.state_dict(),
                    "averaged_model_state": averaged_model.state_dict(),
                    "optim_state": optimizer.state_dict(),
                    "trn_loss": total_trn_loss,
                    "val_loss": total_val_loss,
                }

                checkpoint_manager.save(checkpoint, block=False)

            toc = time.time() - tic

            print(
                (
                    f"Iteration: {iteration} - "
                    + f"Elapsed time: {toc:.0f} - "
                    + f"Training loss: {[f'{loss:.3f}' for loss in trn_loss]} - "
                    + f"Validation loss: {[f'{loss:.3f}' for loss in val_loss]}"
                )
            )

        torch.cuda.empty_cache()
        gc.collect()

        iteration += batch_size

    # Return the trained model
    return total_trn_loss


################################################################
# main function that initializes needed classes and runs the train function
def main():
    ray.init(runtime_env={"working_dir": "/home/jovyan/juart/src"})
    scaling_config = ScalingConfig(
        num_workers=2,  # number of workers that should be initialized
        use_gpu=True,  # should gpu be used?
        resources_per_worker={"CPU": 24, "GPU": 1},
    )

    # Define the run configuration
    run_config = RunConfig(
        name="torch_trainer_example",  # name of the log file
        verbose=1,  # detail of the ouput
    )

    # Create the TorchTrainer
    trainer = TorchTrainer(
        train_func,
        scaling_config=scaling_config,
        run_config=run_config,
    )

    # Run the training
    result = trainer.fit()  # runs the function we passed to the trainer
    print("Training complete!")


if __name__ == "__main__":
    main()

100%|██████████████████████████████████████████| 10/10 [00:22<00:00,  2.24s/it]
 90%|██████████████████████████████████████▋    | 9/10 [00:20<00:02,  2.18s/it][32m [repeated 5x across cluster][0m
100%|██████████████████████████████████████████| 10/10 [00:22<00:00,  2.25s/it]


[36m(RayTrainWorker pid=1677971)[0m Rank 1 - model initialization done -> loss fn initialization
