Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Training a 3D ControlNet for Generating 3D Images Based on Input Masks 

![Generated image examples and input mask](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png)

In this notebook, we detail the procedure for training a 3D ControlNet to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) preparing training data, 2) training config preparation, and 3) launch training of 3D ControlNet. The subsequent sections will demonstrate the entire process using a simulated dataset. We also provide the real preprocessed dataset used in the finetuning config `environment_maisi_controlnet_train.json`. More instructions about how to preprocess real data can be found in the [README](./data/README.md) in `data` folder.

`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'rflow-ct'`. Compared with the previous version `'ddpm-ct'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"

## Setup imports

In [None]:
import copy
import json
import os
import subprocess

import nibabel as nib
import numpy as np
from monai.config import print_config
from monai.data import create_test_image_3d
from scripts.diff_model_setting import setup_logging
from IPython.display import Image, display

from scripts.train_controlnet import train_controlnet
from scripts.infer_controlnet import infer_controlnet
from scripts.download_model_data import download_model_data

print_config()

logger = setup_logging("notebook")

## Set up the MAISI version

Choose among `'ddpm-ct'`, `'rflow-ct'`, and `'rflow-mr'`. The differences are:
- The version `'ddpm-ct'` and `'rflow-ct'` generate CT images, while `'rflow-mr'` generate MR images.
- The version `'ddpm-ct'` uses basic noise scheduler DDPM. `'rflow-ct'` and `'rflow-mr'`use Rectified Flow scheduler, can be 33 times faster during inference.
- The version `'ddpm-ct'` requires training images to be labeled with body region (`"top_region_index"` and `"bottom_region_index"`), while `'rflow-ct'` does not have such requirement. In other words, it is easier to prepare training data.
- For the released model weights, `'rflow-ct'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'ddpm-ct'`.

In [None]:
generate_version = "rflow-ct"
if generate_version == "ddpm-ct":
    model_def_path = "./configs/config_network_ddpm.json"
elif generate_version == "rflow-ct" or generate_version == "rflow-mr":
    model_def_path = "./configs/config_network_rflow.json"
else:
    raise ValueError(f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}.")
with open(model_def_path, "r") as f:
    model_def = json.load(f)
include_body_region = model_def["include_body_region"]
logger.info(f"Models are {generate_version}, whether to use body_region is {include_body_region}")

## Step 1: Training Data Preparation


### Simulate a special dataset

It is widely recognized that training AI models is a time-intensive process. In this instance, we will simulate a small dataset and conduct training over multiple epochs. While the performance may not reach optimal levels due to the abbreviated training duration, the entire pipeline will be completed within minutes.

`sim_datalist` provides the information of the simulated datasets. It lists 2 training images. The size of the dimension is defined by the `sim_dim`.

The diffusion model and ControlNet utilize a U-shaped convolutional neural network architecture, requiring matching input and output dimensions. Therefore, it is advisable to resample the input image dimensions to be multiples of 2 for compatibility. In this case, we have chosen dimensions that are multiples of 128.

The training workflow requires one JSON file to specify the image embedding and segmentation pairs. In addtional, the diffusion model used in ControlNet necessitates additional input attributes, including output dimension, output spacing, and top/bottom body region. The dimensions, and spacing can be extracted from the header information of the training images. The pseudo whole-body segmentation mask, and the top/bottom body region inputs can be determined through manual examination or by utilizing segmentation masks from tools such as [TotalSegmentator](https://github.com/wasserth/TotalSegmentator) or [MONAI VISTA](https://github.com/Project-MONAI/VISTA). The body regions are formatted as 4-dimensional one-hot vectors: the head and neck region is represented by [1,0,0,0], the chest region by [0,1,0,0], the abdomen region by [0,0,1,0], and the lower body region (below the abdomen) by [0,0,0,1]. 

To train the ControlNet/diffusion unet, we first store the latent features (image embeddings) produced by the autoencoder's encoder in local storage. This allows the latent diffusion model to directly utilize these features, thereby conserving both time and GPU memory during the training process. Additionally, we have provided the script for multi-GPU processing to save latent features from all training images, significantly accelerating the creation of the entire training set. Please check the Step 1 Create Training Data in [maisi_diff_unet_training_tutorial](./maisi_diff_unet_training_tutorial.ipynb) and [diff_model_create_training_data.py](./scripts/diff_model_create_training_data.py) for how to encode images and save as image embeddings.

The JSON file used in ControlNet training has the following structure:

In [None]:
sim_dim = [256, 256, 128]
if "mr" in generate_version:
    sim_modality = "mri_t2"
else:
    sim_modality = "ct"

with open("./configs/modality_mapping.json", "r") as f:
    modality_mapping = json.load(f)
if sim_modality not in modality_mapping.keys():
    raise ValueError(f"sim_modality has to be chosen from {modality_mapping.keys()}. Yet got sim_modality={sim_modality}.")

sim_datalist = {
    "training": [
        {
            "image": "tr_image_001_emb.nii.gz",  # relative path to the image embedding file
            # relative path to the combined label (pseudo whole-body segmentation mask + ROI mask) file
            "label": "tr_label_001.nii.gz",
            "fold": 0,  # fold index for cross validation. If the dataset item's fold value is equal to the fold value
            # in config_maisi_controlnet_train.json, then it is used for validation. Otherwise, it is used for training.
            # In the current parameters, fold 0 is for validation.
            "dim": sim_dim,  # the dimension of image
            "spacing": [1.5, 1.5, 1.5],  # the spacing of image,
            "modality": sim_modality
        },
        {
            "image": "tr_image_002_emb.nii.gz",
            "label": "tr_label_002.nii.gz",
            "fold": 1,
            "dim": sim_dim,
            "spacing": [1.5, 1.5, 1.5],
            "modality": sim_modality
        },
        {
            "image": "tr_image_003_emb.nii.gz",
            "label": "tr_label_003.nii.gz",
            "fold": 1,
            "dim": sim_dim,
            "spacing": [1.5, 1.5, 1.5],
            "modality": sim_modality
        },
    ]
}
if include_body_region:
    for i in range(len(sim_datalist["training"])):
        # body region index
        sim_datalist["training"][i]["top_region_index"] = [0, 1, 0, 0]  # the top region index of the image
        sim_datalist["training"][i]["bottom_region_index"] = [0, 0, 0, 1]  # the bottom region index of the image

### Generate simulated images and labels

Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the `work_dir`.

In [None]:
os.environ["MONAI_DATA_DIRECTORY"] = "./temp_work_dir"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
work_dir = tempfile.mkdtemp() if directory is None else directory

download_model_data(generate_version,work_dir,model_only = True)

dataroot_dir = os.path.join(work_dir, "sim_controlnet_datasets")
if not os.path.isdir(dataroot_dir):
    os.makedirs(dataroot_dir)

datalist_file = os.path.join(work_dir, "sim_controlnet_datalist.json")
with open(datalist_file, "w") as f:
    json.dump(sim_datalist, f, indent=4)
logger.info(f"Save data list json file to {datalist_file}")

for d in sim_datalist["training"]:
    # The image embedding is downsampled twice by Autoencoder.
    img_emb, _ = create_test_image_3d(
        sim_dim[0] // 4,
        sim_dim[1] // 4,
        sim_dim[2] // 4,
        rad_max=10,
        num_seg_classes=25,
        random_state=np.random.RandomState(42),
    )
    # The label has a same shape as the original image.
    _, label = create_test_image_3d(
        sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=25, random_state=np.random.RandomState(42)
    )

    image_fpath = os.path.join(dataroot_dir, d["image"])
    # We repeat the volume 4 times to simulate the channel dimension of latent features.
    img_emb = np.stack([img_emb] * 4, axis=3)
    nib.save(nib.Nifti1Image(img_emb, affine=np.eye(4)), image_fpath)
    label_fpath = os.path.join(dataroot_dir, d["label"])
    nib.save(nib.Nifti1Image(label, affine=np.eye(4)), label_fpath)

logger.info("Generated simulated images.")
logger.info(f"img_emb shape: {img_emb.shape}")
logger.info(f"label shape: {label.shape}")

## Step 2: Training Config Preparation

### Set up directories and configurations

To optimize the demonstration for time efficiency, we have adjusted the training epochs to 2. Additionally, we modified the `num_splits` parameter in [AutoencoderKlMaisi](https://github.com/Project-MONAI/MONAI/blob/dev/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py#L881) from its default value of 16 to 4. This adjustment reduces the spatial splitting of feature maps in convolutions, which is particularly beneficial given the smaller input size. (This change helps convert convolutions to a for-loop based approach, thereby conserving GPU memory resources.)

In [None]:
env_config_path = f"./configs/environment_maisi_controlnet_train_{generate_version}.json"
train_config_path = f"./configs/config_maisi_controlnet_train_{generate_version}.json"

final_env_config_path = os.path.join(work_dir, "environment_maisi_controlnet_train.json")
final_train_config_path = os.path.join(work_dir, "config_maisi_controlnet_train.json")

# Load environment configuration, model configuration and model definition
with open(env_config_path, "r") as f:
    env_config = json.load(f)

with open(train_config_path, "r") as f:
    train_config = json.load(f)

with open(model_def_path, "r") as f:
    model_def = json.load(f)

env_config_out = copy.deepcopy(env_config)
train_config_out = copy.deepcopy(train_config)
model_def_out = copy.deepcopy(model_def)

# Set up directories based on configurations
env_config_out["data_base_dir"] = dataroot_dir
env_config_out["json_data_list"] = datalist_file
env_config_out["model_dir"] = os.path.join(work_dir, env_config_out["model_dir"])
env_config_out["output_dir"] = os.path.join(work_dir, env_config_out["output_dir"])
env_config_out["tfevent_path"] = os.path.join(work_dir, env_config_out["tfevent_path"])

env_config_out["exp_name"] = "tutorial_training_example"
env_config_out["trained_controlnet_path"] = f"{env_config_out['model_dir']}/{env_config_out['exp_name']}_current.pt"

# Create necessary directories
os.makedirs(env_config_out["model_dir"], exist_ok=True)
os.makedirs(env_config_out["output_dir"], exist_ok=True)
os.makedirs(env_config_out["tfevent_path"], exist_ok=True)

env_config_filepath = final_env_config_path
with open(env_config_filepath, "w") as f:
    json.dump(env_config_out, f, sort_keys=True, indent=4)

# Update training configuration for demo
max_epochs = 2
train_config_out["controlnet_train"]["n_epochs"] = max_epochs
# We disable weighted_loss for dummy data, which is used to apply more penalty
# to the region of interest (e.g., tumors). When weighted_loss=1,
# we treat all regions equally in loss computation.
train_config_out["controlnet_train"]["weighted_loss"] = 1
# We also set weighted_loss_label to None, which indicates the list of label indices that
# we want to apply more penalty during training.
train_config_out["controlnet_train"]["weighted_loss_label"] = [None]
# We set it as a small number for demo
train_config_out["controlnet_infer"]["num_inference_steps"] = 1

model_config_filepath = final_train_config_path
with open(model_config_filepath, "w") as f:
    json.dump(train_config_out, f, sort_keys=True, indent=4)

# Update model definition for demo
model_def_out["autoencoder_def"]["num_splits"] = 4
model_def_filepath = os.path.join(work_dir, "config_maisi.json")
with open(model_def_filepath, "w") as f:
    json.dump(model_def_out, f, sort_keys=True, indent=4)

# Print files and folders under work_dir
logger.info(f"files and folders under work_dir: {os.listdir(work_dir)}.")

# Adjust based on the number of GPUs you want to use
num_gpus = 1
logger.info(f"number of GPUs: {num_gpus}.")

## Step 3: Train the Model

After all latent feature/mask pairs have been created, we will initiate the multi-GPU script to train ControlNet.

The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The ControlNet is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1.

In [None]:
logger.info("Training the model...")

# Define the arguments for torchrun
module_args = {
    "env_config_path": env_config_filepath,
    "model_config_path": model_config_filepath,
    "model_def_path": model_def_filepath,
    "num_gpus": num_gpus
}
logger.info(module_args)

train_controlnet(**module_args)

## Step 4: Model Inference

Upon completing the training of the ControlNet, we can employ the multi-GPU script to perform inference. 
By integrating autoencoder, diffusion model, and controlnet, this process will generate 3D images with specified top/bottom body regions, spacing, and dimensions based on input masks.

In [None]:
logger.info("Inference...")

# Define the arguments for torchrun
module_args = {
    "env_config_path": env_config_filepath,
    "model_config_path": model_config_filepath,
    "model_def_path": model_def_filepath,
    "num_gpus": num_gpus
}
logger.info(module_args)

infer_controlnet(**module_args)

Upon completing the full training with actual CT datasets, users can expect output images similar to the examples below. Each sub-figure displays images generated by ControlNet, with synthetic tumors highlighted by overlaying tumor masks on top of the generated images. The real input condition of ControlNet consists of segmentation masks containing 127 anatomical structures. The specific output depends on the input segmentation mask. For a thorough evaluation of the entire volume, it is recommended to use visualization tools such as [ITK-SNAP](http://www.itksnap.org/pmwiki/pmwiki.php) or [3D Slicer](https://www.slicer.org/).

In [None]:
display(Image(filename="./figures/maisi_control-train-examples.png"))