# Test data reading and preprocessing for fetal brain segmentation

In [None]:
import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from matplotlib import pyplot as plt
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader

import monai
from monai.data import NiftiDataset, list_data_collate
from monai.transforms import (
    Activationsd,
    AddChanneld,
    NormalizeIntensityd,
    AsDiscreted,
    Resized,
    Compose,
    KeepLargestConnectedComponentd,
    LoadNiftid,
    RandCropByPosNegLabeld,
    RandRotated,
    RandFlipd,
    ToTensord,
    MapTransform,
    CropForegroundd,
    SpatialCrop
)
from monai.utils import set_determinism

# from ipynb.fs.full.io_utils import create_data_list
sys.path.append("/mnt/data/mranzini/Desktop/GIFT-Surg/FBS_Monai/basic_unet_monai/src/")
from io_utils import create_data_list
from custom_transform import ConverToOneHotd, MinimumPadd, CropForegroundAnisotropicMargind

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

cuda_device=1
torch.cuda.set_device(cuda_device)
set_determinism(seed=0)

## Create training and validation data list

In [None]:
# list folders to search for the data
data_root = ["/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB1",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB2", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupC",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupD",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupF"]

# list of subject IDs to search for data
list_root = "/mnt/data/mranzini/Desktop/GIFT-Surg/Retraining_with_expanded_dataset/config/file_names"
training_list = os.path.join(list_root, "list_train_files.txt")
validation_list = [os.path.join(list_root, "list_validation_h_files.txt"),
                   os.path.join(list_root, "list_validation_p_files.txt")]

# 
train_files = create_data_list(data_folder_list=data_root, 
                               subject_list=training_list, 
                               img_postfix='_Image', 
                               label_postfix='_Label')

print(len(train_files))
print(train_files[0])
print(train_files[-1])

val_files = create_data_list(data_folder_list=data_root, 
                             subject_list=validation_list, 
                             img_postfix='_Image', 
                             label_postfix='_Label')
print(len(val_files))
print(val_files[0])
print(val_files[-1])

# np.savetxt("images_training.txt", images_training, fmt='%s')
# np.savetxt("seg_training.txt", seg_training, fmt='%s')
# np.savetxt("images_validation.txt", images_validation, fmt='%s')
# np.savetxt("seg_validation.txt", seg_validation, fmt='%s')


## Setup transforms, dataset

In [None]:
num_classes = 1
seg_labels = [1]
patch_size = [96, 96, 36] 

# data preprocessing for training:
train_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        ConverToOneHotd(keys=["seg"], labels=seg_labels),
        AddChanneld(keys=["img"]),
        NormalizeIntensityd(keys=["img"]),
        MinimumPadd(keys=["img", "seg"], k=(-1, -1, patch_size[2])),
        Resized(keys=["img", "seg"], spatial_size=(patch_size[0], patch_size[1], -1)),
        RandCropByPosNegLabeld(
            keys=["img", "seg"], label_key="seg", spatial_size=patch_size, pos=1, neg=1, num_samples=2
        ),
#         RandRotated(keys=["img", "seg"], range_x=90, range_y=90, prob=0.5, keep_size=True,
#                     mode=["bilinear", "nearest"]),
#         RandFlipd(keys=["img", "seg"], spatial_axis=[0, 1]),
        ToTensord(keys=["img", "seg"]),
    ]
)
# create training data loader
check_train_files = train_files
print(len(check_train_files))

# define dataset, data loader
check_ds = monai.data.Dataset(data=check_train_files[:4], transform=train_transforms)
# use batch_size=2 to load images 
check_loader = monai.data.DataLoader(check_ds,
                                     batch_size=1,
                                     shuffle=True, num_workers=1,
                                     pin_memory=torch.cuda.is_available())

check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

plt.figure(figsize=(10, 5))
plt.subplot(131)
plt.imshow(check_data['img'][0, 0, :, :, 8], cmap='gray')
plt.subplot(132)
plt.imshow(check_data['seg'][0, 0, :, :, 8], interpolation="nearest")
print("Segmentation limits: channel 0")
print(torch.min(check_data['seg'][0, 0, :, :, 8]))
print(torch.max(check_data['seg'][0, 0, :, :, 8]))
if num_classes == 2:
    plt.subplot(133)
    plt.imshow(check_data['seg'][0, 1, : , :, 8], interpolation="nearest")
    print("Segmentation limits: channel 1")
    print(torch.min(check_data['seg'][0, 1, :, :, 8]))
    print(torch.max(check_data['seg'][0, 1, :, :, 8]))

# Transforms using the segmentation as bounding box for cropping

In [None]:
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple

from monai.config import IndexSelection, KeysCollection

def generate_spatial_bounding_box_anisotropic_margin(
    img: np.ndarray,
    select_fn: Callable = lambda x: x > 0,
    channel_indices: Optional[IndexSelection] = None,
    margin: Optional[Sequence[int]] = 0,
) -> Tuple[List[int], List[int]]:
    """
    generate the spatial bounding box of foreground in the image with start-end positions.
    Users can define arbitrary function to select expected foreground from the whole image or specified channels.
    And it can also add margin to every dim of the bounding box.
    Args:
        img: source image to generate bounding box from.
        select_fn: function to select expected foreground, default is to select values > 0.
        channel_indices: if defined, select foreground only on the specified channels
            of image. if None, select foreground on the whole image.
        margin: add margin to all dims of the bounding box.
    """
#     assert isinstance(margin, int), "margin must be int type."
    data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img
    data = np.any(select_fn(data), axis=0)
    nonzero_idx = np.nonzero(data)
    
    if isinstance(margin, int):
        margin = ensure_tuple_rep(margin, len(data.shape))
    margin = [m if m > 0 else 0 for m in margin]
    assert len(data.shape) == len(margin), "defined margin has different number of dimensions than input"

    box_start = list()
    box_end = list()
    for i in range(data.ndim):
        assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}"
        box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i]))
        box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1))
    return box_start, box_end

class MyCropForegroundd(MapTransform):
    """
    Dictionary-based version :py:class:`monai.transforms.CropForeground`.
    Crop only the foreground object of the expected images.
    The typical usage is to help training and evaluation if the valid part is small in the whole medical image.
    The valid part can be determined by any field in the data with `source_key`, for example:
    - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`.
    - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`.
    - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields.
    Users can define arbitrary function to select expected foreground from the whole source image or specified
    channels. And it can also add margin to every dim of the bounding box of foreground object.
    """

    def __init__(
        self,
        keys: KeysCollection,
        source_key: str,
        select_fn: Callable = lambda x: x > 0,
        channel_indices: Optional[IndexSelection] = None,
        margin: Optional[Sequence[int]] = 0,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            source_key: data source to generate the bounding box of foreground, can be image or label, etc.
            select_fn: function to select expected foreground, default is to select values > 0.
            channel_indices: if defined, select foreground only on the specified channels
                of image. if None, select foreground on the whole image.
            margin: add margin to dims of the bounding box.
        """
        super().__init__(keys)
        self.source_key = source_key
        self.select_fn = select_fn
        self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None
        self.margin = margin

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)
        box_start, box_end = generate_spatial_bounding_box_anisotropic_margin(
            d[self.source_key], self.select_fn, self.channel_indices, self.margin
        )
        cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
        for key in self.keys:
            d[key] = cropper(d[key])
        return d

In [None]:
# test another set of transforms to use the segmentation to crop a bounding box the foreground label
num_classes = 1
seg_labels = [1]
patch_size = [96, 96, 36] 

# data preprocessing for training:
train_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        ConverToOneHotd(keys=["seg"], labels=seg_labels),
        AddChanneld(keys=["img"]),
        NormalizeIntensityd(keys=["img"]),
        CropForegroundAnisotropicMargind(keys=["img", "seg"], source_key="seg", margin=[20, 20, 5]),
        MinimumPadd(keys=["img", "seg"], k=(-1, -1, patch_size[2])),
        Resized(keys=["img", "seg"], spatial_size=(patch_size[0], patch_size[1], -1), mode=["trilinear", "nearest"]),
        RandRotated(keys=["img", "seg"], range_x=90, range_y=90, prob=0.5, keep_size=True,
                    mode=["bilinear", "nearest"]),
        RandFlipd(keys=["img", "seg"], spatial_axis=[0, 1]),
        ToTensord(keys=["img", "seg"]),
    ]
)
# create training data loader
check_train_files = train_files
print(len(check_train_files))

# define dataset, data loader
check_ds = monai.data.Dataset(data=check_train_files[:4], transform=train_transforms)
# use batch_size=2 to load images 
check_loader = monai.data.DataLoader(check_ds,
                                     batch_size=1,
                                     shuffle=True, num_workers=1,
                                     pin_memory=torch.cuda.is_available())

check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

plt.figure(figsize=(10, 5))
plt.subplot(131)
plt.imshow(check_data['img'][0, 0, :, :, 18], cmap='gray')
plt.subplot(132)
plt.imshow(check_data['seg'][0, 0, :, :, 18], interpolation="nearest")
print("Segmentation limits: channel 0")
print(torch.min(check_data['seg'][0, 0, :, :, 18]))
print(torch.max(check_data['seg'][0, 0, :, :, 18]))
if num_classes == 2:
    plt.subplot(133)
    plt.imshow(check_data['seg'][0, 1, : , :, 18], interpolation="nearest")
    print("Segmentation limits: channel 1")
    print(torch.min(check_data['seg'][0, 1, :, :, 18]))
    print(torch.max(check_data['seg'][0, 1, :, :, 18]))

In [None]:
print(check_data['seg'][0, 0, 70, 10:40, 18])