# Dual-Domain-Training
## Import needed Scripts

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys

sys.path.insert(0, "../../src")

import gc
import time

import numpy as np
import torch
import torch.distributed as dist

from juart.dl.checkpoint.manager import CheckpointManager
from juart.dl.data.training import DatasetTraining
from juart.dl.loss.loss import JointLoss
from juart.dl.model.unrollnet import (
    ExponentialMovingAverageModel,
    LookaheadModel,
    SingleContrastUnrolledNet,
    UnrolledNet,
)
from juart.dl.operation.modules import training, validation
from juart.dl.utils.dist import GradientAccumulator

import os

if os.getenv("ZS_SSL_RECON_SOFTWARE_DIR") is not None:
    sys.path.insert(0, os.getenv("ZS_SSL_RECON_SOFTWARE_DIR"))

from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

from juart.dl.utils.parser import options_parser
from juart.dl.train.train import train_loop_per_worker

## Define all necessary variables

In [2]:
# Model parameters
num_unroll_blocks = 10
num_res_blocks = 15
CG_Iter = 10
activation = "ReLU"
features = 512
directory = (
    "model_test_checkpoint2"  # Name that is used for the save directory of the model
)
root_dir = "/home/jovyan/models"  # path of the model directory
backend = "local"  # backend of the model directory

# Loss function parameters
weight_kspace_loss = [0.5, 0.5]
weight_ispace_loss = [0.1, 0.1]
weight_hankel_loss = [0.0, 0.01]
weight_casorati_loss = [0.0, 0.0]
weight_wavelet_loss = [0.0, 0.0]
normalized_loss = True

# Training parameters
epochs = 25
model_training = True
model_validation = False
ema_decay = 0.9
fractions = [0.0, 0.5, 0.5]

optimizer = "Adam"
normalized_gradient = False

averaged_model = "Lookahead"

save_checkpoint = True
checkpoint_frequency = 10

load_model_state = True
load_averaged_model_state = True
load_optim_state = True
load_metrics = True

disable_progress_bar = True
timing_level = 0
validation_level = 0

num_threads = 24
num_cpu_per_worker = 24
num_gpu_per_worker = 0
num_workers = 1
group_size = 1
use_gpu = True
device = "cuda:3"

data_dir = ""
data_backend = "local"
model_dir = ""
model_backend = "local"
image_dir = ""
image_backend = "local"
endpoint_url = "https://s3.fz-juelich.de"

datasets = []
slices = []
start = 0
stop = 3
step = 1
shape = 256, 256, 256, 2, 2
num_spokes = 8
batch_size = 1
groups = 1

## Initializing process group

In [3]:
dist.init_process_group(
    backend="gloo", init_method="tcp://127.0.0.1:23456", world_size=1, rank=0
)

## Trainloop_per_worker

In [4]:
def shuffled_indices(num_samples, num_epochs, rng):
    indices = np.repeat(np.arange(num_samples), num_epochs)
    indices = indices.reshape((num_samples, num_epochs))
    indices = rng.permuted(indices, axis=0)
    indices = indices.T.ravel()

    # Check if each sample is used once and only once in every epoch
    assert indices.size == num_samples * num_epochs
    for i in np.split(indices, num_epochs):
        assert np.unique(i).size == num_samples

    return indices


np.random.seed(0)
torch.manual_seed(0)

torch.set_num_threads(num_threads)
#torch.set_num_interop_threads(num_threads)

global_rank = int(dist.get_rank())
world_size = int(dist.get_world_size())

print(f"Rank {global_rank} - Intialize local groups ...")
dist.barrier()

dist.barrier()

for rank in range(0, world_size, group_size):
    ranks = list(range(rank, rank + group_size, 1))
    if global_rank in ranks:
        print(f"Rank {global_rank} is in group {ranks} ...")
        group = dist.new_group(ranks, backend="gloo")
    dist.barrier()

group_rank = dist.get_group_rank(group, global_rank)
group_index = global_rank // group_size
num_groups = world_size // group_size

print(f"Rank {global_rank} is local rank {group_rank} ...")

dist.barrier()

if use_gpu and torch.cuda.is_available():
    num_devices = torch.cuda.device_count()
    device_rank = np.mod(global_rank, torch.cuda.device_count())
    device = f"cuda:{device_rank}"
    print(
        f"Rank {global_rank} - Using CUDA device {device_rank} of {num_devices} ..."
    )
else:
    device = "cpu"

print(f"Rank {global_rank} is using device {device} ...")

dist.barrier()

nD = len(datasets)
nS = len(slices)
nX, nY, nZ, nTI, nTE = shape

num_epochs = epochs

# The number of batches that are computed serially via gradient accumulation
batch_size = batch_size
batch_size_local = batch_size // num_groups

num_iterations = nD * nS * num_epochs

rng = np.random.default_rng(seed=0)

training_indices = shuffled_indices(nD * nS, num_epochs, rng)
training_indices_batched = training_indices.reshape(
    (-1, batch_size_local, num_groups)
)

validation_indices = shuffled_indices(nD * nS, num_epochs, rng)
validation_indices_batched = validation_indices.reshape(
    (-1, batch_size_local, num_groups)
)

# Prepare models and optimizer

if groups == 1:
    model = UnrolledNet(
        shape,
        features=features,
        CG_Iter=CG_Iter,
        num_unroll_blocks=num_unroll_blocks,
        # weight_standardization=options["weight_standardization"],
        # spectral_normalization=options["spectral_normalization"],
        activation=activation,
        disable_progress_bar=disable_progress_bar,
        timing_level=timing_level,
        validation_level=validation_level,
        device=device,
    )
else:
    model = SingleContrastUnrolledNet(
        shape,
        features=features,
        CG_Iter=CG_Iter,
        num_unroll_blocks=num_unroll_blocks,
        # weight_standardization=options["weight_standardization"],
        # spectral_normalization=options["spectral_normalization"],
        activation=activation,
        disable_progress_bar=disable_progress_bar,
        timing_level=timing_level,
        validation_level=validation_level,
        device=device,
    )

loss_fn = JointLoss(
    shape,
    (3, 3),
    weights_kspace_loss=weight_kspace_loss,
    weights_ispace_loss=weight_ispace_loss,
    weights_wavelet_loss=weight_wavelet_loss,
    weights_hankel_loss=weight_hankel_loss,
    weights_casorati_loss=weight_casorati_loss,
    normalized_loss=normalized_loss,
    timing_level=timing_level,
    validation_level=validation_level,
    group=group,
    device=device,
)

if optimizer == "Adam":
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0,
    )
elif optimizer == "AdamW":
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0,
    )
elif optimizer == "RAdam":
    optimizer = torch.optim.RAdam(
        model.parameters(),
        lr=0.0001,
        betas=[0.9, 0.999],
        eps=1.0e-8,
        weight_decay=0.0,
    )

accumulator = GradientAccumulator(
    model,
    accumulation_steps=batch_size_local,
    max_norm=1.0,
    normalized_gradient=normalized_gradient
)

if averaged_model == "EMA":
    print(f"Rank {global_rank} - ExponentialMovingAverageModel")
    averaged_model = ExponentialMovingAverageModel(
        model,
        decay=ema_decay,
    )
elif averaged_model == "Lookahead":
    print(f"Rank {global_rank} - LookaheadModel")
    averaged_model = LookaheadModel(
        model,
        alpha=0.5,
        k=5,
    )

checkpoint_manager = CheckpointManager(
    model_dir,
    root_dir=root_dir,
    endpoint_url=endpoint_url,
    backend=model_backend,
)

if load_model_state:
    print(f"Rank {global_rank} - Loading model state ...")
    checkpoint = checkpoint_manager.load(["model_state"], map_location=device)
    if all(checkpoint.values()):
        model.load_state_dict(checkpoint["model_state"])
    else:
        print(f"Rank {global_rank} - Could not load model state.")

if load_averaged_model_state:
    print(f"Rank {global_rank} - Loading averaged model state ...")
    checkpoint = checkpoint_manager.load(
        ["averaged_model_state"], map_location=device
    )
    if all(checkpoint.values()):
        averaged_model.load_state_dict(checkpoint["averaged_model_state"])
    else:
        print(f"Rank {global_rank} - Could not load averaged model state.")

if load_optim_state:
    print(f"Rank {global_rank} - Loading optim state ...")
    checkpoint = checkpoint_manager.load(["optim_state"], map_location=device)
    if all(checkpoint.values()):
        optimizer.load_state_dict(checkpoint["optim_state"])
    else:
        print(f"Rank {global_rank} - Could not load optim state.")

    total_trn_loss = list()
    total_val_loss = list()
    iteration = 0

if load_metrics:
    print(f"Rank {global_rank} - Loading metrics ...")
    checkpoint = checkpoint_manager.load(["trn_loss", "val_loss", "iteration"])
    if all(checkpoint.values()):
        total_trn_loss = list(checkpoint["trn_loss"])
        total_val_loss = list(checkpoint["val_loss"])
        iteration = checkpoint["iteration"]
    else:
        print(f"Rank {global_rank} - Could not load metrics.")

print(f"Rank {global_rank} - Continue with iteration {iteration} ...")

training_data = DatasetTraining(
    data_dir,
    datasets,
    slices,
    num_spokes,
    fractions,
    mode="training",
    group_rank=group_rank,
    root_dir=root_dir,
    endpoint_url=endpoint_url,
    backend=data_backend,
)

while iteration < num_iterations:
    tic = time.time()

    # Reset the seed so that training can be resumed
    np.random.seed(iteration)
    torch.manual_seed(iteration)

    training_index = training_indices_batched[
        iteration // batch_size,
        :,
        group_index,
    ].tolist()
    validation_index = validation_indices_batched[
        iteration // batch_size, :, group_index
    ].tolist()

    if options["model_training"]:
        print(f"Rank {global_rank} - Training index {training_index} ...")

        trn_loss = training(
            training_index,
            training_data,
            model,
            loss_fn,
            optimizer,
            accumulator,
            group=group,
            device=device,
        )

        averaged_model.update_parameters(
            model,
        )

        torch.cuda.empty_cache()
        gc.collect()

    else:
        trn_loss = [0] * batch_size

    if options["model_validation"]:
        print(f"Rank {global_rank} - Validation index {validation_index} ...")

        val_loss = validation(
            validation_index,
            validation_data,
            averaged_model,
            loss_fn,
            group=group,
            device=device,
        )
        torch.cuda.empty_cache()
        gc.collect()

    else:
        val_loss = [0] * batch_size

    total_trn_loss += trn_loss
    total_val_loss += val_loss

    if global_rank == 0:
        # Completed epoch
        if (
            options["save_checkpoint"]
            and np.mod(iteration + batch_size, nD * nS) == 0
        ):
            print("Creating tagged checkpoint ...")

            checkpoint = {
                "iteration": iteration + batch_size,
                "model_state": model.state_dict(),
                "averaged_model_state": averaged_model.state_dict(),
                "optim_state": optimizer.state_dict(),
                "trn_loss": total_trn_loss,
                "val_loss": total_val_loss,
            }

            epoch = (iteration + batch_size) // (nD * nS)
            checkpoint_manager.save(checkpoint, tag=f"_epoch_{epoch}")

            if options["single_epoch"]:
                # Also save the checkpoint as untagged checkpoint
                # Otherwise, training will be stuck in endless loop
                checkpoint_manager.save(checkpoint)
                checkpoint_manager.release()
                break

        # Intermediate checkpoint
        elif (
            options["save_checkpoint"]
            and np.mod(iteration + batch_size, options["checkpoint_frequency"]) == 0
        ):
            print("Creating untagged checkpoint ...")

            checkpoint = {
                "iteration": iteration + batch_size,
                "model_state": model.state_dict(),
                "averaged_model_state": averaged_model.state_dict(),
                "optim_state": optimizer.state_dict(),
                "trn_loss": total_trn_loss,
                "val_loss": total_val_loss,
            }

            checkpoint_manager.save(checkpoint, block=False)

        toc = time.time() - tic

        print(
            (
                f"Iteration: {iteration} - "
                + f"Elapsed time: {toc:.0f} - "
                + f"Training loss: {[f'{loss:.3f}' for loss in trn_loss]} - "
                + f"Validation loss: {[f'{loss:.3f}' for loss in val_loss]}"
            )
        )

    torch.cuda.empty_cache()
    gc.collect()

    iteration += batch_size

Rank 0 - Intialize local groups ...
Rank 0 is in group [0] ...
Rank 0 is local rank 0 ...
Rank 0 - Using CUDA device 0 of 4 ...
Rank 0 is using device cuda:0 ...
Rank 0 - LookaheadModel
Rank 0 - Loading model state ...
Rank 0 - Could not load model state.
Rank 0 - Loading averaged model state ...
Rank 0 - Could not load averaged model state.
Rank 0 - Loading optim state ...
Rank 0 - Could not load optim state.
Rank 0 - Loading metrics ...
Rank 0 - Could not load metrics.
Rank 0 - Continue with iteration 0 ...


## Ray-Script

In [5]:
def main():
    #options = options_parser()

    scaling_config = ScalingConfig(
        num_workers=num_workers,
        trainer_resources={"CPU": 0, "GPU": 2},
        resources_per_worker={
            "CPU": num_cpu_per_worker,
            "GPU": num_cpu_per_worker,
        },
        use_gpu=use_gpu,
    )

    # Initialize the Trainer.
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        #train_loop_config=options,
        scaling_config=scaling_config,
    )

    # Train the model.
    trainer.fit()


if __name__ == "__main__":
    main()


2025-09-12 13:54:24,662	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2025-09-12 13:54:29,074	INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `<FrameworkTrainer>(...)`.
2025-09-12 13:54:29,076	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
2025-09-12 13:54:29,081	INFO tensorboardx.py:193 -- pip install "ray[tune]" to see TensorBoard files.


== Status ==
Current time: 2025-09-12 13:54:29 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 0/128 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-12_13-54-19_022763_1719359/artifacts/2025-09-12_13-54-29/TorchTrainer_2025-09-12_13-54-19/driver_artifacts
Number of trials: 1/1 (1 PENDING)


[36m(autoscaler +19s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[33m(autoscaler +19s)[0m Error: No available node types can fulfill resource request {'CPU': 24.0, 'GPU': 26.0}. Add suitable node types to this cluster to resolve this issue.
== Status ==
Current time: 2025-09-12 13:54:34 (running for 00:00:05.14)
Using FIFO scheduling algorithm.
Logical resource usage: 0/128 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-12_13-54-19_022763_1719359/artifacts/2025-09-12_13-54-29/TorchTrainer_2025-09-12_13-54-19/driver

Process SpawnProcess-1:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.13/multiprocessing/process.py", line 313, in _bootstrap
    self.run()
    ~~~~~~~~^^
  File "/opt/conda/lib/python3.13/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jovyan/juart/examples/dl/../../src/juart/dl/checkpoint/manager.py", line 89, in save_checkpoint_process
    self.save_buffer_to_filesystem(*self.save_queue.get())
                                    ~~~~~~~~~~~~~~~~~~~^^
  File "/opt/conda/lib/python3.13/multiprocessing/queues.py", line 101, in get
    res = self._recv_bytes()
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
  File "/opt/conda/lib/python3.13/multiprocessing/connection.py", 

== Status ==
Current time: 2025-09-12 13:54:56 (running for 00:00:27.25)
Using FIFO scheduling algorithm.
Logical resource usage: 0/128 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:A100)
Result logdir: /tmp/ray/session_2025-09-12_13-54-19_022763_1719359/artifacts/2025-09-12_13-54-29/TorchTrainer_2025-09-12_13-54-19/driver_artifacts
Number of trials: 1/1 (1 PENDING)


