In [1]:
import os
import json
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import monai
from monai.networks.layers import Norm

# MONAI imports (optional but recommended)
from monai.transforms import (
    Activations, AsDiscrete, Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
    Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd,
    EnsureTyped, ToTensord, NormalizeIntensityd
)
from monai.networks.nets import UNet
from monai.losses import DiceCELoss, DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
# from monai.data import decollate_batch

In [2]:
def load_dataset_json(json_path):
    with open(json_path, 'r') as f:
        dataset = json.load(f)
    return dataset

data_folder_dataset = "/Users/chufal/Downloads/DHAI_capstone_project/data_brainMRI_Segmentation/dataset.json"
dataset_info = load_dataset_json(data_folder_dataset)
base_dir = "/Users/chufal/Downloads/DHAI_capstone_project/data_brainMRI_Segmentation/"

training_files = dataset_info["training"] # This is a list of {"image": "path", "label": "path"}
test_files_paths = dataset_info["test"] # This is a list of image paths

# For inference, create a list of dictionaries similar to training_files
test_files = [{"image": path} for path in test_files_paths]

# Split training_files into actual training and validation sets
train_list, val_list = train_test_split(training_files, test_size=0.2, random_state=42)

print(f"Number of training samples: {len(train_list)}")
print(f"Number of validation samples: {len(val_list)}")
print(f"Number of test samples: {len(test_files)}")

Number of training samples: 387
Number of validation samples: 97
Number of test samples: 266


In [3]:
class BratsDataset(Dataset):
    def __init__(self, data_list, base_dir=base_dir, transform=None):
        """
        Args:
            data_list (list of dicts): List of {"image": path, "label": path} or {"image": path} for test.
            base_dir (str): Base directory where image/label paths are relative to.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_list = data_list
        self.base_dir = base_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        item = self.data_list[idx]
        img_path = os.path.join(self.base_dir, item["image"])

        # Load image
        img_nib = nib.load(img_path)
        img_data = img_nib.get_fdata(dtype=np.float32) # (H, W, D, Modalities)

        # Permute to (Modalities, H, W, D) - adjust if your data loads differently
        img_data = np.moveaxis(img_data, -1, 0) # Now (Modalities, H, W, D)

        sample = {"image": img_data}

        if "label" in item:
            label_path = os.path.join(self.base_dir, item["label"])
            label_nib = nib.load(label_path)
            label_data = label_nib.get_fdata(dtype=np.uint8) # (H, W, D)
            # Add channel dimension for label, making it (1, H, W, D) for MONAI transforms
            label_data = np.expand_dims(label_data, axis=0)
            sample["label"] = label_data

        if self.transform:
            sample = self.transform(sample)

        return sample

In [4]:
# Define target patch size and other parameters
TARGET_SPACING = (1.0, 1.0, 1.0) # Example, adjust based on your data
PATCH_SIZE = (128, 128, 128)    # Example, adjust based on GPU memory
NUM_SAMPLES_PER_IMAGE = 4       # For RandCropByPosNegLabeld

# Define distinct keys for MONAI transforms
image_keys = ["image"]
label_keys = ["label"]
all_keys = ["image", "label"]

# Example: Placeholder values - you need to calculate these from your training data
# means = [mean_flair, mean_t1w, mean_t1gd, mean_t2w]
# stds = [std_flair, std_t1w, std_t1gd, std_t2w]

# Example calculated means and stds (replace with your actual dataset's values)
# These would typically be calculated on the intensity values *after* any initial clipping or
# foreground cropping if you want to normalize the relevant tissue intensities.
# For simplicity, let's assume you've calculated them.
means = [100.0, 150.0, 120.0, 130.0] # Replace with actual means for FLAIR, T1w, t1gd, T2w
stds = [50.0, 60.0, 55.0, 65.0]    # Replace with actual stds

# Then, in your transforms:
train_transforms = Compose([
    LoadImaged(keys=all_keys, reader="NibabelReader", image_only=False),
    EnsureChannelFirstd(keys=all_keys),
    # Replace ScaleIntensityRanged with NormalizeIntensityd
    NormalizeIntensityd(keys="image", subtrahend=means, divisor=stds, channel_wise=True),
    Spacingd(keys=all_keys, pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
    Orientationd(keys=all_keys, axcodes="RAS"),
    CropForegroundd(keys=all_keys, source_key="image", margin=10),
    RandCropByPosNegLabeld(
        keys=all_keys,
        label_key="label",
        spatial_size=PATCH_SIZE,
        pos=1, neg=1,
        num_samples=NUM_SAMPLES_PER_IMAGE,
        image_key="image",
        image_threshold=0,
    ),
    RandFlipd(keys=all_keys, prob=0.5, spatial_axis=0),
    RandFlipd(keys=all_keys, prob=0.5, spatial_axis=1),
    RandFlipd(keys=all_keys, prob=0.5, spatial_axis=2),
    RandRotate90d(keys=all_keys, prob=0.5, max_k=3, spatial_axes=(0, 1)),
    RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), # This might be less effective after Z-score
    EnsureTyped(keys=all_keys, data_type="tensor"),
])

val_transforms = Compose([
    LoadImaged(keys=all_keys, reader="NibabelReader", image_only=False),
    EnsureChannelFirstd(keys=all_keys),
    NormalizeIntensityd(keys="image", subtrahend=means, divisor=stds, channel_wise=True),
    Spacingd(keys=all_keys, pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
    Orientationd(keys=all_keys, axcodes="RAS"),
    CropForegroundd(keys=all_keys, source_key="image", margin=10),
    EnsureTyped(keys=all_keys, data_type="tensor"),
])

# If using the custom BratsDataset with nibabel:
# from monai.data import Dataset as MonaiDataset # To use MONAI's Dataset with custom loader
# train_ds = MonaiDataset(data=train_list, transform=train_transforms) # MONAI transforms expect dicts
# val_ds = MonaiDataset(data=val_list, transform=val_transforms)

# If using MONAI's built-in way of handling file paths (recommended):
# Create dictionaries with absolute paths for MONAI's LoadImaged
def create_data_list_for_monai(file_list, base_dir=base_dir):
    monai_list = []
    for item in file_list:
        entry = {
            "image": os.path.join(base_dir, item["image"]),
        }
        if "label" in item: # For train/val
            entry["label"] = os.path.join(base_dir, item["label"])
        monai_list.append(entry)
    return monai_list

# Assuming 'base_dir' is the root directory where 'imagesTr', 'labelsTr', 'imagesTs' are.
# If your paths in dataset.json are like "./imagesTr/...", then base_dir should be where dataset.json is,
# or adjust paths accordingly. For this example, let's assume paths are relative to current dir.
monai_train_list = create_data_list_for_monai(train_list, base_dir=base_dir)
monai_val_list = create_data_list_for_monai(val_list, base_dir=base_dir)

train_ds = monai.data.Dataset(data=monai_train_list, transform=train_transforms)
val_ds = monai.data.Dataset(data=monai_val_list, transform=val_transforms)

# DataLoaders
# If RandCropByPosNegLabeld is used, it creates multiple samples from one image,
# so the effective batch size increases.
# If RandCropByPosNegLabeld returns num_samples, the DataLoader batch_size is effectively 1
# and then these samples are collated.
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)



In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BraTS has 4 input modalities and 4 output classes (0:bg, 1:edema, 2:non-enhancing, 3:enhancing)
in_channels = 4 # FLAIR, T1w, t1gd, T2w
out_channels = 4 # Number of segmentation classes

model = UNet(
    spatial_dims=3,
    in_channels=in_channels,
    out_channels=out_channels,
    channels=(16, 32, 64, 128, 256), # Channel sequence for encoder
    strides=(2, 2, 2, 2),             # Strides for downsampling
    num_res_units=2,                  # Number of residual units in each block
    norm=Norm.BATCH,                  # Normalization type
).to(device)

# print(model) # To see the model structure

In [6]:
# Loss function: DiceCELoss combines Dice and CrossEntropy
# It expects raw logits from the model (N, C, D, H, W)
# and integer labels (N, 1, D, H, W) or (N, D, H, W)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, include_background=True)
# If your labels are already (N, 1, D, H, W) and you don't want one-hot for CE part, adjust.
# For Dice, labels can be (N, 1, D, H, W).
# For CE in DiceCELoss, it expects class indices, so (N, 1, D, H, W) is fine and will be squeezed.

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

# (Optional) Learning rate scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

# Quick Test (if validation works or not)
### Uncomment below cell for quick testing

In [7]:
# import torch
# from monai.metrics import DiceMetric
# from monai.inferers import sliding_window_inference
# from monai.transforms import Activations, AsDiscrete, Compose
# # from monai.data.utils import decollate_batch # Kept for reference

# # --- Assumed to be defined elsewhere ---
# # model: your neural network model
# # device: torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # train_loader: DataLoader for training data
# # val_loader: DataLoader for validation data
# # optimizer: your optimizer (e.g., torch.optim.AdamW)
# # loss_function: your loss criterion (e.g., DiceCELoss)
# # PATCH_SIZE: tuple, e.g., (128, 128, 128), for sliding_window_inference
# # out_channels: int, number of output classes for segmentation
# # NUM_SAMPLES_PER_IMAGE: int, (not directly used in loop but good for context)
# # -----------------------------------------

# # --- TEMPORARY CHANGES FOR QUICK TESTING ---
# num_epochs = 1     # Run only one epoch for testing
# val_interval = 1   # Validate after every epoch (so, after this 1st test epoch)
# max_train_steps_for_test = 5  # Process only a few batches in training
# max_val_steps_for_test = 3    # Process only a few batches in validation
# # --- END OF TEMPORARY CHANGES ---

# best_metric = -1.0 # Initialize with a float
# best_metric_epoch = -1

# # For validation metrics
# # Using "mean_batch" to get per-class scores first, then we will average them.
# dice_metric = DiceMetric(include_background=True, reduction="mean_batch")

# # Define post-processing for validation metrics
# post_pred_validation = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=out_channels)])
# post_label_validation = Compose([AsDiscrete(to_onehot=out_channels)])

# for epoch in range(num_epochs): # This loop will run once for epoch = 0
#     print("-" * 10)
#     print(f"Epoch {epoch + 1}/{num_epochs} (QUICK TEST RUN)")
#     model.train()
#     epoch_loss_accumulator = 0.0
#     train_step_count = 0

#     for batch_data in train_loader:
#         train_step_count += 1
        
#         if isinstance(batch_data, list):
#             inputs = torch.cat([sample_dict["image"] for sample_dict in batch_data], dim=0).to(device)
#             labels = torch.cat([sample_dict["label"] for sample_dict in batch_data], dim=0).to(device)
#         elif isinstance(batch_data, dict):
#             inputs = batch_data["image"].to(device)
#             labels = batch_data["label"].to(device)
#         else:
#             raise TypeError(f"train_loader produced an unsupported batch_data type: {type(batch_data)}")

#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = loss_function(outputs, labels)
#         loss.backward()
#         optimizer.step()
#         epoch_loss_accumulator += loss.item()
#         # Updated print statement to reflect potential capping for test
#         print(f"{train_step_count}/{len(train_loader) if max_train_steps_for_test is None else max_train_steps_for_test}, Train_loss: {loss.item():.4f}")

#         # --- TEMPORARY: Break after a few training steps for quick test ---
#         if train_step_count >= max_train_steps_for_test:
#             print(f"DEBUG: Stopping training early after {max_train_steps_for_test} steps for quick test.")
#             break
#         # ---

#     if train_step_count > 0:
#         avg_epoch_loss = epoch_loss_accumulator / train_step_count
#         print(f"Epoch {epoch + 1} average training loss (quick test): {avg_epoch_loss:.4f}")
#     else:
#         print(f"Epoch {epoch + 1} training (quick test): No data processed.")

#     # Validation (will run because num_epochs=1 and val_interval=1)
#     if (epoch + 1) % val_interval == 0:
#         model.eval()
#         with torch.no_grad():
#             val_epoch_loss_accumulator = 0.0
#             val_step_count = 0
#             dice_metric.reset() # Reset metric at the start of each validation epoch's accumulation

#             for val_batch_data in val_loader:
#                 val_step_count += 1
                
#                 if isinstance(val_batch_data, list):
#                     val_inputs = torch.cat([sample_dict["image"] for sample_dict in val_batch_data], dim=0).to(device)
#                     val_labels = torch.cat([sample_dict["label"] for sample_dict in val_batch_data], dim=0).to(device)
#                 elif isinstance(val_batch_data, dict):
#                     val_inputs = val_batch_data["image"].to(device)
#                     val_labels = val_batch_data["label"].to(device)
#                 else:
#                     raise TypeError(f"val_loader produced an unsupported val_batch_data type: {type(val_batch_data)}")
        
#                 val_outputs = sliding_window_inference(val_inputs, PATCH_SIZE, 4, model, overlap=0.5)
        
#                 v_loss = loss_function(val_outputs, val_labels)
#                 val_epoch_loss_accumulator += v_loss.item()
        
#                 processed_val_outputs_for_metric = []
#                 processed_val_labels_for_metric = []
        
#                 for i in range(val_outputs.shape[0]): # Iterate over batch dimension
#                     output_item = val_outputs[i] 
#                     label_item = val_labels[i]   
        
#                     processed_output_item = post_pred_validation(output_item)
#                     processed_label_item = post_label_validation(label_item)
                    
#                     processed_val_outputs_for_metric.append(processed_output_item)
#                     processed_val_labels_for_metric.append(processed_label_item)
        
#                 dice_metric(y_pred=processed_val_outputs_for_metric, y=processed_val_labels_for_metric)

#                 # --- TEMPORARY: Break after a few validation steps for quick test ---
#                 if val_step_count >= max_val_steps_for_test:
#                     print(f"DEBUG: Stopping validation early after {max_val_steps_for_test} steps for quick test.")
#                     break
#                 # ---
            
#             if val_step_count > 0:
#                 per_class_dice_scores = dice_metric.aggregate() # This will be a tensor with 4 elements
                
#                 mean_dice = 0.0 # Default if no scores
#                 if per_class_dice_scores.numel() > 0:
#                     # Option 2: Calculate the mean of these per-class scores for a single metric
#                     mean_dice = torch.mean(per_class_dice_scores).item()
#                     # You can also print the per_class_dice_scores if you want to see them:
#                     print(f"DEBUG: Per-class Dice scores (quick test): {per_class_dice_scores.tolist()}")
#                 else:
#                     print("Warning: per_class_dice_scores tensor was empty during validation (quick test).")

#                 avg_val_epoch_loss = val_epoch_loss_accumulator / val_step_count
#                 print(f"Epoch {epoch + 1} Validation (quick test) avg loss: {avg_val_epoch_loss:.4f}, Mean Dice: {mean_dice:.4f}")

#                 # During a quick test, you might want to skip saving the model or save to a different file
#                 if mean_dice > best_metric: # This logic still applies for tracking improvement during test
#                     best_metric = mean_dice
#                     best_metric_epoch = epoch + 1
#                     # torch.save(model.state_dict(), "quick_test_best_model.pth") # Optional: save to a test file
#                     print("DEBUG: New best metric in quick test run.")
                
#                 print(f"DEBUG: Current best Dice in quick test: {best_metric:.4f} at epoch {best_metric_epoch}")
#             else:
#                 print(f"Epoch {epoch + 1} Validation (quick test): No data processed.")

# # For the quick test, this line might not reflect a full training
# # print(f"Training completed. Best Mean Dice: {best_metric:.4f} at epoch {best_metric_epoch}")
# print(f"Quick test run completed. Last Mean Dice: {mean_dice if 'mean_dice' in locals() else 'N/A'}")

# Final Training

In [None]:
# --- Assumed to be defined elsewhere ---
# model: your neural network model
# device: torch.device("cuda" if torch.cuda.is_available() else "cpu")
# train_loader: DataLoader for training data
# val_loader: DataLoader for validation data
# optimizer: your optimizer (e.g., torch.optim.AdamW)
# loss_function: your loss criterion (e.g., DiceCELoss)
# PATCH_SIZE: tuple, e.g., (128, 128, 128), for sliding_window_inference
# out_channels: int, number of output classes for segmentation
# NUM_SAMPLES_PER_IMAGE: int, if used with RandCropByPosNegLabeld (for logging/understanding batch sizes)
# -----------------------------------------

num_epochs = 100  # Adjust as needed
val_interval = 2  # Validate every 2 epochs
best_metric = -1.0 # Initialize with a float
best_metric_epoch = -1

# For validation metrics
dice_metric = DiceMetric(include_background=True, reduction="mean_batch")

# Define post-processing for validation metrics
post_pred_validation = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=out_channels)])
post_label_validation = Compose([AsDiscrete(to_onehot=out_channels)])

for epoch in range(num_epochs):
    print("-" * 10)
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss_accumulator = 0.0
    train_step_count = 0

    for batch_data in train_loader:
        train_step_count += 1
        
        if isinstance(batch_data, list):
            inputs = torch.cat([sample_dict["image"] for sample_dict in batch_data], dim=0).to(device)
            labels = torch.cat([sample_dict["label"] for sample_dict in batch_data], dim=0).to(device)
        elif isinstance(batch_data, dict):
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
        else:
            raise TypeError(f"train_loader produced an unsupported batch_data type: {type(batch_data)}")

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss_accumulator += loss.item()
        print(f"{train_step_count}/{len(train_loader)}, Train_loss: {loss.item():.4f}")

    if train_step_count > 0:
        avg_epoch_loss = epoch_loss_accumulator / train_step_count
        print(f"Epoch {epoch + 1} average training loss: {avg_epoch_loss:.4f}")
    else:
        print(f"Epoch {epoch + 1} training: No data processed.")

    # Validation
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_epoch_loss_accumulator = 0.0  # Correctly initialized here
            val_step_count = 0                # Correctly initialized here
            dice_metric.reset() # Reset metric at the start of each validation epoch's accumulation

            for val_batch_data in val_loader:
                val_step_count += 1
                
                if isinstance(val_batch_data, list):
                    val_inputs = torch.cat([sample_dict["image"] for sample_dict in val_batch_data], dim=0).to(device)
                    val_labels = torch.cat([sample_dict["label"] for sample_dict in val_batch_data], dim=0).to(device)
                elif isinstance(val_batch_data, dict):
                    val_inputs = val_batch_data["image"].to(device)
                    val_labels = val_batch_data["label"].to(device)
                else:
                    raise TypeError(f"val_loader produced an unsupported val_batch_data type: {type(val_batch_data)}")
        
                # Ensure PATCH_SIZE is defined correctly for your model and data
                val_outputs = sliding_window_inference(val_inputs, PATCH_SIZE, 4, model, overlap=0.5)
        
                v_loss = loss_function(val_outputs, val_labels)
                val_epoch_loss_accumulator += v_loss.item()
        
                processed_val_outputs_for_metric = []
                processed_val_labels_for_metric = []
        
                for i in range(val_outputs.shape[0]): # Iterate over batch dimension
                    output_item = val_outputs[i] 
                    label_item = val_labels[i]   
        
                    processed_output_item = post_pred_validation(output_item)
                    processed_label_item = post_label_validation(label_item)
                    
                    processed_val_outputs_for_metric.append(processed_output_item)
                    processed_val_labels_for_metric.append(processed_label_item)
        
                dice_metric(y_pred=processed_val_outputs_for_metric, y=processed_val_labels_for_metric)
            
            # Aggregate metrics and loss after iterating through all validation batches
            if val_step_count > 0:
                per_class_dice_scores = dice_metric.aggregate() # This is your tensor with 4 elements
                
                # Calculate the mean of these per-class scores
                if per_class_dice_scores.numel() > 0 : # Check if tensor is not empty
                    # If you want the mean of all 4 class Dice scores:
                    mean_dice = torch.mean(per_class_dice_scores).item()
                    
                    # Alternatively, if you want to average only foreground classes (e.g., classes 1, 2, 3)
                    # Assuming class 0 is background and per_class_dice_scores has scores in order [bg, c1, c2, c3]
                    # if per_class_dice_scores.shape[0] == out_channels and out_channels > 1: # Ensure it has expected num classes
                    #     mean_dice = torch.mean(per_class_dice_scores[1:]).item() # Mean of elements from index 1 onwards
                    # else:
                    #     mean_dice = torch.mean(per_class_dice_scores).item() # Fallback or if only one class
                else:
                    mean_dice = 0.0 # Or some other default if no scores
                    print("Warning: per_class_dice_scores tensor was empty.")

                # dice_metric.reset() # Resetting here is also fine if aggregate() finalizes the epoch's metric.
                                      # Moved reset to the beginning of validation for clarity per epoch.

                avg_val_epoch_loss = val_epoch_loss_accumulator / val_step_count
                print(f"Epoch {epoch + 1} Validation avg loss: {avg_val_epoch_loss:.4f}, Mean Dice: {mean_dice:.4f}")

                if mean_dice > best_metric:
                    best_metric = mean_dice
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("Saved new best metric model")
                
                print(f"Current best Dice: {best_metric:.4f} at epoch {best_metric_epoch}")
            else:
                print(f"Epoch {epoch + 1} Validation: No data processed.")

print(f"Training completed. Best Mean Dice: {best_metric:.4f} at epoch {best_metric_epoch}")

----------
Epoch 1/100
1/387, Train_loss: 2.6404
2/387, Train_loss: 2.6204
3/387, Train_loss: 2.6420
4/387, Train_loss: 2.6041
5/387, Train_loss: 2.5515
6/387, Train_loss: 2.5326
7/387, Train_loss: 2.5343
8/387, Train_loss: 2.5600
9/387, Train_loss: 2.5182
10/387, Train_loss: 2.4517
11/387, Train_loss: 2.4365
12/387, Train_loss: 2.5157
13/387, Train_loss: 2.4632
14/387, Train_loss: 2.3988
15/387, Train_loss: 2.4129
16/387, Train_loss: 2.3719
17/387, Train_loss: 2.4206
18/387, Train_loss: 2.3678
19/387, Train_loss: 2.4401
20/387, Train_loss: 2.3865
21/387, Train_loss: 2.3194
22/387, Train_loss: 2.4153
23/387, Train_loss: 2.3475
24/387, Train_loss: 2.3175
25/387, Train_loss: 2.3015
26/387, Train_loss: 2.3451
27/387, Train_loss: 2.3304
28/387, Train_loss: 2.2899
29/387, Train_loss: 2.2727
30/387, Train_loss: 2.2860
31/387, Train_loss: 2.2981
32/387, Train_loss: 2.2523
33/387, Train_loss: 2.2655
34/387, Train_loss: 2.2735
35/387, Train_loss: 2.2360
36/387, Train_loss: 2.3165
37/387, Train_