Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# 3D regression example based on DenseNet

This tutorial shows an example of 3D regression task based on DenseNet and array format transforms.

Here, the task is given to predict the ages of subjects from MR imagee.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_regression/densenet_training_array.ipynb)

## Setup environment

In [1]:
# !python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"

## Setup imports

In [2]:
import logging
import os
import sys
import shutil
import tempfile

import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    EnsureChannelFirstd,
    Compose,
    RandRotate90,
    Resize,
    ScaleIntensity,
)
from monai.networks.nets import Regressor

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.3.2
Numpy version: 1.26.0
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /simurgh/u/<username>/miniconda/envs/monai/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.23.2
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.18.1+cu121
tqdm version: 4.66.4
lmdb version: 1.5.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: 4.40.2
mlflow version: 2.14.3
pynrrd version: 1.0.0
clearml version: 1.16.3rc2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Setup data directory

In [3]:
# # Set data directory
# directory = os.environ.get("MONAI_DATA_DIRECTORY")
# if directory is not None:
#     os.makedirs(directory, exist_ok=True)
# root_dir = "/simurgh/group/tmp"
# print(root_dir)

In [4]:
# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
# images = [
#     os.sep.join([root_dir, "ixi", "IXI314-IOP-0889-T1.nii.gz"]),
#     os.sep.join([root_dir, "ixi", "IXI249-Guys-1072-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI609-HH-2600-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI173-HH-1590-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI020-Guys-0700-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI342-Guys-0909-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI134-Guys-0780-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI577-HH-2661-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI066-Guys-0731-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI130-HH-1528-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI607-Guys-1097-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI175-HH-1570-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI385-HH-2078-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI344-Guys-0905-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI409-Guys-0960-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI584-Guys-1129-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI253-HH-1694-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI092-HH-1436-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI574-IOP-1156-T1.nii.gz"]),
    # os.sep.join([root_dir, "ixi", "IXI585-Guys-1130-T1.nii.gz"]),
# ]
# print(images)
# # ages of subjects
# ages = np.array(
# #     [
#         45.86,
#         68.27,
#         # 29.0,
#         # 29.57,
#         # 39.47,
#         # 48.68,
#         # 47.35,
#         # 64.19,
#         # 46.17,
#         # 38.77,
#         # 83.81,
#         # 72.27,
#         # 64.65,
#         # 62.09,
#         # 70.95,
#         # 41.33,
#         # 24.0,
#         # 33.24,
#         # 50.57,
        # 28.12,
#     ]
# )

In [5]:
# if not os.path.isfile(images[0]):
#     resource = "http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar"
#     md5 = "34901a0593b41dd19c1a1f746eac2d58"

#     dataset_dir = os.path.join(root_dir, "ixi")
#     tarfile_name = f"{dataset_dir}.tar"

#     download_and_extract(resource, tarfile_name, dataset_dir, md5)

In [6]:
import pickle
import numpy as np
dataset_names=["/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/abcd/paths_and_info_flexpath.pkl",
               "/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/adni_t1/paths_and_info_flexpath.pkl",
               "/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/hcp_ag_t1/paths_and_info_flexpath.pkl",
               "/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/hcp_dev_t1/paths_and_info_flexpath.pkl",
               "/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/hcp_ya_mpr1/paths_and_info_flexpath.pkl",
               "/simurgh/u/fangruih/monai-tutorials/generative/3d_ldm/metadata/ppmi_t1/paths_and_info_flexpath.pkl"]
train_images=[]
train_ages=[]
val_images=[]
val_ages=[]
for dataset_name in dataset_names:
    with open(dataset_name, 'rb') as file:
        data = pickle.load(file)
        
        # Convert paths and ages to lists if they are NumPy arrays
        train_new_images = data['train']['paths'].tolist() if isinstance(data['train']['paths'], np.ndarray) else data['train']['paths']
        train_new_ages = data['train']['age'].tolist() if isinstance(data['train']['age'], np.ndarray) else data['train']['age']
        
        val_new_images = data['val']['paths'].tolist() if isinstance(data['val']['paths'], np.ndarray) else data['val']['paths']
        val_new_ages = data['val']['age'].tolist() if isinstance(data['val']['age'], np.ndarray) else data['val']['age']
        
        # Append new data to existing lists
        if not train_images:  # More Pythonic way to check if the list is empty
            # Direct assignment for the first file
            train_images = train_new_images
            train_ages = train_new_ages
            val_images = val_new_images
            val_ages = val_new_ages
        else:
            # Concatenation for subsequent files
            train_images += train_new_images
            train_ages += train_new_ages
            val_images += val_new_images
            val_ages += val_new_ages
        
        # Debug output to check the results
        print(train_images[-1])  # Print the last path
        
prefix = "/scr/fangruih/stru/"
train_images = [prefix + train_image for train_image in train_images]
val_images = [prefix + val_image for val_image in val_images]

print(len(train_images))  # Print the total number of paths loaded
print(len(train_ages))  # Print the total number of paths loaded

print(len(val_images))  # Print the total number of paths loaded
print(len(val_ages))  # Print the total number of paths loaded


t1/abcd/sub-NDARINVJLFDX0WJ/ses-2YearFollowUpYArm1/anat/sub-NDARINVJLFDX0WJ_ses-2YearFollowUpYArm1_run-01_T1w_nrm_crp.npy
t1/adni/035_S_0292/MT1__GradWarp__N3m/2013-05-01_14_31_36.0/I371451/ADNI_035_S_0292_MR_MT1__GradWarp__N3m_Br_20130507153329708_S188641_I371451_nrm_crp.npy
t1/hcp_aging/HCA6633069_V1_MR/T1w_nrm_crp.npy
t1/hcp_dev/HCD0797877_V1_MR/T1w_nrm_crp.npy
t1/hcp_ya_mpr1/169343/169343_3T_T1w_MPR1_nrm_crp.npy
t1/ppmi/51689/T1-anatomical/2015-12-08_14_41_38.0/I696900/PPMI_51689_MR_T1-anatomical_Br_20160429191726566_S360889_I696900_nrm_crp.npy
21051
21051
6015
6015


In [7]:
import numpy as np

# Path to the .npy file
file_path = "/scr/fangruih/stru/t1/hcp_ya_mpr1/169343/169343_3T_T1w_MPR1_nrm_crp.npy"
# Load the numpy file
data = np.load(file_path)

# Print the dimensions of the loaded data
print("Dimensions of the loaded data:", data.shape)

Dimensions of the loaded data: (1, 160, 192, 176)


## Create data loaders

In [8]:
# Define transforms
# train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(channel_dim=0), Resize((148,180,148)), RandRotate90()])
train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(channel_dim=0), Resize((148,180,148))])
val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(channel_dim=0),Resize((148,180,148))])

# train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96)), RandRotate90()])
# val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])

# Define nifti dataset, data loader
check_ds = ImageDataset(image_files=train_images, labels=train_ages, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=3, num_workers=2, pin_memory=pin_memory)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

# create a training data loader
train_ds = ImageDataset(image_files=train_images, labels=train_ages, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=pin_memory)

# create a validation data loader
val_ds = ImageDataset(image_files=val_images, labels=val_ages, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=pin_memory)

<class 'monai.data.meta_tensor.MetaTensor'> torch.Size([3, 1, 148, 180, 148]) tensor([13.4167, 11.2500,  9.5000], dtype=torch.float64) torch.Size([3])


## Create model and train

In [None]:
import sys
import torch
import wandb
import numpy as np
from torch.nn import MSELoss
from torch.optim import Adam
# from torch.utils.tensorboard import SummaryWriter
# from model import Regressor  # Assuming this is your model's import statement.# Initialize wandb
wandb.init(project="age-regressor")

# Setup the model
model = Regressor(in_shape=[1,148,180,148], out_shape=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2))
if torch.cuda.is_available():
    device = torch.device("cuda")
    model.to(device)
else:
    device = torch.device("cpu")

# Loss and optimizer
loss_function = MSELoss()
optimizer = Adam(model.parameters(), 1e-4)

# Training settings
val_interval = 2
max_epochs = 100
best_metric = sys.float_info.max
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        if step %100==0:
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")

        # Log training loss every 100 stepsif step % 100 == 0:
    wandb.log({"train_loss": loss.item(), "step": epoch * len(train_loader) + step})

    # Average loss for epoch
    epoch_loss /= len(train_loader)
    wandb.log({"epoch_loss": epoch_loss, "epoch": epoch})

    # Validationif (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
        val_losses = []
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            val_outputs = model(val_images)
            val_loss = loss_function(val_outputs, val_labels.float())
            val_losses.append(val_loss.item())
        
        avg_val_loss = sum(val_losses) / len(val_losses)
        wandb.log({"val_loss": avg_val_loss, "epoch": epoch})

        # Check if this is the best modelif avg_val_loss < best_metric:
        best_metric = avg_val_loss
        torch.save(model.state_dict(), "best_metric_model.pth")
        print("Saved new best model with loss:", best_metric)

print("Training completed. Best validation loss:", best_metric)
wandb.finish()


[34m[1mwandb[0m: Currently logged in as: [33mfangruih[0m ([33mfangruih-Stanford University[0m). Use [1m`wandb login --relogin`[0m to force relogin


----------
epoch 1/100


Using a target size (torch.Size([2])) that is different to the input size (torch.Size([2, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


100/10525, train_loss: 1849.5593
200/10525, train_loss: 261.1571
300/10525, train_loss: 191.5856
400/10525, train_loss: 817.9634
500/10525, train_loss: 1177.3065
600/10525, train_loss: 194.0171
700/10525, train_loss: 518.8301
800/10525, train_loss: 1796.3740
900/10525, train_loss: 9.4253
1000/10525, train_loss: 35.8746
1100/10525, train_loss: 1414.2993
1200/10525, train_loss: 151.4894
1300/10525, train_loss: 1063.0049
1400/10525, train_loss: 115.1751


## Cleanup data directory

Remove directory if a temporary was used.

In [33]:
model = Regressor(in_shape=[1,148,180,148], out_shape=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2))
if torch.cuda.is_available():
    model.cuda()
# It is important that we use nn.MSELoss for regression.
loss_function = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter()
max_epochs = 5

lowest_rmse = sys.float_info.max
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        if step %100==0:
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        all_labels = []
        all_val_outputs = []
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            all_labels.extend(val_labels.cpu().detach().numpy())
            with torch.no_grad():
                val_outputs = model(val_images)
                flattened_val_outputs = [val for sublist in val_outputs.cpu().detach().numpy() for val in sublist]
                all_val_outputs.extend(flattened_val_outputs)

        mse = np.square(np.subtract(all_labels, all_val_outputs)).mean()
        rmse = np.sqrt(mse)

        if rmse < lowest_rmse:
            lowest_rmse = rmse
            lowest_rmse_epoch = epoch + 1
            torch.save(model.state_dict(), "best_metric_model_classification3d_array.pth")
            print("saved new best metric model")

        print(f"Current epoch: {epoch+1} current RMSE: {rmse:.4f} ")
        print(f"Best RMSE: {lowest_rmse:.4f} at epoch {lowest_rmse_epoch}")
        writer.add_scalar("val_rmse", rmse, epoch + 1)

print(f"Training completed, lowest_rmse: {lowest_rmse:.4f} at epoch: {lowest_rmse_epoch}")
writer.close()

----------
epoch 1/5


Using a target size (torch.Size([2])) that is different to the input size (torch.Size([2, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


100/10525, train_loss: 1300.1602
200/10525, train_loss: 91.2725


KeyboardInterrupt: 

In [None]:
# if directory is None:
#     shutil.rmtree(root_dir)