In [1]:
import os
import shutil
import tempfile
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    EnsureChannelFirstd,
    DeleteItemsd,  # Remove the metadata
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR

from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)


import torch

import multiprocessing

#multiprocessing.set_start_method('spawn', force=True)

print_config()

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

MONAI version: 1.2.0+95.ga4e4894d
Numpy version: 1.24.4
Pytorch version: 2.0.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a4e4894dca25f5e87b9306abfc472805f92b69da
MONAI __file__: /usr/local/lib/python3.8/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.3.0
Nibabel version: 5.1.0
scikit-image version: 0.21.0
scipy version: 1.10.1
Pillow version: 10.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.1
TorchVision version: 0.15.2+cu117
tqdm version: 4.66.1
lmdb version: 1.4.1
psutil version: 5.9.5
pandas version: 2.0.3
einops version: 0.6.1
transformers version: 4.32.1
mlflow version: 2.6.0
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recomm

In [18]:
num_samples = 1

import os
import torch

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"  # Use the 3rd and 4th GPU. Indexing starts from 0.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from monai.transforms import Transform
from monai.transforms import SpatialPadd

import numpy as np

import torch

def compute_bounding_box(data):
    """
    Compute the bounding box of the non-zero region in a multi-dimensional tensor.

    Args:
        data (torch.Tensor): the input data tensor.

    Returns:
        tuple: a tuple containing the starting and ending coordinates of the bounding box.
    """
    # Find the coordinates of all non-zero data points
    coords = torch.nonzero(data)

    # Find the minimum and maximum coordinates along each dimension
    start_coords = coords.min(dim=0).values
    end_coords = coords.max(dim=0).values + 1  # +1 to include the max coordinate

    return start_coords.tolist(), end_coords.tolist()


class RandCropInsideForeground(Transform):
    def __init__(self, roi_key, spatial_size):
        self.roi_key = roi_key
        self.spatial_size = spatial_size

    def __call__(self, data):
        label_data = data[self.roi_key[0]]
        print(label_data.shape)

        # Convert label_data to a CUDA tensor
        label_data_tensor = torch.tensor(label_data).cuda()

        # Ensure the spatial size is within the volume shape
        valid_spatial_size = torch.min(torch.tensor(self.spatial_size).cuda(), torch.tensor(label_data.shape[1:]).cuda())

        # Convert valid_spatial_size back to a numpy array
        valid_spatial_size = valid_spatial_size.cpu().numpy()

        # Generate foreground bounding box
        fg_start, fg_end = compute_bounding_box(label_data)

        # Randomize start and end indices for cropping inside the foreground
        start = []
        for s, e, sz in zip(fg_start[1:], fg_end[1:], valid_spatial_size):
            max_start = max(0, e - sz)
            # Use torch.randint for generating random numbers on GPU
            start_val = torch.randint(s, max_start + 1, (1,)).item()
            start.append(start_val)

        self.roi_start = tuple(start)
        self.roi_end = tuple([s + sz for s, sz in zip(self.roi_start, valid_spatial_size)])

        slices = [slice(s, e) for s, e in zip(self.roi_start, self.roi_end)]
        result = {}
        for key in data.keys():
            if data[key].ndim > 1:
                result[key] = data[key][..., slices[0], slices[1], slices[2]]
            else:
                result[key] = data[key]

        print(result['label'].shape)
        return result



from monai.transforms import Transform, SpatialPadd
import numpy as np

class DynamicAxisPad(Transform):
    """
    Dynamically pad the axis with the smallest size to the desired length.
    """
    def __init__(self, keys, desired_size: int = 96, mode: str = 'constant'):
        """
        Args:
            keys: keys of the corresponding items to be transformed.
            desired_size: the desired size after padding.
            mode: padding mode, can be one of ['edge', 'constant', 'reflect', 'replicate'].
        """
        self.keys = keys
        self.desired_size = desired_size
        self.mode = mode

    def __call__(self, data):
        for key in self.keys:
            img = data[key][0]

            # Determine the size of the image
            image_shape = img.shape

            # Find the axis with the smallest size
            min_axis = np.argmin(image_shape)

            # Define the padding size based on the minimum axis
            padding_size = list(image_shape)  # Assuming a 3D image
            padding_size[min_axis] = self.desired_size

            # Apply the padding
            padder = SpatialPadd(keys=[key], spatial_size=padding_size, mode=self.mode)
            data = padder(data)

        return data



train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=False),
        EnsureChannelFirstd(keys=["image", "label"]),
        DeleteItemsd(keys=["image_meta_dict", "label_meta_dict","foreground_start_coord", "foreground_end_coord"]),  # Remove the metadata
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True,),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
        DynamicAxisPad(keys=['image'], desired_size=96, mode='constant'),
        RandCropInsideForeground(roi_key=["label"], spatial_size=(96, 96, 96)),

    ]
)


val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=False),
        EnsureChannelFirstd(keys=["image", "label"]),
        DeleteItemsd(keys=["image_meta_dict", "label_meta_dict","foreground_start_coord", "foreground_end_coord"]),  # Remove the metadata
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        DynamicAxisPad(keys=['image'], desired_size=96, mode='constant'),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
    ]
)

In [19]:
import os
import pandas as pd
from tqdm import tqdm
from monai.data import Dataset
from torch.utils.data import DataLoader

current_directory = os.getcwd()


split_json = "output_1.json"

datasets = os.path.join(current_directory, split_json)
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")

train_ds = Dataset(data=datalist, transform=train_transforms)
train_loader = DataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)
val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, num_workers=0, batch_size=1)



In [20]:
set_track_meta(True)

In [21]:
sample_batch = next(iter(train_loader))
# print(type(sample_batch))
# print(sample_batch.keys()) if isinstance(sample_batch, dict) else print(sample_batch)


RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7f9c1c5792b0>

In [9]:
import torch.nn as nn

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=5,
    feature_size=48,
    use_checkpoint=True,
)

# device_model = torch.device("cuda:1")
# model = model.to(device_model)

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.to('cuda')


Using 3 GPUs!


DataParallel(
  (module): SwinUNETR(
    (swinViT): SwinTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv3d(1, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (layers1): ModuleList(
        (0): BasicLayer(
          (blocks): ModuleList(
            (0-1): 2 x SwinTransformerBlock(
              (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
              (attn): WindowAttention(
                (qkv): Linear(in_features=48, out_features=144, bias=True)
                (attn_drop): Dropout(p=0.0, inplace=False)
                (proj): Linear(in_features=48, out_features=48, bias=True)
                (proj_drop): Dropout(p=0.0, inplace=False)
                (softmax): Softmax(dim=-1)
              )
              (drop_path): Identity()
              (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
              (mlp): MLPBlock(
                (linear1): Linear(in_features=48, out

In [10]:
weight = torch.load("./model_swinvit.pt")
#model.load_from(weights=weight)
model.module.load_from(weights=weight)
print("Using pretrained self-supervied Swin UNETR backbone weights !")

Using pretrained self-supervied Swin UNETR backbone weights !


In [11]:
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()


In [12]:
def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch['image'].cuda(), batch['label'].cuda())

            # Place data on the first GPU
            # x = batch["image"].to("cuda:0")
            # y = batch["label"].to("cuda:0")
    
            # # Transfer data to second GPU for model processing
            # x = x.to("cuda:1")
            # y = y.to("cuda:1")
            
            with torch.cuda.amp.autocast():
                val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        #print(type(batch))
        print(batch['image'].shape)
        x, y = (batch['image'].cuda(), batch['label'].cuda())

        # Place data on the first GPU
        # x = batch["image"].to("cuda:0")
        # y = batch["label"].to("cuda:0")

        # # Transfer data to second GPU for model processing
        # x = x.to("cuda:1")
        # y = y.to("cuda:1")
        
        with torch.cuda.amp.autocast():
            logit_map = model(x)
            loss = loss_function(logit_map, y)
        scaler.scale(loss).backward()
        epoch_loss += loss.item()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
                print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
                )
            else:
                print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        dice_val_best, dice_val
                    )
                )
        global_step += 1
    return global_step, dice_val_best, global_step_best

In [13]:
max_iterations = 10000
eval_num = 500
post_label = AsDiscrete(to_onehot=5)
post_pred = AsDiscrete(argmax=True, to_onehot=5)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(current_directory, "best_metric_model.pth")))

Training (X / X Steps) (loss=X.X):   0%|                          | 0/70 [00:00<?, ?it/s]

torch.Size([1, 128, 342, 256])
torch.Size([1, 96, 96, 96])
torch.Size([1, 1, 96, 96, 96])


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [421,0,0], thread: [27,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [421,0,0], thread: [28,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [421,0,0], thread: [29,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [421,0,0], thread: [30,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): bloc

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
