# Debug data reader and transforms for dynUNet

In [None]:
import os
import sys
import logging
import yaml
import datetime
import argparse
from pathlib import Path

import ignite
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn.functional import interpolate

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from monai.config import print_config
from monai.data import DataLoader, PersistentDataset, Dataset
from monai.utils import misc
from monai.engines import SupervisedTrainer
from monai.losses import DiceLoss
from monai.networks.nets import DynUNet
from monai.transforms import (
    Compose,
    LoadNiftid,
    AddChanneld,
    CropForegroundd,
    Spacingd,
    Orientationd,
    SpatialPadd,
    NormalizeIntensityd,
    RandSpatialCropd,
    RandCropByPosNegLabeld,
    RandZoomd,
    CastToTyped,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandScaleIntensityd,
    RandRotated,
    RandFlipd,
    SqueezeDimd,
    ToTensord,
    Activationsd,
)
from monai.utils import set_determinism

sys.path.append("/mnt/data/mranzini/Desktop/GIFT-Surg/FBS_Monai/basic_unet_monai/src/")
from io_utils import create_data_list
from custom_transform import ConverToOneHotd

# print MONAI config information
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

cuda_device=1
torch.cuda.set_device(cuda_device)
set_determinism(seed=0)

## Create training and validation lists

In [None]:
# list folders to search for the data
data_root = ["/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB1",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB2", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupC",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupD",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupF"]

# list of subject IDs to search for data
list_root = "/mnt/data/mranzini/Desktop/GIFT-Surg/Retraining_with_expanded_dataset/config/file_names"
training_list = os.path.join(list_root, "list_train_files.txt")
validation_list = [os.path.join(list_root, "list_validation_h_files.txt"),
                   os.path.join(list_root, "list_validation_p_files.txt")]

train_files = create_data_list(data_folder_list=data_root, 
                               subject_list=training_list, 
                               img_postfix='_Image', 
                               label_postfix='_Label')

print(len(train_files))
print(train_files[0])
print(train_files[-1])

val_files = create_data_list(data_folder_list=data_root, 
                             subject_list=validation_list, 
                             img_postfix='_Image', 
                             label_postfix='_Label')
print(len(val_files))
print(val_files[0])
print(val_files[-1])

## Define a custom transform to change spacing in-plane only

In [None]:
import copy
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union

from monai.transforms import MapTransform, Spacing, Spacingd
from monai.config import KeysCollection
from monai.utils import (
    GridSampleMode,
    GridSamplePadMode,
    InterpolateMode,
    ensure_tuple,
    ensure_tuple_rep,
    fall_back_tuple,
)
GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str]
InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str]

class InPlaneSpacingd(Spacingd):
    def __init__(self, 
                 keys: KeysCollection,
                 pixdim: Sequence[float],
                 diagonal: bool = False,
                 mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
                 padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER,
                 align_corners: Union[Sequence[bool], bool] = False,
                 dtype: Optional[Union[Sequence[np.dtype], np.dtype]] = np.float64,
                 meta_key_postfix: str = "meta_dict", 
        ) -> None:
        super().__init__(keys, 
                         pixdim,
                         diagonal,
                         mode,
                         padding_mode,
                         align_corners,
                         dtype,
                         meta_key_postfix)
        self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
        self.diagonal = diagonal
        self.dim_to_keep = np.argwhere(self.pixdim == -1.0)

    def __call__(self, 
                data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]]
                )-> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]:
        d = dict(data)
        for idx, key in enumerate(self.keys):
            meta_data = d[f"{key}_{self.meta_key_postfix}"]
            # set pixdim to original pixdim value where required
            current_pixdim = copy.deepcopy(self.pixdim)
            original_pixdim = meta_data["pixdim"]
            old_pixdim = original_pixdim[1:4]
            current_pixdim[self.dim_to_keep] = old_pixdim[self.dim_to_keep]
            
            # apply the transform
            spacing_transform = Spacing(current_pixdim, diagonal=self.diagonal)
            
            # resample array of each corresponding key
            # using affine fetched from d[affine_key]
            d[key], _, new_affine = spacing_transform(
                data_array=d[key],
                affine=meta_data["affine"],
                mode=self.mode[idx],
                padding_mode=self.padding_mode[idx],
                align_corners=self.align_corners[idx],
                dtype=self.dtype[idx],
            )
            
            # store the modified affine
            meta_data["affine"] = new_affine
        return d

In [None]:
seg_labels = [0, 1]
pixdim = [0.8, 0.8, -1.0]
transf1 = Compose([
    LoadNiftid(keys=["image", "label"]),
    ConverToOneHotd(keys=["label"], labels=seg_labels),
    AddChanneld(keys=["image"]),
])

transf2 = Compose([transf1,
                  InPlaneSpacingd(keys=["image", "label"], 
                    pixdim=pixdim,
                    mode=["bilinear", "nearest"])])


data = transf1(train_files[1])
print(data["image_meta_dict"]["pixdim"])
print(data["image_meta_dict"]["original_affine"])
print(data["image_meta_dict"]["affine"])
print(data["image"].shape, data["label"].shape)

data2 = transf2(train_files[1])
print(data2["image_meta_dict"]["pixdim"])
print(data2["image_meta_dict"]["original_affine"])
print(data2["image_meta_dict"]["affine"])
print(data2["image"].shape, data2["label"].shape)

In [None]:
slice_z = 30
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(12, 10))

ax[0, 0].imshow(np.squeeze(data["image"][0, :, :, slice_z]), cmap="gray", interpolation=None)
ax[0, 1].imshow(np.squeeze(data["label"][0, :, :, slice_z]), vmin=0.0, vmax=1.0, interpolation=None)
ax[0, 2].imshow(np.squeeze(data["label"][1, :, :, slice_z]), vmin=0.0, vmax=1.0, interpolation=None)
ax[1, 0].imshow(np.squeeze(data2["image"][0, :, :, slice_z]), cmap="gray", interpolation=None)
ax[1, 1].imshow(np.squeeze(data2["label"][0, :, :, slice_z]), vmin=0.0, vmax=1.0, interpolation=None)
ax[1, 2].imshow(np.squeeze(data2["label"][1, :, :, slice_z]), vmin=0.0, vmax=1.0, interpolation=None)

## Test full series of transforms

In [None]:
from custom_transform import ConverToOneHotd, InPlaneSpacingd

spacing = (0.8, 0.8, -1.0)
seg_labels = [0, 1]
patch_size = (448, 512, 1)
num_classes = len(seg_labels)
batch_size = 4

train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            ConverToOneHotd(keys=["label"], labels=seg_labels),
            AddChanneld(keys=["image"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                       mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),   
            ToTensord(keys=["image", "label"]),
        ]
    )
# check on single data point
check_transforms = train_transforms(train_files[10])
print("Before data loader:")
print(check_transforms["image"].shape, check_transforms["label"].shape)

# create data loader and check correctness
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=1)
check_data = misc.first(train_loader)
print("Training data tensor shapes")
print(check_data["image"].shape, check_data["label"].shape)

In [None]:
for batch_element in range(batch_size): 
    plt.figure(figsize=(10, 5))
    plt.title(f"Batch element = {batch_element}")
    plt.subplot(131)
    plt.imshow(check_data["image"][batch_element, 0, :, :], cmap='gray')
    plt.subplot(132)
    plt.imshow(check_data["label"][batch_element, 0, :, :], interpolation="nearest", vmin=0.0, vmax=1.0)
    print("Segmentation limits: channel 0")
    print(torch.min(check_data["label"][batch_element, 0, :, :]))
    print(torch.max(check_data["label"][batch_element, 0, :, :]))
    if num_classes == 2:
        plt.subplot(133)
        plt.imshow(check_data["label"][batch_element, 1, : , :], interpolation="nearest", vmin=0.0, vmax=1.0)
        print("Segmentation limits: channel 1")
        print(torch.min(check_data["label"][batch_element, 1, :, :]))
        print(torch.max(check_data["label"][batch_element, 1, :, :]))