In [1]:
import os
import pathlib
import json
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold


In [2]:
def get_attributes(sitk_image):
    """Get physical space attributes (meta-data) of the image."""
    attributes = {}
    attributes['orig_pixelid'] = sitk_image.GetPixelIDValue()
    attributes['orig_origin'] = sitk_image.GetOrigin()
    attributes['orig_direction'] = sitk_image.GetDirection()
    attributes['orig_spacing'] = np.array(sitk_image.GetSpacing())
    attributes['orig_size'] = np.array(sitk_image.GetSize(), dtype=int)
    return attributes


def resample_sitk_image(sitk_image,
                        new_spacing=[1, 1, 1],
                        new_size=None,
                        attributes=None,
                        interpolator=sitk.sitkLinear,
                        new_origin=(0.0, 0.0, 0.0),
                        fill_value=0):
    """
    Resample a SimpleITK Image.

    Parameters
    ----------
    sitk_image : sitk.Image
        An input image.
    new_spacing : list of int
        A distance between adjacent voxels in each dimension given in physical units (mm) for the output image.
    new_size : list of int or None
        A number of pixels per dimension of the output image. If None, `new_size` is computed based on the original
        input size, original spacing and new spacing.
    attributes : dict or None
        The desired output image's spatial domain (its meta-data). If None, the original image's meta-data is used.
    interpolator
        Available interpolators:
            - sitk.sitkNearestNeighbor : nearest
            - sitk.sitkLinear : linear
            - sitk.sitkGaussian : gaussian
            - sitk.sitkLabelGaussian : label_gaussian
            - sitk.sitkBSpline : bspline
            - sitk.sitkHammingWindowedSinc : hamming_sinc
            - sitk.sitkCosineWindowedSinc : cosine_windowed_sinc
            - sitk.sitkWelchWindowedSinc : welch_windowed_sinc
            - sitk.sitkLanczosWindowedSinc : lanczos_windowed_sinc
    fill_value : int or float
        A value used for padding, if the output image size is less than `new_size`.

    Returns
    -------
    sitk.Image
        The resampled image.

    Notes
    -----
    This implementation is based on https://github.com/deepmedic/SimpleITK-examples/blob/master/examples/resample_isotropically.py
    """
    sitk_interpolator = interpolator

    # provided attributes:
    if attributes:
        orig_pixelid = attributes['orig_pixelid']
        orig_origin = sitk_image.GetOrigin()
        orig_direction = attributes['orig_direction']
        orig_spacing = attributes['orig_spacing']
        orig_size = attributes['orig_size']

    else:
        # use original attributes:
        orig_pixelid = sitk_image.GetPixelIDValue()
        orig_origin = sitk_image.GetOrigin()
        orig_direction = sitk_image.GetDirection()
        orig_spacing = np.array(sitk_image.GetSpacing())
        orig_size = np.array(sitk_image.GetSize(), dtype=int)

    # new image size:
    if not new_size:
        new_size = orig_size * (orig_spacing / new_spacing)
        new_size = np.ceil(new_size).astype(int)  # Image dimensions are in integers
        new_size = [int(s) for s in new_size]  # SimpleITK expects lists, not ndarrays

    resample_filter = sitk.ResampleImageFilter()
    resample_filter.SetSize(new_size)
    resample_filter.SetTransform(sitk.Transform())
    resample_filter.SetInterpolator(sitk_interpolator)
    resample_filter.SetOutputOrigin(orig_origin)
    resample_filter.SetOutputSpacing(new_spacing)
    resample_filter.SetOutputDirection(orig_direction)
    resample_filter.SetDefaultPixelValue(fill_value)
    resample_filter.SetOutputPixelType(orig_pixelid)

    resampled_sitk_image = resample_filter.Execute(sitk_image)
    return resampled_sitk_image

In [8]:
path = 'SegAorta/'
dongyang = os.listdir(os.path.join(path, 'Dongyang'))
dongyang.remove('.DS_Store')
rider = os.listdir(os.path.join(path, 'Rider'))
rider.remove('.DS_Store')
kits = os.listdir(os.path.join(path, 'KiTS'))
kits.remove('.DS_Store')

In [9]:
patients = []
for k, d, r in zip(kits, dongyang, rider):
    patients.append('KiTS/' + k)
    patients.append('Dongyang/' + d)
    patients.append('Rider/' + r)
    

In [19]:
path = 'SegAorta/'
patients = os.listdir(path)
patients.remove('.DS_Store')

In [20]:
len(patients)

56

In [23]:
out_path = 'segaorta_resampled_nii/'

In [25]:
for p in patients:
    # p = pat_path.split('/')[-1]
    print("Processing:", p)
    
    ct_path = os.path.join(path, p, (p+'.nrrd'))
    gt_path = os.path.join(path, p, (p+'.seg.nrrd'))

    ct = sitk.ReadImage(ct_path)
    gt = sitk.ReadImage(gt_path)
    
    # CT resampling
    print("Before resampling:")
    print(p, 'ct:', ct.GetSize(), ct.GetSpacing(), ct.GetOrigin(), ct.GetDirection())
    
    print("After resampling:")
    ct_img = resample_sitk_image(ct, 
                                    new_spacing=[1, 1, 1],
                                    new_size=None,
                                    interpolator=sitk.sitkLinear,
                                        )
    
    ct_fake = sitk.GetArrayFromImage(ct_img)
    ct_img = sitk.GetImageFromArray(ct_fake)

    print(p, 'ct:', ct_img.GetSize(), ct_img.GetSpacing(), ct_img.GetDirection(), ct_img.GetOrigin())


    # GT resampling
    gt.SetSpacing(ct.GetSpacing())

    target_size = list(ct_img.GetSize())
    attributes = get_attributes(ct_img)

    gt_img = resample_sitk_image(gt, 
                                attributes=attributes,
                                interpolator=sitk.sitkNearestNeighbor)
    
    gt_fake = sitk.GetArrayFromImage(gt_img)
    gt_img = sitk.GetImageFromArray(gt_fake)
    
    print(p, 'mask:', gt_img.GetSize(), gt_img.GetSpacing(), gt_img.GetDirection(), gt_img.GetOrigin())

    if not os.path.exists(os.path.join(out_path, p)):
            os.makedirs(os.path.join(out_path, p), exist_ok=True)

    sitk.WriteImage(ct_img, os.path.join(out_path, p, (p+'_ct.nii.gz')), useCompression=True)
    sitk.WriteImage(gt_img, os.path.join(out_path, p, (p+'_gt.nii.gz')), useCompression=True)
    

Processing: D12
Before resampling:
D12 ct: (512, 666, 136) (0.68359375, 0.68359375, 2.9925925925925925) (-171.15, -57.33672, -1023.9001) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
After resampling:
D12 ct: (350, 456, 407) (1.0, 1.0, 1.0) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) (0.0, 0.0, 0.0)
D12 mask: (350, 456, 407) (1.0, 1.0, 1.0) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) (0.0, 0.0, 0.0)
Processing: D15
Before resampling:
D15 ct: (512, 666, 132) (0.6269531249999998, 0.6269531249999998, 3.0304183206106865) (-151.512, -38.77539062500001, -926.1348999999998) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
After resampling:
D15 ct: (321, 418, 401) (1.0, 1.0, 1.0) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) (0.0, 0.0, 0.0)
D15 mask: (321, 418, 401) (1.0, 1.0, 1.0) (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) (0.0, 0.0, 0.0)
Processing: D14
Before resampling:
D14 ct: (512, 512, 132) (0.615234375, 0.615234375, 2.999998473282443) (-166.63499999999993, 12.7999999999998, -980.41

### Make JSON files

In [14]:
train_json = []

In [15]:
num_of_folds = 5

In [16]:
data_dir = "segaorta_resampled/"
patients = os.listdir(data_dir)

full_indices = range(len(patients))

kf = KFold(n_splits=num_of_folds, shuffle=True, random_state=786)

for f in range(num_of_folds):
    

    train_idx = {}
    test_idx = {}

    key = 1
    for i,j in kf.split(full_indices):
        train_idx[key] = i
        test_idx[key] = j

        key += 1

    patients_test = [patients[i] for i in test_idx[f+1]]



    for p in patients_test:
        train_json_dict = {}
        train_json_dict["id"] = p
        train_json_dict["fold"] = f
        train_json_dict["image"] = (os.path.join(p, (p + "_ct.nrrd")))
        train_json_dict["label"] = (os.path.join(p, (p + "_gt.seg.nrrd")))
        train_json.append(train_json_dict)


In [17]:
train_json.remove({
            "id": ".DS_Store",
            "fold": 0,
            "image": ".DS_Store/.DS_Store_ct.nrrd",
            "label": ".DS_Store/.DS_Store_gt.seg.nrrd"
        },)

In [18]:
train_json_final = {"training": train_json}

In [19]:
import json
with open("train_json_orig.json", "w") as outfile:
    json.dump(train_json_final, outfile, indent=4)