In [4]:
%load_ext autoreload
%autoreload 2  
import sys
sys.path.insert(0,'../')
import src.utils
import pandas as pd 
import tqdm
import os
import multiprocessing
import numpy as np 
import SimpleITK as sitk 
import matplotlib.pyplot as plt 
import scipy.misc
import multiprocessing
import json 
import tensorflow as tf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
path_to_folder = 'data/Dataset_1/'
labels = pd.read_csv(os.path.join(path_to_folder, 'labels.csv'),
                     decimal=",")
list_series = labels.Sequence_id.tolist()

In [5]:
def threshold_based_crop(image):
    """
    Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is
    usually air. Then crop the image using the foreground's axis aligned bounding box.
    Args:
        image (SimpleITK image): An image where the anatomy and background intensities form a bi-modal distribution
                                 (the assumption underlying Otsu's method.)
    Return:
        Cropped image based on foreground's axis aligned bounding box.                                 
    """
    # Set pixels that are in [min_intensity,otsu_threshold] to inside_value, values above otsu_threshold are
    # set to outside_value. The anatomy has higher intensity values than the background, so it is outside.
    inside_value = 0
    outside_value = 255
    label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
    label_shape_filter.Execute( sitk.OtsuThreshold(image, inside_value, outside_value) )
    bounding_box = label_shape_filter.GetBoundingBox(outside_value)
    # The bounding box's first "dim" entries are the starting index and last "dim" entries the size
    return sitk.RegionOfInterest(image, bounding_box[int(len(bounding_box)/2):], bounding_box[0:int(len(bounding_box)/2)])
    
    
def resample_image(original_image, reference_image, T0,
                    interpolator = sitk.sitkLinear, default_intensity_value = 0.0):
    normalized_image = sitk.Resample(original_image, reference_image, sitk.Transform(T0),
                              interpolator, default_intensity_value)
    return(normalized_image)

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def preprocess(path_to_folder, normalization = 'mean'):
    #####Loading the image
    seed = 0
    size = 160
    valid_ratio=0.2
    test_ratio=0.2
    reference_spacing = [1,1,1]
    tf_record_path = 'data/_tf_records_{}_seed_{}'.format(size, seed)
    labels = pd.read_csv(os.path.join(path_to_folder, 'labels.csv'),
                     decimal=",")
    with open('ignored.json') as json_file:
        ignored_dict = json.load(json_file)  
    list_series = labels.Sequence_id.tolist()
    np.random.seed(seed)
    list_series = [serie for serie in list_series  if not str(serie) in ignored_dict["ignored"]]  
    list_series = np.random.permutation(list_series)
    
    n_test_set = int(len(list_series) * test_ratio)
    n_valid_set = int(len(list_series) * valid_ratio)
    test_series = list_series[:n_test_set]
    valid_series = list_series[n_test_set:n_test_set + n_valid_set]
    train_series = list_series[n_test_set + n_valid_set:]
    split_dict = {'train_id': list(map(str, train_series)),
                  'valid_id': list(map(str, valid_series)),
                  'test_id': list(map(str, test_series))}
    if not os.path.exists(tf_record_path):
        os.makedirs(tf_record_path)
    with open(os.path.join(tf_record_path, 'split.json'), 'w') as json_file:
        json.dump(split_dict, json_file)
    for i,serie in tqdm.tqdm(enumerate(list_series)):
        serie_name = str(serie)
        serie_path = os.path.join(path_to_folder, serie_name)
        infos = {'serie_description': serie_name}
        reader = sitk.ImageSeriesReader()
        dicom_names = reader.GetGDCMSeriesFileNames(serie_path)
        reader.SetFileNames(dicom_names)
        image = reader.Execute()
        const_voxel_dims = image.GetSize()
        infos['dims'] = const_voxel_dims
        ConstPixelSpacing = image.GetSpacing()
        infos['resolution'] = ConstPixelSpacing
        patient_infos = labels[labels.Sequence_id == int(serie)]
        features = {
        'age': _int64_feature(patient_infos['age'].iloc[0]),
        'Sequence_id': _bytes_feature(str(patient_infos['Sequence_id'].iloc[0])
                                      .encode('utf-8')),
        'EDSS': _float_feature(float(patient_infos['EDSS'].iloc[0])),
        'examination_date': _bytes_feature(patient_infos['examination_date'].iloc[0]
                                           .encode('utf-8')),}
                
        serie_name = str(serie)
        serie_path = os.path.join(path_to_folder, serie_name)
        infos = {'serie_description': serie_name}
        reader = sitk.ImageSeriesReader()
        dicom_names = reader.GetGDCMSeriesFileNames(serie_path)
        reader.SetFileNames(dicom_names)
        image = reader.Execute()
        dimension = 3 
        reference_origin = np.zeros(dimension)
        reference_direction = np.identity(dimension).flatten()
        reference_size = [300]*dimension 
        
        reference_image = sitk.Image(reference_size, image.GetPixelIDValue())
        reference_image.SetOrigin(reference_origin)
        reference_image.SetSpacing(reference_spacing)
        reference_image.SetDirection(reference_direction)
        reference_center = np.array(reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize())/2.0))

        transform = sitk.AffineTransform(dimension)
        transform.SetMatrix(image.GetDirection())
        transform.SetTranslation(np.array(image.GetOrigin()) - reference_origin)
        centering_transform = sitk.TranslationTransform(dimension)
        img_center = np.array(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())/2.0))
        centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
        centered_transform = sitk.Transform(transform)
        centered_transform.AddTransform(centering_transform)
        aug_transform = sitk.AffineTransform(dimension)
        aug_transform.SetCenter(reference_center)
        aug_transform.SetMatrix(image.GetDirection())
        T_all = sitk.Transform(centered_transform)
        T_all.AddTransform(aug_transform.GetInverse())
        aug_image = sitk.Resample(image, reference_image, T_all,  sitk.sitkLinear,0.0)
        bb = threshold_based_crop(aug_image)
        image_array = np.swapaxes(sitk.GetArrayFromImage(bb).astype('int16'),0,2)[-size:,-size:,-size:]
        if normalization == 'mean':
            mean_arr = np.mean(image_array) 
            max_arr = np.amax(image_array)
            image_array = np.clip((image_array - mean_arr) / max_arr, 0, 1)
        elif normalization == 'min':
            min_arr = np.amin(image_array) 
            max_arr = np.amax(image_array)
            image_array = np.clip((image_array - min_arr) / max_arr , 0, 1)
        elif normalization == 'clahe':
            raise NotImplemented
        stringDicom = image_array.tostring()
        serializedDicom = _bytes_feature(stringDicom)
        features['serie_description'] = (_bytes_feature
                                         (serie_name
                                          .encode('utf-8')))
        h, w, d = array.shape
        features['h'], features['w'], features['d'] = (_int64_feature(h),
                                                       _int64_feature(w),
                                                       _int64_feature(d))
        r_h, r_w, r_d = image.GetSpacing()
        features['r_h'], features['r_w'], features['r_d'] = (_float_feature(r_h),
                                                             _float_feature(r_w),
                                                             _float_feature(r_d))
        features['image_raw'] = serializedDicom
                    
        example = tf.train.Example(features=tf.train.Features(feature=features))
        mode = (serie in train_series) * 'train' + (serie in test_series) * 'test' + (serie in valid_series) * 'valid'
        with tf.io.TFRecordWriter(os.path.join(tf_record_path,
                                                   '{}_set_{}.tfrecords')
                                             .format(mode, serie)) as writer:
                writer.write(example.SerializeToString())

In [5]:
preprocess(path_to_folder)

480it [07:00,  1.14it/s]
