In [2]:
import os
import tempfile
from glob import glob
import shutil

import matplotlib.pyplot as plt
import monai
import nibabel as nib
import numpy as np
import torch
from monai.data import DataLoader, PatchDataset
from monai.inferers import SliceInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    RandRotate90d,
    Resized,
    ScaleIntensityd,
    SqueezeDimd,
)
from monai.visualize import matshow3d

from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    EnsureChannelFirstd,
    DeleteItemsd,

)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR

from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)

monai.config.print_config()
monai.utils.set_determinism(0)

MONAI version: 1.2.0+95.ga4e4894d
Numpy version: 1.24.4
Pytorch version: 2.0.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a4e4894dca25f5e87b9306abfc472805f92b69da
MONAI __file__: /usr/local/lib/python3.8/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.3.0
Nibabel version: 5.1.0
scikit-image version: 0.21.0
scipy version: 1.10.1
Pillow version: 10.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.1
TorchVision version: 0.15.2+cu117
tqdm version: 4.66.1
lmdb version: 1.4.1
psutil version: 5.9.5
pandas version: 2.0.3
einops version: 0.6.1
transformers version: 4.32.1
mlflow version: 2.6.0
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recomm

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpd7if9faj


In [4]:
num_samples = 1

import os
import torch

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"  # Use the 3rd and 4th GPU. Indexing starts from 0.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from monai.transforms import Transform
from monai.transforms import SpatialPadd

import numpy as np

import torch

def compute_bounding_box(data):
    """
    Compute the bounding box of the non-zero region in a multi-dimensional tensor.

    Args:
        data (torch.Tensor): the input data tensor.

    Returns:
        tuple: a tuple containing the starting and ending coordinates of the bounding box.
    """
    # Find the coordinates of all non-zero data points
    coords = torch.nonzero(data)

    # Find the minimum and maximum coordinates along each dimension
    start_coords = coords.min(dim=0).values
    end_coords = coords.max(dim=0).values + 1  # +1 to include the max coordinate

    return start_coords.tolist(), end_coords.tolist()


class RandCropInsideForeground(Transform):
    def __init__(self, roi_key, spatial_size):
        self.roi_key = roi_key
        self.spatial_size = spatial_size

    def __call__(self, data):
        label_data = data[self.roi_key[0]]
        print(label_data.shape)

        # Convert label_data to a CUDA tensor
        label_data_tensor = torch.tensor(label_data).cuda()

        # Ensure the spatial size is within the volume shape
        valid_spatial_size = torch.min(torch.tensor(self.spatial_size).cuda(), torch.tensor(label_data.shape[1:]).cuda())

        # Convert valid_spatial_size back to a numpy array
        valid_spatial_size = valid_spatial_size.cpu().numpy()

        # Generate foreground bounding box
        fg_start, fg_end = compute_bounding_box(label_data)

        # Randomize start and end indices for cropping inside the foreground
        start = []
        for s, e, sz in zip(fg_start[1:], fg_end[1:], valid_spatial_size):
            max_start = max(0, e - sz)
            # Use torch.randint for generating random numbers on GPU
            start_val = torch.randint(s, max_start + 1, (1,)).item()
            start.append(start_val)

        self.roi_start = tuple(start)
        self.roi_end = tuple([s + sz for s, sz in zip(self.roi_start, valid_spatial_size)])

        slices = [slice(s, e) for s, e in zip(self.roi_start, self.roi_end)]
        result = {}
        for key in data.keys():
            if data[key].ndim > 1:
                result[key] = data[key][..., slices[0], slices[1], slices[2]]
            else:
                result[key] = data[key]

        print(result['label'].shape)
        return result



from monai.transforms import Transform, SpatialPadd
import numpy as np

class DynamicAxisPad(Transform):
    """
    Dynamically pad the axis with the smallest size to the desired length.
    """
    def __init__(self, keys, desired_size: int = 96, mode: str = 'constant'):
        """
        Args:
            keys: keys of the corresponding items to be transformed.
            desired_size: the desired size after padding.
            mode: padding mode, can be one of ['edge', 'constant', 'reflect', 'replicate'].
        """
        self.keys = keys
        self.desired_size = desired_size
        self.mode = mode

    def __call__(self, data):
        for key in self.keys:
            img = data[key][0]

            # Determine the size of the image
            image_shape = img.shape

            # Find the axis with the smallest size
            min_axis = np.argmin(image_shape)

            # Define the padding size based on the minimum axis
            padding_size = list(image_shape)  # Assuming a 3D image
            padding_size[min_axis] = self.desired_size

            # Apply the padding
            padder = SpatialPadd(keys=[key], spatial_size=padding_size, mode=self.mode)
            data = padder(data)

        return data



train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=False),
        EnsureChannelFirstd(keys=["image", "label"]),
        DeleteItemsd(keys=["image_meta_dict", "label_meta_dict","foreground_start_coord", "foreground_end_coord"]),  # Remove the metadata
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True,),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
        # DynamicAxisPad(keys=['image'], desired_size=96, mode='constant'),
        # RandCropInsideForeground(roi_key=["label"], spatial_size=(96, 96, 96)),

    ]
)


# val_transforms = Compose(
#     [
#         LoadImaged(keys=["image", "label"], ensure_channel_first=False),
#         EnsureChannelFirstd(keys=["image", "label"]),
#         DeleteItemsd(keys=["image_meta_dict", "label_meta_dict","foreground_start_coord", "foreground_end_coord"]),  # Remove the metadata
#         ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
#         DynamicAxisPad(keys=['image'], desired_size=96, mode='constant'),
#         CropForegroundd(keys=["image", "label"], source_key="image"),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Spacingd(
#             keys=["image", "label"],
#             pixdim=(1.5, 1.5, 2.0),
#             mode=("bilinear", "nearest"),
#         ),
#         EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
#     ]
# )

monai.transforms.io.dictionary LoadImaged.__init__:image_only: Current default value of argument `image_only=False` has been deprecated since version 1.1. It will be changed to `image_only=True` in version 1.3.
monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.3.


In [5]:
import os
import pandas as pd
from tqdm import tqdm
from monai.data import Dataset
from torch.utils.data import DataLoader

current_directory = os.getcwd()


split_json = "output_1.json"

datasets = os.path.join(current_directory, split_json)
train_files = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")

# train_ds = Dataset(data=train_files, transform=train_transforms)
# train_loader = DataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)

# val_ds = Dataset(data=val_files, transform=val_transforms)
# val_loader = DataLoader(val_ds, num_workers=0, batch_size=1)



In [5]:
# volume-level transforms for both image and segmentation
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        DynamicAxisPad(keys=['image'], desired_size=96, mode='constant'),
        ScaleIntensityd(keys="image"),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True,),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)
# 3D dataset with preprocessing transforms
volume_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_num=8,cache_rate=1,num_workers=8)
# use batch_size=1 to check the volumes because the input volumes have different shapes
check_loader = DataLoader(volume_ds, batch_size=1)
check_data = monai.utils.misc.first(check_loader)
print("first volume's shape: ", check_data["image"].shape, check_data["label"].shape)

Loading dataset: 100%|█████████████████████████████████████| 8/8 [00:15<00:00,  1.98s/it]

first volume's shape:  torch.Size([1, 1, 117, 243, 256]) torch.Size([1, 1, 117, 243, 256])





In [6]:
num_samples = 4
patch_func = monai.transforms.RandSpatialCropSamplesd(
    keys=["image", "label"],
    roi_size=[1, -1, -1],  # dynamic spatial_size for the first two dimensions
    num_samples=num_samples,
    random_size=False,
)

patch_transform = Compose(
    [
        SqueezeDimd(keys=["image", "label"], dim=-1),  # squeeze the last dim
        Resized(keys=["image", "label"], spatial_size=[96, 96]),
        # to use crop/pad instead of resize:
        # ResizeWithPadOrCropd(keys=["img", "seg"], spatial_size=[48, 48], mode="replicate"),
    ]
)
patch_ds = PatchDataset(
    volume_ds,
    transform=patch_transform,
    patch_func=patch_func,
    samples_per_image=num_samples,
)
train_loader = DataLoader(
    patch_ds,
    batch_size=3,
    shuffle=True,  # this shuffles slices from different volumes
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
)
check_data = monai.utils.misc.first(train_loader)
print("first patch's shape: ", check_data["image"].shape, check_data["label"].shape)

first patch's shape:  torch.Size([3, 1, 48, 48]) torch.Size([3, 1, 48, 48])


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=6,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2),
    num_res_units=2,
).to(device)

#loss_function = monai.losses.DiceLoss(sigmoid=True)
loss_function = monai.losses.DiceCELoss(to_onehot_y=True, softmax=True)
#optimizer = torch.optim.Adam(model.parameters(), 5e-3)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()


In [None]:
epoch_loss_values = []
num_epochs = 20
for epoch in range(num_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss, step = 0, 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)

        # with torch.cuda.amp.autocast():
        #     outputs = model(inputs)
        #     loss = loss_function(outputs, labels)
        # scaler.scale(loss).backward()
        # epoch_loss += loss.item()
        # scaler.unscale_(optimizer)
        # scaler.step(optimizer)
        # scaler.update()
        # optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

        # Backward pass and optimization
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        
        # optimizer.zero_grad()
        # outputs = model(inputs)
        # loss = loss_function(outputs, labels)
        # loss.backward()
        # optimizer.step()
        # epoch_loss += loss.item()

        
        epoch_len = len(patch_ds) // train_loader.batch_size
        if step % 25 == 0:
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
print("train completed")

----------
epoch 1/20
