# Test data reading and preprocessing for fetal brain segmentation

In [1]:
import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from matplotlib import pyplot as plt
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader

import monai
from monai.data import NiftiDataset, list_data_collate
from monai.transforms import (
    Activationsd,
    AddChanneld,
    NormalizeIntensityd,
    AsDiscreted,
    Compose,
    KeepLargestConnectedComponentd,
    LoadNiftid,
    RandCropByPosNegLabeld,
    RandRotated,
    RandFlipd,
    ToTensord,
    MapTransform
)
from monai.utils import set_determinism

from ipynb.fs.full.io_utils import create_data_list

monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

cuda_device=1
torch.cuda.set_device(cuda_device)
set_determinism(seed=0)

MONAI version: 0.2.0
Python version: 3.7.4 (default, Jul  9 2019, 03:52:42)  [GCC 5.4.0 20160609]
Numpy version: 1.18.2
Pytorch version: 1.4.0

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.0.2
scikit-image version: 0.16.2
Pillow version: 7.1.1
Tensorboard version: 2.2.1

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Create training and validation data list

In [2]:
# list folders to search for the data
data_root = ["/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB1",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupB2", 
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupC",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupD",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE",
             "/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupF"]

# list of subject IDs to search for data
list_root = "/mnt/data/mranzini/Desktop/GIFT-Surg/Retraining_with_expanded_dataset/config/file_names"
training_list = os.path.join(list_root, "list_train_files.txt")
validation_list = [os.path.join(list_root, "list_validation_h_files.txt"),
                   os.path.join(list_root, "list_validation_p_files.txt")]

# 
train_files = create_data_list(data_folder_list=data_root, 
                               subject_list=training_list, 
                               img_postfix='_Image', 
                               label_postfix='_Label')

print(len(train_files))
print(train_files[0])
print(train_files[-1])

val_files = create_data_list(data_folder_list=data_root, 
                             subject_list=validation_list, 
                             img_postfix='_Image', 
                             label_postfix='_Label')
print(len(val_files))
print(val_files[0])
print(val_files[-1])

# np.savetxt("images_training.txt", images_training, fmt='%s')
# np.savetxt("seg_training.txt", seg_training, fmt='%s')
# np.savetxt("images_validation.txt", images_validation, fmt='%s')
# np.savetxt("seg_validation.txt", seg_validation, fmt='%s')


316
{'img': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA/a01_02_Image.nii.gz', 'seg': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA/a01_02_Label.nii.gz'}
{'img': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE/E18_02_Image.nii.gz', 'seg': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE/E18_02_Label.nii.gz'}
50
{'img': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA/a04_02_Image.nii.gz', 'seg': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset/GroupA/a04_02_Label.nii.gz'}
{'img': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE/E11_08_Image.nii.gz', 'seg': '/mnt/data/mranzini/Desktop/GIFT-Surg/Data/NeuroImage_dataset_extension/GroupE/E11_08_Label.nii.gz'}


## Setup transforms, dataset

In [3]:
class ConverToOneHotd(MapTransform):
    """
    Convert multi-class label to One Hot Encoding:
    """
    def __init__(self, keys, labels):
        """
        Args:
            
        """
        super().__init__(keys)
        self.labels = labels
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = list()
            for n in self.labels:
                result.append(d[key] == n)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

In [6]:
num_classes = 1
seg_labels = [1]
patch_size = [256, 256, 24] 

# data preprocessing for training:
train_transforms = Compose(
    [
        LoadNiftid(keys=["img", "seg"]),
        ConverToOneHotd(keys=['seg'], labels=seg_labels),
        AddChanneld(keys=['img']),
        NormalizeIntensityd(keys=['img']),
        RandCropByPosNegLabeld(
            keys=["img", "seg"], label_key="seg", spatial_size=patch_size, pos=1, neg=1, num_samples=2
        ),
        RandRotated(keys=['img', 'seg'], range_x=90, range_y=90, prob=0.5, keep_size=True,
                    mode=["bilinear", "nearest"]),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        ToTensord(keys=["img", "seg"]),
    ]
)
# create training data loader
check_train_files = train_files
print(len(check_train_files))

# define dataset, data loader
check_ds = monai.data.Dataset(data=check_train_files, transform=train_transforms)
# use batch_size=2 to load images 
check_loader = monai.data.DataLoader(check_ds,
                                     batch_size=2,
                                     shuffle=True, num_workers=1,
                                     pin_memory=torch.cuda.is_available())

check_data = monai.utils.misc.first(check_loader)
print(check_data['img'].shape, check_data['seg'].shape)


# plt.figure(figsize=(10, 5))
# plt.subplot(131)
# plt.imshow(check_data['img'][0, 0, :, :, 8], cmap='gray')
# plt.subplot(132)
# plt.imshow(check_data['seg'][0, 0, :, :, 8], interpolation="nearest")
# print("Segmentation limits: channel 0")
# print(torch.min(check_data['seg'][0, 0, :, :, 8]))
# print(torch.max(check_data['seg'][0, 0, :, :, 8]))
# if num_classes == 2:
#     plt.subplot(133)
#     plt.imshow(check_data['seg'][0, 1, : , :, 8], interpolation="nearest")
#     print("Segmentation limits: channel 1")
#     print(torch.min(check_data['seg'][0, 1, :, :, 8]))
#     print(torch.max(check_data['seg'][0, 1, :, :, 8]))
########
# val_transforms = Compose([
#     LoadNiftid(keys=['img', 'seg']),
#     AddChanneld(keys=['img', 'seg']),
#     NormalizeIntensityd(keys=['img']),
#     Resized(keys=['img'], spatial_size=[96, 96], order=1),
#     Resized(keys=['seg'], spatial_size=[96, 96], order=0, anti_aliasing=False),
#     ToTensord(keys=['img', 'seg'])
# ])

# # create a training data loader
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# train_loader = DataLoader(train_ds, batch_size=10, shuffle=True, num_workers=4,
#                           collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())
# check_train_data = monai.utils.misc.first(train_loader)
# print(check_train_data['img'].shape, check_train_data['seg'].shape)

# # create a validation data loader
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
# val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate,
#                         pin_memory=torch.cuda.is_available())

316
torch.Size([4, 1, 256, 256, 24]) torch.Size([4, 1, 256, 256, 24])
