In [78]:
from glob import glob
import os
import monai
from monai.transforms.utils import get_unique_labels
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    Flipd,
    Rotate90d,
    LabelToMaskd,
    MapTransform,
)

from torch.utils.data import DataLoader
from tqdm import tqdm
from monai.data import ThreadDataLoader, NibabelWriter
import torch
import numpy as np

In [83]:
data_paths = {
    'hmd': '/mnt/B-SSD/unet21d_slices/datasets/liver/supervised',
    'LITSkaggle': '/mnt/B-SSD/unet21d_slices/datasets/LITSkaggle',
    'amos22': '/mnt/B-SSD/unet21d_slices/datasets/amos22/AMOS22',
    'MSD_Brain': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task01_BrainTumour',
    'MSD_Heart': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task02_Heart',
    'MSD_Liver': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task03_Liver',
    'MSD_Hippocampus': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task04_Hippocampus',
    'MSD_Prostate': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task05_Prostate',
    'MSD_Lung': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task06_Lung',
    'MSD_Pancreas': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task07_Pancreas',
    'MSD_HepaticVessel': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task08_HepaticVessel',
    'MSD_Spleen': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task09_Spleen',
    'MSD_Colon': '/mnt/B-SSD/unet21d_slices/datasets/MSD/Task10_Colon',
    
}
dataset = 'LITSkaggle'
dataset_path = data_paths[dataset]

ct = sorted(glob(os.path.join(dataset_path, 'imagesTr', '*.nii*')))
mask = sorted(glob(os.path.join(dataset_path, 'labelsTr', '*.nii*')))

print("Found {} CT scans and {} masks".format(len(ct), len(mask)))
files = [{'ct': ct_, 'mask': mask_} for ct_, mask_ in zip(ct, mask)]

processed_path = os.path.join('/mnt/B-SSD/unet21d_slices/datasets/test', dataset)
imagefolder = 'images'
maskfolder = 'mask'

Found 131 CT scans and 131 masks


In [84]:
class ConvertToMultiChannelBasedOnClassesd(MapTransform):
    """
    Convert labels to multi channels based on classes:
    Args:
        keys (list): list of keys to be transformed.
    """

    def __call__(self, data):
        d = dict(data)
        # Get unique labels
        labels = list(get_unique_labels(d[self.keys[0]], is_onehot=False))[1:]
        # Convert labels to multi channels
        for key in self.keys:
            result = []
            for label in labels:
                result.append(d[key] == label)
            d[key] = torch.squeeze(torch.stack(result, axis=0).float())
        return d

In [86]:
# define transforms for image and segmentation
transforms = Compose(
    [
        LoadImaged(keys=['ct', 'mask'], image_only=True),
        EnsureChannelFirstd(keys=['ct', 'mask']),
        Rotate90d(keys=['ct', 'mask'], k=2, spatial_axes=(0, 1)),
        ConvertToMultiChannelBasedOnClassesd(keys=['mask']),
        #LabelToMaskd(keys=['mask'], select_labels=[1, 2], merge_channels=True),
        #Flipd(keys=['mask'], spatial_axis=1), #hmd data are flipped
    ]
)

ds = monai.data.Dataset(data=files, transform=transforms)
loader = ThreadDataLoader(ds, num_workers=1, batch_size=1, shuffle=False)

writer_ct = NibabelWriter()
writer_mask = NibabelWriter(output_dtype=np.int8)

patient_num = 0
for batch_data in tqdm(loader):
    ct, mask = batch_data['ct'], batch_data['mask']

    for slice in tqdm(range(ct.shape[4]), leave=False):
        writer_ct.set_data_array(ct[0, 0, :, :, slice], channel_dim=None)
        # first four digits are patient number, last four are slice number
        writer_ct.write(os.path.join(processed_path, imagefolder, '{:04d}_{:04d}_test.nii'.format(patient_num, slice)))

        writer_mask.set_data_array(mask[0, :, :, :, slice], channel_dim=0)
        # first four digits are patient number, last four are slice number
        writer_mask.write(os.path.join(processed_path, maskfolder, '{:04d}_{:04d}_test.nii'.format(patient_num, slice)))
            
    break
    patient_num += 1

  0%|          | 0/131 [00:00<?, ?it/s]TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if