In [1]:
from __future__ import print_function

import h5py
import os
import numpy as np
import dicom
from scipy.misc import imresize
from utils import preprocess, rotation_augmentation, shift_augmentation
import matplotlib.pyplot as plt
% matplotlib inline

img_resize = True
img_shape = (256, 256)
DATA_DIR = '/media/haidar/Storage/Data/SADSB/'

In [2]:
def crop_resize(img):
    """
    Crop center and resize.

    :param img: image to be cropped and resized.
    """
    if img.shape[0] < img.shape[1]:
        img = img.T
    # we crop image from center
    short_edge = min(img.shape[:2])
    yy = int((img.shape[0] - short_edge) / 2)
    xx = int((img.shape[1] - short_edge) / 2)
    crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
    img = crop_img
    img = imresize(img, img_shape)
    return img


def load_images(from_dir, verbose=True):
    """
    Load images in the form study x slices x width x height.
    Each image contains 30 time series frames so that it is ready for the convolutional network.

    :param from_dir: directory with images (train or validate)
    :param verbose: if true then print data
    """
    print('-'*50)
    print('Loading all DICOM images from {0}...'.format(from_dir))
    print('-'*50)

    current_study_sub = ''  # saves the current study sub_folder
    current_study = ''  # saves the current study folder
    current_study_images = []  # holds current study images
    ids = []  # keeps the ids of the studies
    study_to_images = dict()  # dictionary for studies to images
    total = 0
    images = []  # saves 30-frame-images
    from_dir = from_dir if from_dir.endswith('/') else from_dir + '/'
    for subdir, _, files in os.walk(from_dir):
        subdir = subdir.replace('\\', '/')  # windows path fix
        subdir_split = subdir.split('/')
        study_id = subdir_split[-3]
        if "sax" in subdir:
            for f in files:
                image_path = os.path.join(subdir, f)
                if not image_path.endswith('.dcm'):
                    continue

                image = dicom.read_file(image_path)
                image = image.pixel_array.astype(float)
                image /= np.max(image)  # scale to [0,1]
                if img_resize:
                    image = crop_resize(image)

                if current_study_sub != subdir:
                    x = 0
                    try:
                        while len(images) < 30:
                            images.append(images[x])
                            x += 1
                        if len(images) > 30:
                            images = images[0:30]

                    except IndexError:
                        pass
                    current_study_sub = subdir
                    current_study_images.append(images)
                    images = []

                if current_study != study_id:
                    study_to_images[current_study] = np.array(current_study_images)
                    if current_study != "":
                        ids.append(current_study)
                    current_study = study_id
                    current_study_images = []
                images.append(image)
                if verbose:
                    if total % 1000 == 0:
                        print('Images processed {0}'.format(total))
                total += 1
    x = 0
    try:
        while len(images) < 30:
            images.append(images[x])
            x += 1
        if len(images) > 30:
            images = images[0:30]
    except IndexError:
        pass

    print('-'*50)
    print('All DICOM in {0} images loaded.'.format(from_dir))
    print('-'*50)

    current_study_images.append(images)
    study_to_images[current_study] = np.array(current_study_images)
    if current_study != "":
        ids.append(current_study)

    return ids, study_to_images


def map_studies_results():
    """
    Maps studies to their respective targets.
    """
    id_to_results = dict()
    train_csv = open(DATA_DIR+'train.csv')
    lines = train_csv.readlines()
    i = 0
    for item in lines:
        if i == 0:
            i = 1
            continue
        id, diastole, systole = item.replace('\n', '').split(',')
        id_to_results[id] = [float(diastole), float(systole)]

    return id_to_results

def append_data(ds, data):
    curRows = ds.shape[0]
    newRows = data.shape[0]
    ds.resize(curRows+newRows, axis=0)
    ds[curRows:, ...] = data

In [3]:
"""
Loads the training data set including X and y and saves it to .h5 file.
"""
print('-'*50)
print('Writing training data to .h5 file...')
print('-'*50)

study_ids, images = load_images(DATA_DIR+'train')  # load images and their ids
studies_to_results = map_studies_results()  # load the dictionary of studies to targets

study_id = study_ids[0]
# x = shift_augmentation(rotation_augmentation(preprocess(images[study_id].astype(np.float32)/255), 15), 0.1, 0.1)
x = images[study_id]
y = studies_to_results[study_id]

YTr = []
for j in range(x.shape[0]):
    YTr.append(y)
    
YVal = []
for j in range(x.shape[0]):
    YVal.append(y)

shpX = x.shape[1:]

with h5py.File(DATA_DIR+'trainData.h5', 'w') as dataFile:
    XTrain = dataFile.create_dataset('XTrain', data=x, 
                                        maxshape=(None, shpX[0], shpX[1], shpX[2]))
    XVal = dataFile.create_dataset('XVal', data=x, 
                                        maxshape=(None, shpX[0], shpX[1], shpX[2]))
    for i in range(1, len(study_ids)):
        study_id = study_ids[i]
#         x = shift_augmentation(rotation_augmentation(preprocess(images[study_id].astype(np.float32)/255), 15), 0.1, 0.1)
        x = images[study_id]
        y = studies_to_results[study_id]
#         randomly assign to either validation or training (prob of getting into training data = 0.92)
        r = np.random.rand()
        if r <= 0.92:
            append_data(XTrain, x)
            for j in range(x.shape[0]):            
                YTr.append(y)
        else:
            append_data(XVal, x)
            for j in range(x.shape[0]):            
                YVal.append(y)

    YTrain = dataFile.create_dataset('YTrain', data=YTr)
    YVal = dataFile.create_dataset('YVal', data=YVal)
  
print('Done saving (split) Training data.')

--------------------------------------------------
Writing training data to .h5 file...
--------------------------------------------------
--------------------------------------------------
Loading all DICOM images from /media/haidar/Storage/Data/SADSB/train...
--------------------------------------------------
Images processed 0
Images processed 1000
Images processed 2000
Images processed 3000
Images processed 4000
Images processed 5000
Images processed 6000
Images processed 7000
Images processed 8000
Images processed 9000
Images processed 10000
Images processed 11000
Images processed 12000
Images processed 13000
Images processed 14000
Images processed 15000
Images processed 16000
Images processed 17000
Images processed 18000
Images processed 19000
Images processed 20000
Images processed 21000
Images processed 22000
Images processed 23000
Images processed 24000
Images processed 25000
Images processed 26000
Images processed 27000
Images processed 28000
Images processed 29000
Images pro

In [4]:
"""
Loads the validation data set including X and study ids and saves it to .h5 file.
"""
print('-'*50)
print('Writing validation data to .h5 file...')
print('-'*50)

ids, images = load_images(DATA_DIR+'validate')

study_id = ids[0]
# x = shift_augmentation(rotation_augmentation(preprocess(images[study_id].astype(np.float32)/255), 15), 0.1, 0.1)
x = images[study_id]
y = study_id
Y = []
for j in range(x.shape[0]):
    Y.append(y)

shpX = x.shape[1:]

with h5py.File(DATA_DIR+'valData.h5', 'w') as dataFile:
    XVal = dataFile.create_dataset('XVal', data=x, 
                                        maxshape=(None, shpX[0], shpX[1], shpX[2]))
        
    for i in range(1, len(ids)):
        study_id = ids[i]
#         x = shift_augmentation(rotation_augmentation(preprocess(images[study_id].astype(np.float32)/255), 15), 0.1, 0.1)
        x = images[study_id]
        y = study_id
        append_data(XVal, x)
        for j in range(x.shape[0]):            
            Y.append(y)
    YVal = dataFile.create_dataset('YVal', data=Y)

print('Done saving validation data.')

--------------------------------------------------
Writing validation data to .h5 file...
--------------------------------------------------
--------------------------------------------------
Loading all DICOM images from /media/haidar/Storage/Data/SADSB/validate...
--------------------------------------------------
Images processed 0
Images processed 1000
Images processed 2000
Images processed 3000
Images processed 4000
Images processed 5000
Images processed 6000
Images processed 7000
Images processed 8000
Images processed 9000
Images processed 10000
Images processed 11000
Images processed 12000
Images processed 13000
Images processed 14000
Images processed 15000
Images processed 16000
Images processed 17000
Images processed 18000
Images processed 19000
Images processed 20000
Images processed 21000
Images processed 22000
Images processed 23000
Images processed 24000
Images processed 25000
Images processed 26000
Images processed 27000
Images processed 28000
Images processed 29000
Image