In [1]:
import glob
import logging
import os
from pathlib import Path
import shutil
import sys
import tempfile
import nibabel as nb
import pandas as pd
from tqdm import tqdm
import numpy as np

from monai.config import print_config
from monai.data import ArrayDataset, DataLoader, CacheDataset, ThreadDataLoader, decollate_batch
from monai.handlers import (
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    RandSpatialCrop,
    RandShiftIntensity,
    RandFlip,
    RandRotate90,
    RandCropByPosNegLabel,
    LoadImage,
    ScaleIntensityRange,
    CropForeground,
    Orientation,
    Spacing,
    EnsureType
)
from monai.utils import first

import ignite
from ignite.handlers import EarlyStopping
from ignite.engine import Events
from ignite.metrics import Loss
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau


root_dir = os.getcwd()
data_dir = os.path.join(root_dir, "data")

cropped_images = sorted(glob.glob(pathname=os.path.join(root_dir, "cropped_data", "volume-*.nii")))
cropped_segs = sorted(glob.glob(pathname=os.path.join(root_dir, "cropped_data", "segmentation-*.nii")))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:
# def find_range():
#     amax = -1000
#     amin = 1000
#     for idx in range(len(cropped_images)):
#         img = nb.load(cropped_images[idx])
#         img = torch.from_numpy(img.get_fdata())
#         min_value = torch.min(img)
#         max_value = torch.max(img)
#         if min_value < amin:
#             amin = min_value
#         if max_value > amax:
#             amax = max_value
#     return amax.item(), amin.item()

# amax, amin = find_range()
# print(amax, amin)

In [None]:
# from skimage import measure

# # function used to find the minimum heights for the tumors
# def find_heights():
#     results = {}
#     for idx in range(len(segs)):
#         seg = nb.load(segs[idx])
#         seg = torch.from_numpy(seg.get_fdata())
#         labeled_mask_3d = measure.label(seg, connectivity=3)
#         regions_3d = measure.regionprops(labeled_mask_3d)

#         lowest_z = 1000
#         highest_z = 0
#         for region in regions_3d:
#             low_z = region.bbox[2]
#             high_z = region.bbox[5]
#             if lowest_z > low_z:
#                 lowest_z = low_z
#             if highest_z < high_z:
#                 highest_z = high_z
#         results[idx] = (lowest_z, highest_z)
#     return results

# results = find_heights()
# print(results)

In [None]:
# def CropZ(idx):
#     #for idx in range(len(segs)):
#     img = nb.load(images[idx])
#     img_data = img.get_fdata()
#     seg = nb.load(segs[idx])
#     seg_data = seg.get_fdata()
#     low, high = results[idx]
#     if (high - low) < 96:   
#         mid = high - low
#         high = mid + 48 # 48 is half the size of the randspatialcrop
#         low = mid - 48
#         # condition if low goes out of bounds
#         if low < 0:
#             high = high - low
#             low = 0
#         img_data = img_data[:, :,low:high]
#         seg_data = seg_data[:, :,low:high]
#         cropped_img = nb.Nifti1Image(img_data, img.affine)
#         cropped_seg = nb.Nifti1Image(seg_data, seg.affine)        
#         nb.save(cropped_img, os.path.join(root_dir, "cropped_data", f"volume-{idx}.nii"))
#         nb.save(cropped_seg, os.path.join(root_dir, "cropped_data", f"segmentation-{idx}.nii"))  
#         print(f"Img Shape: {img_data.shape} | Seg Shape: {seg_data.shape}, | Low: {low} | High: {high}")
#     else:
#         img_data = img_data[:, :,low:high]
#         seg_data = seg_data[:, :,low:high]
#         cropped_img = nb.Nifti1Image(img_data, img.affine)
#         cropped_seg = nb.Nifti1Image(seg_data, seg.affine)
#         nb.save(cropped_img, os.path.join(root_dir, "cropped_data", f"volume-{idx}.nii"))
#         nb.save(cropped_seg, os.path.join(root_dir, "cropped_data", f"segmentation-{idx}.nii"))
#         print(f"Img Shape: {img_data.shape} | Seg Shape: {seg_data.shape}")

In [2]:
#numbers generated from the function above
#amax, amin = 3071.0, -3024.0
amin = -22.18
amax = 450.0

# Define transforms for image and segmentation
train_image_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        ScaleIntensityRange(
            a_min=amin,
            a_max=amax,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        #CropForeground(source_key="image"),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            #mode=("bilinear", "nearest"),
        ),
        RandSpatialCrop(
            (64,64,64), 
            random_size=False
        ),
        #EnsureType(device=device, track_meta=False),
        # RandCropByPosNegLabel(            
        #     spatial_size=(64, 64, 64),
        #     pos=1,
        #     neg=0,
        #     num_samples=1,
        #     image_threshold=0,
        # ),
        RandFlip(            
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlip(            
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlip(            
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90(           
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensity(        
            offsets=0.10,
            prob=0.50,
        ),
    ]
)

train_seg_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            #mode=("bilinear", "nearest"),
        ),
        #EnsureType(device=device, track_meta=False),
        RandSpatialCrop(
            (64,64,64), 
            random_size=False
        ),        
        # RandCropByPosNegLabel(            
        #     spatial_size=(64, 64, 64),
        #     pos=1,
        #     neg=0,
        #     num_samples=1,
        #     image_threshold=0,
        # ),
        RandFlip(            
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlip(            
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlip(            
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90(           
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensity(        
            offsets=0.10,
            prob=0.50,
        ),
    ]
)

val_image_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        ScaleIntensityRange(
            a_min=amin,
            a_max=amax,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        #CropForeground(source_key="image"),
        # RandCropByPosNegLabel(            
        #     spatial_size=(64, 64, 64),
        #     pos=1,
        #     neg=0,
        #     num_samples=1,
        #     image_threshold=0,
        # ),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            #mode=("bilinear", "nearest"),
        ),
        RandSpatialCrop(
            (64,64,64), 
            random_size=False
        ),        
        #EnsureType(device=device, track_meta=False),
    ]
)

val_seg_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            #mode=("bilinear", "nearest"),
        ),
        RandSpatialCrop(
            (64,64,64), 
            random_size=False
        ),    
        #EnsureType(device=device, track_meta=False),
    ]
)

In [3]:
def clear_folder_contents(folder_path):
    # Check if the folder exists
    if not os.path.exists(folder_path):
        print(f"The folder {folder_path} does not exist.")
        return
    
    # Loop through each item in the folder
    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        
        # Check if it's a file or directory and delete accordingly
        if os.path.isfile(item_path) or os.path.islink(item_path):
            os.unlink(item_path)  # Remove file or symbolic link
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)  # Remove directory and all contents
            
    print(f"Contents of the folder '{folder_path}' have been deleted.")

In [25]:
clear_folder_contents("logs")

Contents of the folder 'logs' have been deleted.


In [26]:
# train_dict = []
# val_dict = []
# test_dict = []

# for idx in range(len(cropped_images)):
#     if idx < 20:
#         temp = {"image": cropped_images[idx], "seg": cropped_segs[idx]}
#         train_dict.append(temp)
#     elif idx < 24:
#         temp = {"image": cropped_images[idx], "seg": cropped_segs[idx]}
#         val_dict.append(temp)
#     else:
#         temp = {"image": cropped_images[idx], "seg": cropped_segs[idx]}
#         test_dict.append(temp)        

In [27]:
# Create UNet, DiceLoss and Adam optimizer
net = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    #channels=(32, 64, 128, 256, 512),
    #channels=(64,128,256,512,1024),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

loss = DiceLoss()
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)
scheduler = ReduceLROnPlateau(opt, mode='min', patience=6, factor=0.1)

# create a training data loader
train, seg_train = cropped_images[:20], cropped_segs[:20]
val, seg_val = cropped_images[20:24], cropped_segs[20:24]
test, seg_test = cropped_images[24:], cropped_segs[24:]

train_ds = ArrayDataset(img=train, img_transform=val_image_transforms, seg=seg_train, seg_transform=val_seg_transforms)
train_loader = DataLoader(
    train_ds,
    batch_size=5,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
)

val_ds = ArrayDataset(img=val, img_transform=val_image_transforms, seg=seg_val, seg_transform=val_seg_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

test_ds = ArrayDataset(img=test, img_transform=val_image_transforms, seg=seg_test, seg_transform=val_seg_transforms)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())

In [28]:
# Create trainer
trainer = ignite.engine.create_supervised_trainer(net, opt, loss, device, False)

# optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network
# params and optimizer stats) during training
log_dir = os.path.join(root_dir, "logs")
checkpoint_handler = ignite.handlers.ModelCheckpoint(log_dir, "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
    event_name=Events.EPOCH_COMPLETED,
    handler=checkpoint_handler,
    to_save={"net": net, "opt": opt},
)

# StatsHandler prints loss at every iteration
# user can also customize print functions and can use output_transform to convert
# engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration
train_tensorboard_stats_handler = TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)

# optional section for model validation during training
validation_every_n_epochs = 1

# Set parameters for validation
metric_name = "Mean_Dice"
metric_loss = "Loss"

# add evaluation metric to the evaluator engine
val_metrics = {
    metric_name: MeanDice(),
    metric_loss: Loss(loss)
    }
post_pred = Compose([Activations(), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])

# Ignite evaluator expects batch=(img, seg) and
# returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = ignite.engine.create_supervised_evaluator(
    net,
    val_metrics,
    device,
    True,
    output_transform=lambda x, y, y_pred: (
        [post_pred(i) for i in decollate_batch(y_pred)],
        [post_label(i) for i in decollate_batch(y)],
    ),
)

@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
    evaluator.run(val_loader)

# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name="evaluator",
    # no need to print loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    log_dir=log_dir,
    # no need to plot loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_tensorboard_stats_handler.attach(evaluator)

# add handler to draw the first image and the corresponding
# label and model output in the last batch
# here we draw the 3D output as GIF format along Depth
# axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
    log_dir=log_dir,
    batch_transform=lambda batch: (batch[0], batch[1]),
    output_transform=lambda output: output[0],
    global_iter_transform=lambda x: trainer.state.epoch,
)
evaluator.add_event_handler(
    event_name=Events.EPOCH_COMPLETED,
    handler=val_tensorboard_image_handler,
)

# adding a scheduler to update learning rate
@trainer.on(Events.EPOCH_COMPLETED)
def update_sheduler(engine):
    val_loss = evaluator.state.metrics[metric_loss]
    scheduler.step(val_loss)


def score_function(engine):
    # Return the metric you want to monitor 
    return -engine.state.metrics[metric_loss]  

# adding earlystop handler
early_stopping = EarlyStopping(
    patience=10,  # Number of epochs with no improvement
    score_function=score_function,
    trainer=trainer  # Trainer to stop when condition is met
)

# Attach early stopping to the evaluator (usually validation evaluator)
evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)

# Define a test evaluator
test_evaluator = ignite.engine.create_supervised_evaluator(
    net,
    metrics={
        metric_name: MeanDice(),
        metric_loss: Loss(loss)
    },
    device=device,
)

@trainer.on(Events.COMPLETED)
def evaluate_test_set(engine):
    test_evaluator.run(test_loader)
    metrics = test_evaluator.state.metrics
    print(f"Test Mean Dice: {metrics[metric_name]:.4f} | Test Loss: {metrics[metric_loss]:.4f}")

save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

@trainer.on(Events.COMPLETED)
def save_model(engine):
    model_path = os.path.join(save_dir, "model.pth")  # Create full path
    torch.save(net.state_dict(), model_path)  # Save model weights
    print(f"Model saved to {model_path}")

In [29]:
max_epochs = 32
state = trainer.run(train_loader, max_epochs)

2024-11-28 16:13:07,424 - INFO - Epoch: 1/32, Iter: 1/4 -- Loss: 0.8883 
2024-11-28 16:13:17,364 - INFO - Epoch: 1/32, Iter: 2/4 -- Loss: 0.9974 
2024-11-28 16:13:36,301 - INFO - Epoch: 1/32, Iter: 3/4 -- Loss: 0.8953 
2024-11-28 16:13:50,436 - INFO - Epoch: 1/32, Iter: 4/4 -- Loss: 0.8943 
2024-11-28 16:14:01,313 - INFO - Epoch[1] Metrics -- Loss: 0.8541 Mean_Dice: 0.2917 
2024-11-28 16:14:38,153 - INFO - Epoch: 2/32, Iter: 1/4 -- Loss: 0.9902 
2024-11-28 16:14:38,182 - INFO - Epoch: 2/32, Iter: 2/4 -- Loss: 0.8244 
2024-11-28 16:15:05,015 - INFO - Epoch: 2/32, Iter: 3/4 -- Loss: 0.8343 
2024-11-28 16:15:05,044 - INFO - Epoch: 2/32, Iter: 4/4 -- Loss: 0.9706 
2024-11-28 16:15:15,617 - INFO - Epoch[2] Metrics -- Loss: 1.0000 Mean_Dice: 0.0000 
2024-11-28 16:15:56,307 - INFO - Epoch: 3/32, Iter: 1/4 -- Loss: 0.7483 
2024-11-28 16:15:56,337 - INFO - Epoch: 3/32, Iter: 2/4 -- Loss: 1.0000 
2024-11-28 16:16:20,251 - INFO - Epoch: 3/32, Iter: 3/4 -- Loss: 0.9337 
2024-11-28 16:16:20,278 - I

2024-11-28 16:32:05,417 ignite.handlers.early_stopping.EarlyStopping INFO: EarlyStopping: Stop training


Test Mean Dice: 0.0000 | Test Loss: 0.7679
Model saved to saved_models/model.pth


In [31]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
