In [1]:
INPUT_IMAGE_DIMS = (224, 152, 224, 1)
OUTPUT_IMAGE_DIMS = (50, 34, 50, 1)

SAVE_IMAGES = True

INPUT_FOLDER = '../../output/step4/'
OUTPUT_FOLDER = '../../output/step5/'

In [2]:
import sys
import h5py
from random import shuffle
import numpy as np
from numpy import ndarray
import pandas as pd
import statistics
import csv
import dicom
import math
from time import time
import os
import shutil
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import rotate
from scipy.ndimage.interpolation import shift
import itertools
from itertools import product, combinations
from skimage import measure, morphology, transform
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import datetime
import logging

In [3]:
class Timer:
    def __init__(self, name, debug=True):
        self._name = name
        self._debug = debug
        self.start()
    
    def start(self):
        self._start = time()
        if(self._debug):
            logger.info('> [started] ' + self._name + '...')

    def stop(self):
        self._lastElapsed = (time()-self._start)
        if(self._debug):
            logger.info('> [done]    {} ({:.3f} ms)'.format(self._name, self._lastElapsed*1000))
            
    def elapsed(self):
        if(self._lastElapsed != None):
            return (self._lastElapsed)
        else:
            return (time()-self._start)


In [4]:
logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)

def setup_file_logger(log_file):
    hdlr = logging.FileHandler(log_file)
    hdlr.setLevel(logging.DEBUG)
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr) 
    logger.setLevel(logging.DEBUG)

In [5]:
def show_image(pixels, slice_pos, patient_id):
    fig1, ax1 = plt.subplots(1)
    fig1.set_size_inches(4,4)
    ax1.imshow(pixels[round(np.shape(pixels)[0]*(slice_pos-1))][:,:,0], cmap=plt.cm.gray)
    
    if(SAVE_IMAGES):
        file = OUTPUT_FOLDER + 'images/' + patient_id + '-' + 'slice-' + str(slice_pos) + '.jpg'
        plt.savefig(file)
        plt.close(fig1)
    else:
        plt.show()


In [6]:
def show_slices(pixels, patient_id, nr_slices=12, cols=4):
    fig = plt.figure()
    slice_depth = round(np.shape(pixels)[0]/nr_slices)
    rows = round(nr_slices/cols)+1
    fig.set_size_inches(cols*10, rows*10)
    for i in range(nr_slices):
        slice_pos = int(slice_depth*i)
        y = fig.add_subplot(rows,cols,i+1)
        y.imshow(pixels[slice_pos][:,:,0], cmap='gray')
        
    if(SAVE_IMAGES):
        file = OUTPUT_FOLDER + 'images/' + patient_id + '-' + 'slices.jpg'
        plt.savefig(file)
        plt.close(fig)
    else:
        plt.show()


In [7]:
def validate_dataset(dataset_file):
    logger.info('VALIDATING OUTPUT DATASET ' + dataset_file)

    with h5py.File(dataset_file, 'r') as h5f:
        x_ds = h5f['X']
        y_ds = h5f['Y']

        if(len(x_ds) != len(y_ds)):
            logger.error('VALIDATION ERROR: x and y datasets with different lengths')

        for px in range(len(x_ds)):
            arr = np.array(x_ds[px])
            if(not np.any(arr)):
                logger.error('VALIDATION ERROR: No image found index=' + str(px))

        for py in range(len(y_ds)):
            arr = np.array(y_ds[py])
            if(not np.any(arr)):
                logger.error('VALIDATION ERROR: No label found index=' + str(py) + 'label=' + str(arr))

        logger.info('X shape=' + str(x_ds.shape))
        logger.info('Y shape=' + str(y_ds.shape))

        logger.info('Taking some shots from the output sample for later verification')
        size = len(x_ds)
        qtty = min(10, size)
        f = size/qtty
        for i in range(qtty):
            pi = round(i*f)
            logger.info('patient_index ' + str(pi))
            logger.info('x=')
            show_slices(x_ds[pi], 'validation-' + str(pi))
            logger.info('y=' + str(y_ds[pi]))


In [8]:
def generate_dataset(source_dataset_file, output_dataset_file, source_start_index, source_end_index, image_dims):
    with h5py.File(source_dataset_file, 'r') as input_h5f:
        logger.info('input x shape={}'.format(input_h5f['X'].shape))
        input_x_ds = input_h5f['X']
        input_y_ds = input_h5f['Y']
        
        output_len = source_end_index - source_start_index
        
        with h5py.File(output_dataset_file, 'w') as output_h5f:
            output_x_ds = output_h5f.create_dataset('X', (output_len, image_dims[0], image_dims[1], image_dims[2], 1), chunks=(1, image_dims[0], image_dims[1], image_dims[2], 1), dtype='f')
            output_y_ds = output_h5f.create_dataset('Y', (output_len, 2), dtype='f')

            output_x_ds[0:output_len] = input_x_ds[source_start_index:source_end_index]
            output_y_ds[0:output_len] = input_y_ds[source_start_index:source_end_index]

In [13]:
def start_processing(input_dir, input_image_dims, output_image_dims, output_dir):
    logger.info('Resizing images. input_dir='+ str(input_dir) + ' output_dir=' + output_dir)
    
    t = Timer('Preparing output dir')
    shutil.rmtree(output_dir, True)
    try:
        os.makedirs(output_dir + 'images')
    except:
        logger.warning('Ops! Couldnt create output dir')
        pass
    t.stop()

    setup_file_logger(output_dir + 'out.log')

    t = Timer('Starting to resize dataset')
    len_input_x_ds = None
    resize_factor = (output_image_dims[0]/input_image_dims[0], output_image_dims[1]/input_image_dims[1], output_image_dims[2]/input_image_dims[2], 1)
    dataset_file = input_dir + 'data-centered-rotated-{}-{}-{}.h5'.format(input_image_dims[0], input_image_dims[1], input_image_dims[2])
    with h5py.File(dataset_file, 'r') as input_h5f:
        logger.info('input x shape={}'.format(input_h5f['X'].shape))
        input_x_ds = input_h5f['X']
        input_y_ds = input_h5f['Y']
        len_input_x_ds = len(input_x_ds)
        len_input_y_ds = len(input_y_ds)

        output_dataset_file = output_dir + 'data-centered-rotated-{}-{}-{}.h5'.format(output_image_dims[0], output_image_dims[1], output_image_dims[2])
        with h5py.File(output_dataset_file, 'w') as output_h5f:
            output_x_ds = output_h5f.create_dataset('X', (len_input_x_ds, output_image_dims[0], output_image_dims[1], output_image_dims[2], 1), chunks=(1, output_image_dims[0], output_image_dims[1], output_image_dims[2], 1), dtype='f')
            output_y_ds = output_h5f.create_dataset('Y', (len_input_y_ds, 2), dtype='f')

            for pi in range(len_input_x_ds):
                ts = Timer('Resizing patient index ' + str(pi))
                image_pixels = input_x_ds[pi]
#                 image_pixels = np.expand_dims(image_pixels, axis=3)#REMOVE LATER
#                 logger.info('redim shape=' + str(np.shape(image_pixels)))
                image_pixels = scipy.ndimage.interpolation.zoom(image_pixels, resize_factor)
                logger.info('resized shape=' + str(np.shape(image_pixels)))
                output_x_ds[pi] = image_pixels
                output_y_ds[pi] = input_y_ds[pi]
                ts.stop()
    t.stop()
            
    t = Timer('Validate output dataset')
    validate_dataset(output_dataset_file)
    t.stop()
    
    t = Timer('Generate train dataset')
    dataset_file = output_dir + 'train-centered-rotated-{}-{}-{}.h5'.format(output_image_dims[0], output_image_dims[1], output_image_dims[2])
    train_end = round(len_input_x_ds * 0.8)
    generate_dataset(output_dataset_file, dataset_file, 0, train_end, output_image_dims)
#     validate_dataset(dataset_file)
    t.stop()

    t = Timer('Generate validate dataset')
    dataset_file = output_dir + 'validate-centered-rotated-{}-{}-{}.h5'.format(output_image_dims[0], output_image_dims[1], output_image_dims[2])
    validate_end = train_end + round(len_input_x_ds * 0.1)
    generate_dataset(output_dataset_file, dataset_file, train_end, validate_end, output_image_dims)
#     validate_dataset(dataset_file)
    t.stop()
    
    t = Timer('Generate test dataset')
    dataset_file = output_dir + 'test-centered-rotated-{}-{}-{}.h5'.format(output_image_dims[0], output_image_dims[1], output_image_dims[2])
    generate_dataset(output_dataset_file, dataset_file, validate_end, len_input_x_ds, output_image_dims)
#     validate_dataset(dataset_file)
    t.stop()


In [12]:
logger.info('==== PROCESSING SHARDS MERGE ====')
start_processing(INPUT_FOLDER, INPUT_IMAGE_DIMS, OUTPUT_IMAGE_DIMS, OUTPUT_FOLDER)
logger.info('==== ALL DONE ====')

2017-02-19 22:09:37,114 INFO ==== PROCESSING SHARDS MERGE ====
2017-02-19 22:09:37,115 INFO Resizing images. input_dir=../../output/step4/ output_dir=../../output/step5/
2017-02-19 22:09:37,116 INFO > [started] Preparing output dir...
2017-02-19 22:09:37,118 INFO > [done]    Preparing output dir (1.326 ms)
2017-02-19 22:09:37,119 INFO > [started] Starting to resize dataset...
2017-02-19 22:09:37,120 INFO input x shape=(231, 224, 152, 224)
2017-02-19 22:09:37,123 INFO > [started] Resizing patient index 0...
2017-02-19 22:09:37,162 INFO redim shape=(224, 152, 224, 1)
2017-02-19 22:09:37,862 INFO resized shape=(50, 34, 50, 1)
2017-02-19 22:09:37,864 INFO > [done]    Resizing patient index 0 (740.428 ms)
2017-02-19 22:09:37,865 INFO > [started] Resizing patient index 1...
2017-02-19 22:09:38,016 INFO redim shape=(224, 152, 224, 1)
2017-02-19 22:09:38,695 INFO resized shape=(50, 34, 50, 1)
2017-02-19 22:09:38,697 INFO > [done]    Resizing patient index 1 (832.328 ms)
2017-02-19 22:09:38,698