In [2]:
import os
from monai.apps import download_and_extract


* 'schema_extra' has been renamed to 'json_schema_extra'


In [3]:
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar"

root_dir = "."
compressed_file = os.path.join(root_dir, "Task06_Lung.tar")
data_dir = os.path.join(root_dir, "Task06_Lung")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir)

Task06_Lung.tar:   0%|          | 41.1M/8.53G [00:08<29:56, 5.08MB/s]    


KeyboardInterrupt: 

In [4]:
import nibabel as nib
import numpy as np
import torch
import monai
from monai.data import DataLoader, CacheDataset, ArrayDataset, load_decathlon_datalist, decollate_batch
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, 
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld, RandFlipd, 
    RandRotate90d, RandShiftIntensityd, EnsureTyped, Resized, AsDiscrete
)

# Path to the dataset
# data_dir = "../Task06_Lung"
split_json = os.path.join(data_dir, "dataset.json")

# Loading the dataset split (JSON file with train, validation split)
train_files = load_decathlon_datalist(split_json, data_list_key="training")
test_files = load_decathlon_datalist(split_json, data_list_key="test")
np.random.shuffle(train_files)
split_point = int(0.8 * len(train_files))
train_files, val_files = train_files[:split_point], train_files[split_point:]

# Define transforms for training and validation
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label", spatial_size=(128, 128, 128), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0
    ),
    RandFlipd(keys=["image", "label"], spatial_axis=[0], prob=0.5),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
    EnsureTyped(keys=["image", "label"]),
])

val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    EnsureTyped(keys=["image", "label"]),
    Resized(keys=["image", "label"], spatial_size=(128, 128, 128))
])

# Create DataLoader for training and validation

train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
# train_ds = ArrayDataset(img=train_files, img_transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
# val_ds = ArrayDataset(img=val_files, img_transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1)

print(f"train: {len(train_ds)}, val: {len(val_ds)}")


Loading dataset: 100%|██████████| 50/50 [04:33<00:00,  5.47s/it]
Loading dataset: 100%|██████████| 13/13 [01:09<00:00,  5.34s/it]

train: 50, val: 13





In [10]:
from monai.networks.layers import Norm
from monai.networks.nets import UNet, SwinUNETR
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from torch.optim import Adam
from torchsummary import summary

# Define the model
# model = UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=2,  # Number of segmentation classes
#     channels=(16, 32, 64, 128, 256),
#     strides=(2, 2, 2, 2),
#     norm=Norm.BATCH,
#     num_res_units=2,
#     dropout=0.1
# ).cuda()

model = SwinUNETR(
    img_size=(128, 128, 128),
    spatial_dims=3,
    in_channels=1,
    out_channels=2,  # Number of segmentation classes
    feature_size=24
).cuda()

# Define loss function and optimizer
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = Adam(model.parameters(), lr=1e-4)

# Dice Metric for validation
dice_metric = DiceMetric(reduction="mean")

summary(model, (1, 128, 128, 128), batch_size=2)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.63 GiB. GPU 0 has a total capacty of 8.00 GiB of which 0 bytes is free. Of the allocated memory 20.02 GiB is allocated by PyTorch, and 523.17 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [9]:
from tqdm import tqdm
import torch.nn.functional as F

# Training and validation loop
max_epochs = 150
val_interval = 2
best_metric = -1
best_metric_epoch = -1
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print(f"Epoch {epoch + 1}/{max_epochs}")

    # Training
    model.train()
    losses = []
    for batch_data in train_loader:
        inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    print(f"Epoch {epoch + 1}, Loss avg: {np.average(losses):.4f} | std: {np.std(losses):.4f}")

    # Validation
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            dice_metric.reset()
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
                val_outputs = sliding_window_inference(val_inputs, (128, 128, 128), 4, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            dice_metric.reset()

            print(f"Validation Dice: {metric:.4f}")
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model.pth")
                print("Saved new best metric model")

print(f"Training complete. Best validation dice: {best_metric:.4f} at epoch {best_metric_epoch}")


Epoch 1/150


OutOfMemoryError: CUDA out of memory. Tried to allocate 10.52 GiB. GPU 0 has a total capacty of 8.00 GiB of which 0 bytes is free. Of the allocated memory 13.73 GiB is allocated by PyTorch, and 1.19 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# Load the best model
model.load_state_dict(torch.load("best_metric_model.pth"))
model.eval()

# Example inference
with torch.no_grad():
    for i, test_data in enumerate(val_loader):
        test_inputs = test_data["image"].cuda()
        roi_size = (128, 128, 128)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
        test_outputs = torch.argmax(test_outputs, dim=1)

        print(test_outputs)
        # Save or visualize test_outputs as needed
