In [1]:
%matplotlib inline

In [2]:
import os
import shutil
import tempfile

from tqdm import tqdm

import numpy as np
import nibabel as nib
import concurrent.futures

from monai.data.utils import pad_list_data_collate

import monai.losses as losses
import matplotlib.pyplot as plt
from monai.networks.layers import Norm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import Transform
from monai.transforms import (
    AsDiscrete,
    Compose,
    SpatialPad,
    SpatialPadd,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    NormalizeIntensity,
    NormalizeIntensityd,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    RandGaussianNoised,
)


from monai.config import print_config
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import UNet

from monai.data import (
    DataLoader,
    ThreadDataLoader,
    SmartCacheDataset,
    PersistentDataset,
    Dataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)


import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, CyclicLR

from boundary_loss import BDLoss, DC_and_BD_loss

print_config()

MONAI version: 1.1.0
Numpy version: 1.21.5
Pytorch version: 1.13.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /home/emmanuel/PycharmProjects/pythonProject/venv/lib/python3.7/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.5.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.1
TorchVision version: 0.14.1+cu117
tqdm version: 4.65.0
lmdb version: 1.4.1
psutil version: 5.9.5
pandas version: 1.3.5
einops version: 0.6.1
transformers version: 4.28.1
mlflow version: 1.30.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = 'xxModel'
model_name =  os.path.join(root_dir, "xxxxbest_model_dice_ce.pth")
print(model_name)

xxModel/xxxxbest_model_dice_ce.pth


In [4]:
class CombineBinaryMaps(Transform):
    def __init__(self, num_classes, keys):
        super().__init__()
        self.num_classes = num_classes

    def __call__(self, data):
        binary_maps = data["label"]
        combined_map = torch.zeros_like(data['image'][0])
#         print(1, data['image'][0].shape, binary_maps[0].shape)

        for i in range(self.num_classes):
            zero_indices = np.where(combined_map == 0)
#             print(zero_indices, binary_maps[i].shape)
            combined_map[zero_indices] += (i + 1) * binary_maps[i][zero_indices]
#             combined_map %= (self.num_classes + 1)

        data["label"] = combined_map.long().unsqueeze(0)
        return data

In [71]:
num_samples = 4

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ToDo: Try windowing with W:350 and L:40
# define the window width and level
window_width = 350
window_level = 40

# calculate the intensity range to clip
intensity_min = window_level - window_width / 2.0
intensity_max = window_level + window_width / 2.0

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=True, image_only=False),
        CombineBinaryMaps(keys=["label"], num_classes=10),
#         RandGaussianNoised(keys=["image"], prob=0.50, mean=0.0, std=0.1),

        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
#         CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        SpatialPadd(
            spatial_size=(96, 96, 96),
            keys=["image", "label"]
        ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
       
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="image",
            image_threshold=0,
            allow_smaller=True
        ),

        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),

    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], ensure_channel_first=True, image_only=False),
        CombineBinaryMaps(keys=["label"], num_classes=10),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        SpatialPadd(
            spatial_size=(96, 96, 96),
            keys=["image", "label"]
        ),
        EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
    ]
)

# Load Dataset

In [5]:
def generate_file_path(root_path):
    all_files = os.listdir(root_path)
    return [{'image': f'{root_path}/{i}/ct.nii.gz',
             'label': [f'{root_path}/{i}/segments/{f}' for f in os.listdir(f'{root_path}/{i}/segments')  if not f.startswith("._")],
#              'label': f'{root_path}/{i}/labels.nii.gz'

            } 
            for i in all_files if os.path.isfile(f'{root_path}/{i}/ct.nii.gz')]


In [6]:
root = 'Dataset'
# root = 'C:/Training/Dataset'
file_list_train = generate_file_path(root_path=f'{root}/train')
file_list_val = generate_file_path(root_path=f'{root}/val')

In [74]:
# [len(x['label']) for x in file_list_val]

In [8]:
[x['label'] for x in file_list_train]

[['Dataset/train/s0452/segments/all_vertebrae_parts.nii.gz',
  'Dataset/train/s0452/segments/all_rib_parts.nii.gz',
  'Dataset/train/s0452/segments/pulmonary_artery.nii.gz',
  'Dataset/train/s0452/segments/lung_lower_lobe_left.nii.gz',
  'Dataset/train/s0452/segments/trachea.nii.gz',
  'Dataset/train/s0452/segments/lung_lower_lobe_right.nii.gz',
  'Dataset/train/s0452/segments/all_scapula_parts.nii.gz',
  'Dataset/train/s0452/segments/lung_upper_lobe_left.nii.gz',
  'Dataset/train/s0452/segments/lung_upper_lobe_right.nii.gz',
  'Dataset/train/s0452/segments/lung_middle_lobe_right.nii.gz'],
 ['Dataset/train/s0476/segments/all_vertebrae_parts.nii.gz',
  'Dataset/train/s0476/segments/all_rib_parts.nii.gz',
  'Dataset/train/s0476/segments/pulmonary_artery.nii.gz',
  'Dataset/train/s0476/segments/lung_lower_lobe_left.nii.gz',
  'Dataset/train/s0476/segments/trachea.nii.gz',
  'Dataset/train/s0476/segments/lung_lower_lobe_right.nii.gz',
  'Dataset/train/s0476/segments/all_scapula_parts.nii.g

In [75]:
train_ds = PersistentDataset(
    data=file_list_train,
    transform=train_transforms,
#     transform=scale,
    cache_dir='train_unet'
)

# train_ds = Dataset(
#     data=file_list_train,
#     transform=train_transforms,
# #     transform=scale,
# #     cache_dir='train_unet'
# )

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True,
#                           num_workers=2, pin_memory=False,
#                                 collate_fn=collate_fn
#                                 collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96))
                               )

In [76]:
val_ds = PersistentDataset(
    data=file_list_val,
    transform=val_transforms,
    cache_dir='val_unet'
#     cache_dir='C:/Training/val'
)

val_loader = DataLoader(val_ds, 
#                         num_workers=0,
                        batch_size=1, pin_memory=False,
#                               collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96))
                       )

# Model

In [77]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        else:
            if (self.best_loss - loss) > self.min_delta:
                self.best_loss = loss
                self.counter = 0
            else:
                self.counter += 1
                print(f'Early Stopping patience: {self.counter}, best loss: {self.best_loss}, current_loss: {loss}')
                if self.counter >= self.tolerance:  
                    self.early_stop = True

In [78]:
if torch.cuda.is_available():
    device_count = torch.cuda.device_count()
    print(f"Number of available GPUs: {device_count}")
    for i in range(device_count):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("No GPU found, using CPU instead.")

Number of available GPUs: 1
GPU 0: GeForce RTX 2080 Ti


In [79]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = UNet(
#               spatial_dims=3,
#               in_channels=1,
#               out_channels=11,
#               channels=(64, 128, 256),
#               strides=(2, 2),
#               num_res_units=2,
#               norm=Norm.BATCH
#                 ).to(device)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=11,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

In [80]:
torch.backends.cudnn.benchmark = True
early_stopping = EarlyStopping(tolerance=10, min_delta=0.001)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, lambda_dice=1, lambda_ce=1)
bd_loss =  BDLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, patience=10, verbose=True, mode='min', min_lr=1e-8, factor=0.9)
scaler = torch.cuda.amp.GradScaler()

In [81]:
transform = SpatialPad(
            (1, 96, 96, 96),
        )
def validation(epoch_iterator_val):
    model.eval()
    validation_loss = []
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            with torch.cuda.amp.autocast():
                val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), num_samples, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            loss = loss_function(val_outputs, val_labels)
            validation_loss.append(loss.item())
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description(f"Validate (loss={loss.item():2.5f})")
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
        
        mean_hausdorff = 0

        validation_loss_mean = np.nanmean(np.nan_to_num(np.array(validation_loss),
                                               nan=np.nan, posinf=np.nan, neginf=np.nan))

        val_loss_values.append(validation_loss_mean)
    return mean_dice_val, mean_hausdorff, validation_loss_mean


def train(global_step, train_loader, dice_val_best, global_step_best):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
#         print(step)
        x, y = (batch["image"].cuda(), batch["label"].cuda())
#         a = all(t.shape == x.shape for t in y)
        if x.shape != y.shape:
            print(x.shape, y.shape)
#         with torch.cuda.amp.autocast():
#             logit_map = model(x)
#             loss = loss_function(logit_map, y)
#         scaler.scale(loss).backward()
#         epoch_loss += loss.item()
#         if torch.isnan(loss):
#             print(torch.max(x), torch.min(x))
            
#         scaler.unscale_(optimizer)

#         scaler.step(optimizer)

#         scaler.update()
#         optimizer.zero_grad()

#         epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss.item():2.5f})")
#         global_step += 1
       

    epoch_loss /= step
    return global_step, dice_val_best, global_step_best, epoch_loss

In [82]:
max_iterations = 90000
post_label = AsDiscrete(to_onehot=11)
post_pred = AsDiscrete(argmax=True, to_onehot=11)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
hausdorff_metric = HausdorffDistanceMetric(include_background=True, get_not_nans=False, reduction="none",
                                           distance_metric="euclidean")
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
val_loss_values = []
while global_step < max_iterations:
    global_step, dice_val_best, global_step_best, epoch_loss = train(global_step, train_loader, dice_val_best, global_step_best)
    epoch_iterator_val = tqdm(val_loader, desc="Validate", dynamic_ncols=True)
    dice_val, hausdorff_val, loss_val = validation(epoch_iterator_val)
    scheduler.step(epoch_loss)
    epoch_loss_values.append(epoch_loss)
    metric_values.append(dice_val)
    early_stopping(loss_val)
    print(f'Mean Hausdorff disatnce: {hausdorff_val}')
    if dice_val > dice_val_best:
        dice_val_best = dice_val
        global_step_best = global_step
        torch.save(model.state_dict(), model_name)
        print(
            "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
        )
    else:
        print(
            "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                dice_val_best, dice_val
            )
        )
    # early stopping
    if early_stopping.early_stop:
        break
model.load_state_dict(torch.load(model_name))


Training (X / X Steps) (loss=X.X):   1%|▊                                                                                                                                                  | 2/357 [00:00<01:03,  5.61it/s]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])


Training (X / X Steps) (loss=X.X):   1%|██                                                                                                                                                 | 5/357 [00:01<02:36,  2.25it/s]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])
1 torch.Size([236, 236, 428]) torch.Size([236, 236, 428])


Training (X / X Steps) (loss=X.X):   2%|██▉                                                                                                                                                | 7/357 [00:08<09:38,  1.65s/it]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])


Training (X / X Steps) (loss=X.X):   2%|███▎                                                                                                                                               | 8/357 [00:08<07:13,  1.24s/it]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])
1 torch.Size([285, 285, 427]) torch.Size([285, 285, 427])


Training (X / X Steps) (loss=X.X):   3%|███▋                                                                                                                                               | 9/357 [00:17<19:48,  3.41s/it]

1 torch.Size([264, 264, 419]) torch.Size([264, 264, 419])


Training (X / X Steps) (loss=X.X):   3%|████                                                                                                                                              | 10/357 [00:24<26:22,  4.56s/it]

1 torch.Size([51, 153, 393]) torch.Size([51, 153, 393])


Training (X / X Steps) (loss=X.X):   3%|████▉                                                                                                                                             | 12/357 [00:25<14:27,  2.52s/it]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])


Training (X / X Steps) (loss=X.X):   4%|█████▎                                                                                                                                            | 13/357 [00:26<11:28,  2.00s/it]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])
1 torch.Size([145, 145, 147]) torch.Size([145, 145, 147])


Training (X / X Steps) (loss=X.X):   4%|█████▋                                                                                                                                            | 14/357 [00:27<09:13,  1.61s/it]

1 torch.Size([212, 212, 205]) torch.Size([212, 212, 205])


Training (X / X Steps) (loss=X.X):   4%|██████▏                                                                                                                                           | 15/357 [00:29<10:13,  1.79s/it]

1 torch.Size([281, 281, 409]) torch.Size([281, 281, 409])


Training (X / X Steps) (loss=X.X):   5%|██████▉                                                                                                                                           | 17/357 [00:38<16:21,  2.89s/it]

torch.Size([4, 1, 96, 96, 96]) torch.Size([4, 10, 96, 96, 96])
1 torch.Size([195, 195, 270]) torch.Size([195, 195, 270])


Training (X / X Steps) (loss=X.X):   5%|███████▎                                                                                                                                          | 18/357 [00:42<13:20,  2.36s/it]


RuntimeError: applying transform <monai.transforms.croppad.dictionary.RandCropByPosNegLabeld object at 0x7f09cf799ad0>

In [None]:
file_list_train

In [None]:
print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.show()

In [None]:
model.load_state_dict(torch.load(model_name))
model.eval()

In [None]:
case_num = 4
with torch.no_grad():
    img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    val_inputs = torch.unsqueeze(img, 1).to(device)
    val_labels = torch.unsqueeze(label, 1).to(device)
    val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), num_samples, model, overlap=0.8)
    plt.figure("check", (18, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, 200], cmap="gray")
    plt.subplot(1, 3, 2)
    plt.title("label")
    plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, 200])
    plt.subplot(1, 3, 3)
    plt.title("output")
    plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 200])
    plt.show()

In [None]:
file_list_test = generate_file_path(root_path=f'{root}/test')
test_ds = PersistentDataset(
    data=file_list_test,
    transform=val_transforms,
    cache_dir='test'
#     cache_dir='C:/Training/val'
)

test_loader = ThreadDataLoader(test_ds, num_workers=0, batch_size=1, 
                              collate_fn=lambda x: pad_list_data_collate(x, pad_to_shape=(96, 96, 96)))

In [None]:
def test(epoch_iterator_test):
    model.eval()
    test_loss = []
    with torch.no_grad():
        for batch in epoch_iterator_test:
            test_inputs, test_labels = (batch["image"].to(device), batch["label"].to(device))
            with torch.cuda.amp.autocast():
#                 test_outputs = sliding_window_inference(test_inputs, (96, 96, 96), num_samples, model, overlap=0.8)
                test_outputs = sliding_window_inference(test_inputs, (96, 96, 96), num_samples, model)
            test_labels_list = decollate_batch(test_labels)
            test_labels_convert = [post_label(test_label_tensor) for test_label_tensor in test_labels_list]
            test_outputs_list = decollate_batch(test_outputs)
            test_output_convert = [post_pred(test_pred_tensor) for test_pred_tensor in test_outputs_list]
            loss = loss_function(test_outputs, test_labels)
            test_loss.append(loss.item())
            dice_metric(y_pred=test_output_convert, y=test_labels_convert)
            epoch_iterator_test.set_description(f"Test (loss={loss.item():2.5f})")
        mean_dice_test = dice_metric.aggregate().item()
        dice_metric.reset()

        test_loss_mean = np.nanmean(np.nan_to_num(np.array(test_loss),
                                               nan=np.nan, posinf=np.nan, neginf=np.nan))
    return mean_dice_test, test_loss_mean

In [None]:
epoch_iterator_test = tqdm(test_loader, desc="Test", dynamic_ncols=True)
dice_test, loss_test = test(epoch_iterator_test)

In [None]:
print(f'Test: (loss={loss_test}) (dice={dice_test})')