In [14]:
import os
import glob
import torch
import monai
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstD,
    ScaleIntensityD,
    CropForegroundD,
    Activations,
    AsDiscrete,
    ToTensorD,
    Compose,
)
from monai.networks.nets import UNet, SegResNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.data import Dataset, DataLoader
from monai.utils import set_determinism

# Set deterministic training for reproducibility
set_determinism(seed=42)

In [2]:
# Define paths
train_files = sorted(
    [f for f in glob.glob("./Data/train/**/**.nii.gz")]
)
val_files = sorted(
    [f for f in glob.glob("./Data/val/**/**.nii.gz")]
)
test_files = sorted(
    [f for f in glob.glob("./Data/test/**/**.nii.gz")]
)

train_images, train_labels = [f for f in train_files if "seg" not in f], [f for f in train_files if "seg" in f]
val_images, val_labels = [f for f in val_files if "seg" not in f], [f for f in val_files if "seg" in f]
test_images, test_labels = [f for f in test_files if "seg" not in f], [f for f in test_files if "seg" in f]

train_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
val_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]
test_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels)]

In [3]:
# Define transforms
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),  # Load images and labels
        EnsureChannelFirstD(keys=["image", "label"]),  # Ensure channel dimension is first
        ScaleIntensityD(keys=["image"]),  # Scale intensities for the image
        # CropForegroundD(keys=["image", "label"], source_key="image"),  # Crop foreground based on the image
        ToTensorD(keys=["image", "label"]),  # Convert to tensors
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),  # Load images and labels
        EnsureChannelFirstD(keys=["image", "label"]),  # Ensure channel dimension is first
        ScaleIntensityD(keys=["image"]),  # Scale intensities for the image
        # CropForegroundD(keys=["image", "label"], source_key="image"),  # Crop foreground based on the image
        ToTensorD(keys=["image", "label"]),  # Convert to tensors
    ]
)

In [4]:
# Prepare datasets and dataloaders
train_dataset = Dataset(
    data=train_data_dicts,
    transform=train_transforms,
)
val_dataset = Dataset(
    data=val_data_dicts,
    transform=val_transforms,
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [5]:
for batch_data in train_loader:
    inputs, labels = batch_data["image"], batch_data["label"]
    print(inputs.shape, labels.shape)
    break

torch.Size([4, 1, 256, 128, 256]) torch.Size([4, 1, 256, 128, 256])


In [16]:
# Define the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the model
# model = UNet(
#     spatial_dims=3,  # Use 3 for 3D images, 2 for 2D images
#     in_channels=1,   # Number of input channels (e.g., grayscale images = 1)
#     out_channels=4,  # Number of output channels (e.g., 2 for binary segmentation)
#     channels=(16, 32, 64, 128, 256),  # Number of features at each layer
#     strides=(2, 2, 2, 2),  # Strides for down-sampling
#     num_res_units=2,  # Number of residual units
# ).to(device)

model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=1,
    out_channels=1,
    dropout_prob=0.2,
).to(device)

# Loss function and optimizer
# loss_function = DiceLoss(to_onehot_y=True, softmax=True)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

# Define metrics
dice_metric = DiceMetric(include_background=True, reduction="mean")

In [17]:
outputs.shape
labels.shape

torch.Size([4, 1, 256, 128, 256])

In [19]:
# Training loop
max_epochs = 50
val_interval = 2

for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0

    for batch_data in train_loader:
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    print(f"Average Training Loss: {epoch_loss:.4f}")

    # Validation loop
    if (epoch + 1) % val_interval == 0:
        model.eval()
        val_dice = []
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                val_outputs = model(val_inputs)
                dice = dice_metric(val_outputs, val_labels)
                mean_dice = dice.mean().item()  # Compute mean Dice across classes
                val_dice.append(mean_dice)
        print(f"Validation Dice: {torch.mean(torch.tensor(val_dice)):.4f}")

# Testing loop
# model.eval()
# test_outputs = []
# with torch.no_grad():
#     for test_data in test_loader:
#         test_inputs = test_data["image"].to(device)
#         test_outputs.append(model(test_inputs))

# print("Testing complete!")

Epoch 1/50


Average Training Loss: 0.5833
Epoch 2/50
Average Training Loss: 0.5702
Validation Dice: 0.0000
Epoch 3/50
Average Training Loss: 0.5563
Epoch 4/50
Average Training Loss: 0.5520
Validation Dice: 0.0000
Epoch 5/50
Average Training Loss: 0.5458
Epoch 6/50
Average Training Loss: 0.5372
Validation Dice: 0.0000
Epoch 7/50
Average Training Loss: 0.5260
Epoch 8/50
Average Training Loss: 0.5154
Validation Dice: 0.0000
Epoch 9/50
Average Training Loss: 0.5079
Epoch 10/50
Average Training Loss: 0.5054
Validation Dice: 0.0000
Epoch 11/50
Average Training Loss: 0.5019
Epoch 12/50
Average Training Loss: 0.5000
Validation Dice: 0.0000
Epoch 13/50


KeyboardInterrupt: 