In [2]:
%load_ext autoreload
%autoreload 2  
import sys
sys.path.insert(0,'../')
import src.utils
import pandas as pd 
import tqdm
import os
import multiprocessing
import numpy as np 
import SimpleITK as sitk 
import matplotlib.pyplot as plt 
import scipy.misc
import multiprocessing
import json 
import tensorflow as tf

In [3]:
path_to_folder = 'data/Dataset_1/'
labels = pd.read_csv(os.path.join(path_to_folder, 'labels.csv'),
                     decimal=",")
list_series = labels.Sequence_id.tolist()

In [4]:
def resample_image(original_image, reference_image, T0,
                    interpolator = sitk.sitkLinear, default_intensity_value = 0.0):
    normalized_image = sitk.Resample(original_image, reference_image, sitk.Transform(T0),
                              interpolator, default_intensity_value)
    return(normalized_image)

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def preprocess(path_to_folder):
    #####Loading the image
    seed = 0
    size = 112
    valid_ratio=0.2
    test_ratio=0.2
    tf_record_path = 'data/_tf_records_{}_seed_{}'.format(size, seed)
    labels = pd.read_csv(os.path.join(path_to_folder, 'labels.csv'),
                     decimal=",")
    list_series = labels.Sequence_id.tolist()
    list_arrays = []
    
    np.random.seed(seed)
    list_series = np.random.permutation(list_series)
    n_test_set = int(len(list_series) * test_ratio)
    n_valid_set = int(len(list_series) * valid_ratio)
    test_series = list_series[:n_test_set]
    valid_series = list_series[n_test_set:n_test_set + n_valid_set]
    train_series = list_series[n_test_set + n_valid_set:]
    split_dict = {'train_id': list(map(str, train_series)),
                  'valid_id': list(map(str, valid_series)),
                  'test_id': list(map(str, test_series))}
    if not os.path.exists(tf_record_path):
        os.makedirs(tf_record_path)
    with open(os.path.join(tf_record_path, 'split.json'), 'w') as json_file:
        json.dump(split_dict, json_file)
    for i,serie in tqdm.tqdm(enumerate(list_series)):
        serie_name = str(serie)
        serie_path = os.path.join(path_to_folder, serie_name)
        infos = {'serie_description': serie_name}
        reader = sitk.ImageSeriesReader()
        dicom_names = reader.GetGDCMSeriesFileNames(serie_path)
        reader.SetFileNames(dicom_names)
        image = reader.Execute()
        const_voxel_dims = image.GetSize()
        infos['dims'] = const_voxel_dims
        ConstPixelSpacing = image.GetSpacing()
        infos['resolution'] = ConstPixelSpacing
        ArrayDicom = sitk.GetArrayFromImage(image).astype('int16')
        patient_infos = labels[labels.Sequence_id == int(serie)]
        features = {
        'age': _int64_feature(patient_infos['age'].iloc[0]),
        'Sequence_id': _bytes_feature(str(patient_infos['Sequence_id'].iloc[0])
                                      .encode('utf-8')),
        'EDSS': _float_feature(float(patient_infos['EDSS'].iloc[0])),
        'examination_date': _bytes_feature(patient_infos['examination_date'].iloc[0]
                                           .encode('utf-8')),}
                
        if i == 0:  
            dimension = 3
            reference_physical_size = np.zeros(dimension)
            reference_physical_size[:] = [(sz-1)*spc if sz*spc>mx  else mx for sz,spc,mx in zip(image.GetSize(), image.GetSpacing(), reference_physical_size)]     
            reference_origin = np.zeros(dimension)
            reference_direction = np.identity(dimension).flatten()
            reference_size = [size]*dimension 
            reference_spacing = [ phys_sz/(sz-1) for sz,phys_sz in zip(reference_size, reference_physical_size) ]
            reference_image = sitk.Image(reference_size, image.GetPixelIDValue())
            reference_image.SetOrigin(reference_origin)
            reference_image.SetSpacing(reference_spacing)
            reference_image.SetDirection(reference_direction)
            reference_center = np.array(reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize())/2.0))

        transform = sitk.AffineTransform(dimension)
        transform.SetMatrix(image.GetDirection())
        transform.SetTranslation(np.array(image.GetOrigin()) - reference_origin)
        # Modify the transformation to align the centers of the original and reference image instead of their origins.
        centering_transform = sitk.TranslationTransform(dimension)
        img_center = np.array(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())/2.0))
        centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
        centered_transform = sitk.Transform(transform)
        centered_transform.AddTransform(centering_transform)
        normalized_image = resample_image(image, reference_image, transform) 
        image_array = sitk.GetArrayFromImage(normalized_image).astype('int16')
        stringDicom = image_array.tostring()
        serializedDicom = _bytes_feature(stringDicom)
        features['serie_description'] = (_bytes_feature
                                         (serie_name
                                          .encode('utf-8')))
        h, w, d = normalized_image.GetSize()
        features['h'], features['w'], features['d'] = (_int64_feature(h),
                                                       _int64_feature(w),
                                                       _int64_feature(d))
        r_h, r_w, r_d = normalized_image.GetSpacing()
        features['r_h'], features['r_w'], features['r_d'] = (_float_feature(r_h),
                                                             _float_feature(r_w),
                                                             _float_feature(r_d))
        features['image_raw'] = serializedDicom
                    
        example = tf.train.Example(features=tf.train.Features(feature=features))
        mode = (serie in train_series) * 'train' + (serie in test_series) * 'test' + (serie in valid_series) * 'valid'
        with tf.io.TFRecordWriter(os.path.join(tf_record_path,
                                                   '{}_set_{}.tfrecords')
                                             .format(mode, serie)) as writer:
                writer.write(example.SerializeToString())

In [5]:
preprocess(path_to_folder)

480it [07:00,  1.14it/s]
