# Dual Domain Training for 3D Datasets
This Notebook is an upgraded version of the already existing Training_3D.ipynb. The reconstructed images of the models trained by the normal training seem to be very noisy. The dual domain training can hopefully remove the noise better. The structure of the Network stays the same but ray and the torch distributor function are used to run the training twice on two different gpus. Furthermore the ispace loss is used to weight the difference between the kspace data of the 2 models and to average their gradients before doing the optimizer step. In this way both trainings are running seperately but the optimizer step they do are the exact same because the greadients used during the optimizers step are the average of the both calculated gradients. Setting the weight_ispace_loss on 0 results in the normal Training again (except of the fact that there are 2 models getting trained but only one is saved at the end).

In [1]:
# Import of all necessary functions and classes
import os
import torch
import gc
import time
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from ray.air.config import RunConfig
from typing import Tuple
import numpy as np
import torch.distributed as dist
import ray
import h5py
import zarr as z
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

import sys
sys.path.insert(0, "../../src")
from juart.dl.loss.loss import JointLoss
from juart.dl.operation.modules import training
from juart.dl.utils.dist import GradientAccumulator
from juart.dl.model.unrollnet import LookaheadModel, UnrolledNet
from juart.dl.checkpoint.manager import CheckpointManager
from juart.conopt.functional.fourier import (
    fourier_transform_adjoint,
    fourier_transform_forward,
    nonuniform_fourier_transform_adjoint,
)

# activates the terminal output for print commands in ray
import logging
logging.basicConfig(level=logging.INFO)

# Training function that is later passed to the TorchTrainer
def train_func():

    global_rank = int(dist.get_rank())
    world_size = int(dist.get_world_size())
    group_size = 2
    
    ################################################################
    # Setting the rank for each worker
    for rank in range(0, world_size, group_size):
        ranks = list(range(rank, rank + group_size, 1))
        device = f"cuda:{global_rank}" if torch.cuda.is_available() else "cpu"
        print(f"Rank {global_rank} is using devide {device} ...")
        if global_rank in ranks:
            print(f"Rank {global_rank} is in group {ranks} ...")
            group = dist.new_group(ranks, backend="gloo")
    
    ################################################################
    
    # define variables
    shape = (128,128,128,1,1)
    nX, nY, nZ, nTI, nTE = shape
    weight_kspace_loss = [0.5, 0.5] # weight the difference in k space
    weight_ispace_loss = [0.0, 0.0] # weight the difference of the two images (dual domain) and average their gradients
    weight_hankel_loss = [0.0, 0.0]
    weight_casorati_loss = [0.0, 0.0]
    weight_wavelet_loss = [0.0, 0.0] # weight the loss in wavelet domain
    normalized_loss = True

    batch_size = 1 # number of datapoints used per batch iteration
    nD = 1 # number of datasets
    nP = 2 # number of permutations per epoch
    cgiter = 2 # number of dc iterations
    num_epochs = 2 # number of epochs
    
    model_dir = f'nummodel{weight_ispace_loss[0]}i_{nP}P_{cgiter}DC_{num_epochs}E_R1'
    root_dir ="/home/jovyan/models"
    endpoint_url = "https://s3.fz-juelich.de"
    model_backend = 'local'

    single_epoch = False # if its true the script will stop after 1 epoch
    save_checkpoint = True # enables checkpoint saving
    checkpoint_frequency = 5 # number of iterations between the save files
    load_model_state = True # if true the latest model state will be loaded if available
    load_averaged_model_state = True # latest averaged model state will be loaded
    load_optim_state = True # latest optimizer state will be loaded
    load_metrics = True # the latest metrics (lost, iterations) will be loaded

    num_groups = 1
    batch_size_local = batch_size // num_groups
    num_iterations = nD * nP * num_epochs
    

    # reading and shaping data
    store = z.open("/home/jovyan/datasets/num_phantom_128_R1")

    C = torch.from_numpy(np.array(store["C"])).to(device)
    k = torch.from_numpy(np.array(store["k"]))[...,None,None].to(device)
    d = torch.from_numpy(np.array(store["d"]))[...,None,None].to(device)

    # --- shaping data ---
    k_scaled = (k / (2 * k.max())).to(device)

    generator = torch.Generator()

    ################################################################
    # Defining the neural network
    
    model = UnrolledNet(shape,
                      CG_Iter = cgiter,
                      num_unroll_blocks = 10,
                      num_res_blocks = 15,
                      features = 32,
                      axes = (1,2,3),
                      kernel_size = (3,3,3),
                      activation = 'ReLU',
                      ResNetCheckpoints = True).to(device)

    loss_fn = JointLoss(
        shape,
        (3, 3),
        weights_kspace_loss = weight_kspace_loss,
        weights_ispace_loss = weight_ispace_loss,
        weights_hankel_loss = weight_hankel_loss,
        weights_casorati_loss = weight_casorati_loss,
        weights_wavelet_loss = weight_wavelet_loss,
        normalized_loss=normalized_loss,
        group = group,
        device=device,
    )
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0,
    )

    accumulator = GradientAccumulator(
        model,
        accumulation_steps=batch_size_local,
        max_norm=1.0,
        normalized_gradient=False,
    )

    averaged_model = LookaheadModel(
        model,
        alpha=0.5,
        k=5,
    )

    dist.barrier()
    
    checkpoint_manager = CheckpointManager(
        model_dir,
        root_dir=root_dir,
        endpoint_url=endpoint_url,
        backend=model_backend,
    )

    dist.barrier()

    ################################################################
    # LOADING CURRENT MODEL STATE
    if load_model_state:
        print(f"Rank {global_rank} - Loading model state ...")
        checkpoint = checkpoint_manager.load(["model_state"], map_location=device)
        if all(checkpoint.values()):
            model.load_state_dict(checkpoint["model_state"])
        else:
            print(f"Rank {global_rank} - Could not load model state.")
    
    if load_averaged_model_state:
        print(f"Rank {global_rank} - Loading averaged model state ...")
        checkpoint = checkpoint_manager.load(
            ["averaged_model_state"], map_location=device
        )
        if all(checkpoint.values()):
            averaged_model.load_state_dict(checkpoint["averaged_model_state"])
        else:
            print(f"Rank {global_rank} - Could not load averaged model state.")
    
    if load_optim_state:
        print(f"Rank {global_rank} - Loading optim state ...")
        checkpoint = checkpoint_manager.load(["optim_state"], map_location=device)
        if all(checkpoint.values()):
            optimizer.load_state_dict(checkpoint["optim_state"])
        else:
            print(f"Rank {global_rank} - Could not load optim state.")
    
        total_trn_loss = list()
        total_val_loss = list()
        iteration = 0
    
    if load_metrics:
        print(f"Rank {global_rank} - Loading metrics ...")
        checkpoint = checkpoint_manager.load(["trn_loss", "val_loss", "iteration"])
        if all(checkpoint.values()):
            total_trn_loss = list(checkpoint["trn_loss"])
            total_val_loss = list(checkpoint["val_loss"])
            iteration = checkpoint["iteration"]
        else:
            print(f"Rank {global_rank} - Could not load metrics.")

    print(f"Rank {global_rank} - Continue with iteration {iteration} ...")

    dist.barrier()

    ################################################################
    # ACTUAL TRAINING LOOP
    total_trn_loss = list()
    total_val_loss = list()
    iteration = 0

    while iteration < num_iterations:
        tic = time.time()
        generator.manual_seed(iteration%nP)
    
        kspace_mask_worker0 = torch.randint(0, 2, (1, k_scaled.shape[1], 1, 1), generator=generator)
        kspace_mask_worker1 = 1 - kspace_mask_worker0

        # Defining data for worker 0
        if global_rank == 0:
            k_scaled_masked = (k_scaled * kspace_mask_worker0).to(device)
            AHd = nonuniform_fourier_transform_adjoint(k_scaled_masked, d, (nX, nY, nZ)).to(device)
            AHd = torch.sum(torch.conj(C[...,None,None]) * AHd, dim=0).to(device)
        
            data = [
               {
                   "images_regridded": AHd.to(device),
                   "kspace_trajectory": k_scaled.to(device),
                   "sensitivity_maps": C.to(device),
                   "kspace_mask_source": kspace_mask_worker0.to(device),
                   "kspace_mask_target": kspace_mask_worker1.to(device),
                   "kspace_data": d.to(device),
               }
            ]

        # Defining data for worker 1
        elif global_rank == 1:
            k_scaled_masked = (k_scaled * kspace_mask_worker1).to(device)
            AHd = nonuniform_fourier_transform_adjoint(k_scaled_masked, d, (nX, nY, nZ)).to(device)
            AHd = torch.sum(torch.conj(C[...,None,None]) * AHd, dim=0).to(device)
        
            data = [
               {
                   "images_regridded": AHd.to(device),
                   "kspace_trajectory": k_scaled.to(device),
                   "sensitivity_maps": C.to(device),
                   "kspace_mask_source": kspace_mask_worker1.to(device),
                   "kspace_mask_target": kspace_mask_worker0.to(device),
                   "kspace_data": d.to(device),
               }
            ]
    
        trn_loss = training(
           [0],
           data,
           model.to(device),
           loss_fn.to(device),
           optimizer.to(device),
           accumulator.to(device),
           group=group,
           device=device,
        )

        val_loss = [0] * batch_size

    ################################################################
    # SAVING DATA
        if global_rank == 0:
            # Completed epoch
            if (
                save_checkpoint
                and np.mod(iteration + batch_size, nD * nP) == 0
            ):
                print("Creating tagged checkpoint ...")
    
                checkpoint = {
                    "iteration": iteration + batch_size,
                    "model_state": model.state_dict(),
                    "averaged_model_state": averaged_model.state_dict(),
                    "optim_state": optimizer.state_dict(),
                    "trn_loss": total_trn_loss,
                    "val_loss": total_val_loss,
                }
    
                epoch = (iteration + batch_size) // (nD * nP)
                checkpoint_manager.save(checkpoint, tag=f"_epoch_{epoch}")
    
                if single_epoch:
                    # Also save the checkpoint as untagged checkpoint
                    # Otherwise, training will be stuck in endless loop
                    checkpoint_manager.save(checkpoint)
                    checkpoint_manager.release()
                    break
    
            # Intermediate checkpoint
            elif (
                save_checkpoint
                and np.mod(iteration + batch_size, checkpoint_frequency) == 0
            ):
                print("Creating untagged checkpoint ...")
    
                checkpoint = {
                    "iteration": iteration + batch_size,
                    "model_state": model.state_dict(),
                    "averaged_model_state": averaged_model.state_dict(),
                    "optim_state": optimizer.state_dict(),
                    "trn_loss": total_trn_loss,
                    "val_loss": total_val_loss,
                }
    
                checkpoint_manager.save(checkpoint, block=False)
    
            toc = time.time() - tic
    
            print(
                (
                    f"Iteration: {iteration} - "
                    + f"Elapsed time: {toc:.0f} - "
                    + f"Training loss: {[f'{loss:.3f}' for loss in trn_loss]} - "
                    + f"Validation loss: {[f'{loss:.3f}' for loss in val_loss]}"
                )
            )
    
        torch.cuda.empty_cache()
        gc.collect()
    
        iteration += batch_size

    # Return the trained model
    return {"model": model.parameters()}

################################################################
# main function that initializes needed classes and runs the train function
def main():

    dist.init_process_group(
    backend="gloo", init_method="tcp://127.0.0.1:23456", world_size = 1, rank=0
    )
    
    ray.init(runtime_env={"working_dir": "/home/jovyan/juart/src"})
    scaling_config = ScalingConfig(
        num_workers=2, # number of workers that should be initialized
        use_gpu=True,  # should gpu be used?
        resources_per_worker={"CPU": 24, "GPU": 1},
    )

    # Define the run configuration
    run_config = RunConfig(
        name="torch_trainer_example", # name of the log file
        verbose=1, # detail of the ouput
    )

    # Create the TorchTrainer
    trainer = TorchTrainer(
        train_func,
        scaling_config=scaling_config,
        run_config=run_config,
    )

    # Run the training
    result = trainer.fit() # runs the function we passed to the trainer
    print("Training complete!")

if __name__ == "__main__":
    main()


2025-09-24 09:19:58,590	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-09-24 09:19:58,689	INFO packaging.py:588 -- Creating a file package for local module '/home/jovyan/juart/src'.
2025-09-24 09:19:58,985	INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_ebf3d0d5f95cf163.zip' (70.09MiB) to Ray cluster...
2025-09-24 09:19:59,479	INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_ebf3d0d5f95cf163.zip'.
2025-09-24 09:20:04,544	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
2025-09-24 09:20:04,548	INFO tensorboardx.py:193 -- pip install "ray[tune]" to see TensorBoard files.


== Status ==
Current time: 2025-09-24 09:20:04 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/128 CPUs, 0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-24 09:20:09 (running for 00:00:05.18)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 PENDING)






== Status ==
Current time: 2025-09-24 09:20:14 (running for 00:00:10.21)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




[36m(RayTrainWorker pid=4178069)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TorchTrainer pid=4175346)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=4175346)[0m - (node_id=f21ca5a7638fd180f35278a2c123d9ee76196dd84493b25f8de7ecf7, ip=10.1.65.9, pid=4178069) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=4175346)[0m - (node_id=f21ca5a7638fd180f35278a2c123d9ee76196dd84493b25f8de7ecf7, ip=10.1.65.9, pid=4178068) world_rank=1, local_rank=1, node_rank=0


== Status ==
Current time: 2025-09-24 09:20:19 (running for 00:00:15.23)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


[36m(RayTrainWorker pid=4178069)[0m Rank 0 is using devide cuda:0 ...
[36m(RayTrainWorker pid=4178069)[0m Rank 0 is in group [0, 1] ...
== Status ==
Current time: 2025-09-24 09:20:24 (running for 00:00:20.25)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




[36m(RayTrainWorker pid=4178068)[0m [rank1]:[W924 09:20:25.884361458 ProcessGroupNCCL.cpp:4718] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.


[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Loading model state ...
[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Loading averaged model state ...
[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Loading optim state ...
[36m(RayTrainWorker pid=4178068)[0m Rank 1 is using devide cuda:1 ...
[36m(RayTrainWorker pid=4178068)[0m Rank 1 is in group [0, 1] ...
== Status ==
Current time: 2025-09-24 09:20:29 (running for 00:00:25.27)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Loading metrics ...
[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Could not load metrics.
[36m(RayTrainWorker pid=4178069)[0m Rank 0 - Continue with iteration 0 ...


2025-09-24 09:20:30,279	ERROR tune_controller.py:1331 -- Trial task failed for trial TorchTrainer_a7bd8_00000
Traceback (most recent call last):
  File "/opt/conda/lib/python3.13/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/opt/conda/lib/python3.13/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.13/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.13/site-packages/ray/_private/worker.py", line 2882, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.13/site-packages/ray/_private/worker.py", line 968, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(Run

== Status ==
Current time: 2025-09-24 09:20:30 (running for 00:00:25.73)
Using FIFO scheduling algorithm.
Logical resource usage: 49.0/128 CPUs, 2.0/2 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-24_09-19-53_758497_4165116/artifacts/2025-09-24_09-20-04/torch_trainer_example/driver_artifacts
Number of trials: 1/1 (1 ERROR)
Number of errored trials: 1
+--------------------------+--------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name               |   # failures | error file                                                                                                                                                                        |
|--------------------------+--------------+-----------------------------------------------------------------------------------------------------------------------------

TrainingFailedError: The Ray Train run failed. Please inspect the previous error messages for a cause. After fixing the issue (assuming that the error is not caused by your own application logic, but rather an error such as OOM), you can restart the run from scratch or continue this run.
To continue this run, you can use: `trainer = TorchTrainer.restore("/home/jovyan/ray_results/torch_trainer_example")`.
To start a new run that will retry on training failures, set `train.RunConfig(failure_config=train.FailureConfig(max_failures))` in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` for unlimited retries.