### Import Modules

In [None]:
import os
import nibabel as nib
import numpy as np
import glob
import torch
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from monai.data import create_test_image_3d, Dataset, DataLoader, decollate_batch
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    RandRotate90d,
    RandFlipd,
    Activations,
    AsDiscrete
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.utils import first, set_determinism
from monai.visualize.utils import blend_images
%env CUDA_VISIBLE_DEVICES=0

### Define Functions and Classes

In [None]:
def create_data(data_dir, sim_dim = (64, 64, 64), num_images=40):
    os.makedirs(os.path.join(data_dir, "Image"), exist_ok=True)
    os.makedirs(os.path.join(data_dir, "Label"), exist_ok=True)
    for i in range(num_images):
        img, lbl = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], rad_min=3, rad_max=6, num_seg_classes=1, random_state=np.random.RandomState(42))
        n = nib.Nifti1Image(img, np.eye(4))
        nib.save(n, os.path.join(data_dir, "Image", f"{i:03d}.nii.gz"))
        n = nib.Nifti1Image(lbl, np.eye(4))
        nib.save(n, os.path.join(data_dir, "Label", f"{i:03d}.nii.gz"))
    images = sorted(glob.glob(os.path.join(data_dir, "Image", "*.nii.gz")))
    labels = sorted(glob.glob(os.path.join(data_dir, "Label", "*.nii.gz")))
    data_dicts = [
        {"image": image_name, "label": label_name}
        for image_name, label_name in zip(images, labels)
    ]
    print(f"{num_images} images and labels with {sim_dim[0]} \u00D7 {sim_dim[1]} \u00D7 {sim_dim[2]} dimensions simulated")
    return data_dicts

def load_data(data_dicts, batch_size, test_size=10):
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"], image_only=True),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityd(keys="image"),
        RandRotate90d(keys=["image", "label"], prob=0.8, spatial_axes=[0, 2]),
        RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"], image_only=True),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityd(keys="image"),
    ])
    train_files, val_files = train_test_split(data_dicts, test_size=test_size, random_state=42)
    train_ds = Dataset(data=train_files, transform=train_transforms)
    val_ds = Dataset(data=val_files, transform=val_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
    return train_loader, val_loader

def train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred):
    model.train()
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        images, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(images)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss = criterion(outputs, labels)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        epoch_loss += loss.item()
        outputs = [post_pred(i) for i in decollate_batch(outputs)]
        metric(y_pred=outputs, y=labels)
    epoch_metric = metric.aggregate().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, metric, post_pred):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            outputs = model(images)
            outputs = [post_pred(i) for i in decollate_batch(outputs)]
            metric(y_pred=outputs, y=labels)
    return metric.aggregate().item()

def train_model(model_dir, model, device, train_loader, val_loader, logger,
        criterion, metric, post_pred, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    start_time = time.time()
    best_metric = -1
    best_metric_epoch = -1
    best_model_state = None
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, metric, post_pred)
            metric_values.append(val_metric)
            if val_metric > best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best DSC: {best_metric:.4f} at epoch {best_metric_epoch}")
        epoch_end_time = time.time()
        logger.info(f"Epoch {epoch + 1} completed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - Training loss: {epoch_loss:.4f}, Training DSC: {epoch_metric:.4f}, Validation DSC: {val_metric:.4f}")
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(f"Best DSC: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    print(f"Best DSC: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(8, 5))
    axs[0].plot( [i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training DSC', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation DSC', color='blue')
    axs[1].set_title('Training DSC vs. Validation DSC')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('DSC')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)


### Prepare Inputs

In [None]:
data_dir = "Demo"
model_dir_prefix = "Demo"
sim_dim = (64, 64, 64)
num_images = 40
test_size = 10
batch_size = 5
max_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
val_interval = 1

model_dir = f"{model_dir_prefix}_UNet"
os.makedirs(model_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Device:', device)
log_file = os.path.join(model_dir, "Prediction.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
set_determinism(seed=0)
data_dicts = create_data(data_dir, sim_dim, num_images)
train_loader, val_loader = load_data(data_dicts, batch_size, test_size)

# Check data shape
tr = first(train_loader)
print('\nData shape for training:')
for key, value in tr.items():
    print(f'\u2022 {key}: {list(value.shape)} \u00D7 {len(train_loader)}')
vl = first(val_loader)
print('\nData shape for validation:')
for key, value in vl.items():
    print(f'\u2022 {key}: {list(value.shape)} \u00D7 {len(val_loader)}')

# Visualize data
_, axs = plt.subplots(1, 3, figsize=(12, 5))
image = tr["image"][0, :, :, :, :].detach().cpu()
label = tr["label"][0, :, :, :, :].detach().cpu()
middle_index = image.shape[3] // 2
image_slice = image[0, :, :, middle_index]
label_slice = label[0, :, :, middle_index]
blended = blend_images(image, label, alpha=0.5)
blended_slice = blended[0, :, :, middle_index]
axs[0].imshow(image_slice, cmap='gray')
axs[0].set_title("Image")
axs[0].axis('off')
axs[1].imshow(label_slice, cmap='gray')
axs[1].set_title("Label")
axs[1].axis('off')
axs[2].imshow(blended_slice, cmap='gray')
axs[2].set_title("Label-overlaid Image")
axs[2].axis('off')
plt.tight_layout()
plt.show()

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
criterion = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, sigmoid=True)
metric = DiceMetric(include_background=True, reduction="mean")
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(model_dir, model, device, train_loader, val_loader, logger,
    criterion, metric, post_pred, max_epochs, learning_rate, weight_decay, val_interval)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_index = 5
slice_index = 16
model.eval()
metric.reset()
with torch.no_grad():
    image = val_loader.dataset[sample_index]["image"].to(device)
    label = val_loader.dataset[sample_index]["label"].to(device)
    output = model(image.unsqueeze(0))
    output = post_pred(output.squeeze(0))
    metric(y_pred=output, y=label)
metric_value = metric.aggregate().item()
_, axs = plt.subplots(1, 3, figsize=(12, 5))
axs[0].imshow(image[0,:, :, slice_index].detach().cpu(), cmap="gray")
axs[0].set_title("Image")
axs[0].axis('off')
axs[1].imshow(label[0,:, :, slice_index].detach().cpu(), cmap="gray")
axs[1].set_title("Label")
axs[1].axis('off')
axs[2].imshow(output[0,:, :, slice_index].detach().cpu(), cmap="gray")
axs[2].set_title(f"Predicted Label: DSC = {metric_value:.3f}")
axs[2].axis('off')
plt.tight_layout()
plt.show()