In [1]:
#depth, height, width, channels
IMAGE_DIMS = (224, 152, 224, 1)

NR_SHARDS = 2

SAVE_IMAGES = True

INPUT_FOLDER = '../../output/step3/'
OUTPUT_FOLDER = '../../output/step4/'

In [2]:
import sys
import h5py
from random import shuffle
import numpy as np
from numpy import ndarray
import pandas as pd
import statistics
import csv
import dicom
import math
from time import time
import os
import shutil
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import rotate
from scipy.ndimage.interpolation import shift
import itertools
from itertools import product, combinations
from skimage import measure, morphology, transform
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import datetime
import logging

In [3]:
class Timer:
    def __init__(self, name, debug=True):
        self._name = name
        self._debug = debug
        self.start()
    
    def start(self):
        self._start = time()
        if(self._debug):
            logger.info('> [started] ' + self._name + '...')

    def stop(self):
        self._lastElapsed = (time()-self._start)
        if(self._debug):
            logger.info('> [done]    {} ({:.3f} ms)'.format(self._name, self._lastElapsed*1000))
            
    def elapsed(self):
        if(self._lastElapsed != None):
            return (self._lastElapsed)
        else:
            return (time()-self._start)


In [4]:
logger = logging.getLogger()
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)

def setup_file_logger(log_file):
    hdlr = logging.FileHandler(log_file)
    hdlr.setLevel(logging.DEBUG)
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr) 
    logger.setLevel(logging.DEBUG)

In [5]:
def show_image(pixels, slice_pos, patient_id):
    fig1, ax1 = plt.subplots(1)
    fig1.set_size_inches(4,4)
    ax1.imshow(pixels[round(np.shape(pixels)[0]*(slice_pos-1))][:,:,0], cmap=plt.cm.gray)
    
    if(SAVE_IMAGES):
        file = OUTPUT_FOLDER + 'images/' + patient_id + '-' + 'slice-' + str(slice_pos) + '.jpg'
        plt.savefig(file)
        plt.close(fig1)
    else:
        plt.show()


In [6]:
def show_slices(pixels, patient_id, nr_slices=12, cols=4):
    fig = plt.figure()
    slice_depth = round(np.shape(pixels)[0]/nr_slices)
    rows = round(nr_slices/cols)+1
    fig.set_size_inches(cols*10, rows*10)
    for i in range(nr_slices):
        slice_pos = int(slice_depth*i)
        y = fig.add_subplot(rows,cols,i+1)
        y.imshow(pixels[slice_pos][:,:,0], cmap='gray')
        
    if(SAVE_IMAGES):
        file = OUTPUT_FOLDER + 'images/' + patient_id + '-' + 'slices.jpg'
        plt.savefig(file)
        plt.close(fig)
    else:
        plt.show()


In [7]:
def validate_dataset(dataset_file):
    logger.info('VALIDATING OUTPUT DATASET ' + dataset_file)

    with h5py.File(dataset_file, 'r') as h5f:
        x_ds = h5f['X']
        y_ds = h5f['Y']

        if(len(x_ds) != len(y_ds)):
            logger.warning('VALIDATION ERROR: x and y datasets with different lengths')

        for px in range(len(x_ds)):
            arr = np.array(x_ds[px])
            if(not np.any(arr)):
                logger.warn('VALIDATION ERROR: No image found index=' + str(px))

        for py in range(len(y_ds)):
            arr = np.array(y_ds[py])
            if(not np.any(arr)):
                logger.warn('VALIDATION ERROR: No label found index=' + str(py) + 'label=' + str(arr))

        logger.info('X shape=' + str(x_ds.shape))
        logger.info('Y shape=' + str(y_ds.shape))

        size = len(x_ds)
        qtty = min(10, size)
        f = size/qtty
        for i in range(qtty):
            pi = round(i*f)
            logger.info('patient_index' + str(pi))
            logger.info('x=')
            show_slices(x_ds[pi], 'validation-' + str(pi))
            logger.info('y=' + str(y_ds[pi]))


In [8]:
def start_processing(input_dir, nr_shards, image_dims, output_dir):
    logger.info('Merging shard results. nr_shards=' + str(nr_shards) + ' input_dir='+ str(input_dir) + ' output_dir=' + output_dir)
    
    t = Timer('Preparing output dir')
    shutil.rmtree(output_dir, True)
    try:
        os.makedirs(output_dir + 'images')
    except:
        logger.error('Ops! Couldnt create output dir')
        pass
    t.stop()

    setup_file_logger(output_dir + 'out.log')
    
    t = Timer('Count total patients among shards')
    total_patients = 0
    for shard_id in range(1,nr_shards+1):
        dataset_file = input_dir + '{}/data-centered-rotated-{}-{}-{}.h5'.format(shard_id, image_dims[0], image_dims[1], image_dims[2])
        with h5py.File(dataset_file, 'r') as h5f:
            logger.info('shard_id={} shape={}'.format(shard_id,h5f['X'].shape))
            total_patients = total_patients + len(h5f['X'])
    t.stop()
            
    logger.info('total_patients=' + str(total_patients))

    t = Timer('Creating output merged dataset')
    output_dataset_file = output_dir + 'data-centered-rotated-{}-{}-{}.h5'.format(image_dims[0], image_dims[1], image_dims[2])
    with h5py.File(output_dataset_file, 'w') as h5f:
        x_ds = h5f.create_dataset('X', (total_patients, image_dims[0], image_dims[1], image_dims[2], image_dims[3]), chunks=(1, image_dims[0], image_dims[1], image_dims[2], image_dims[3]), dtype='f')
        y_ds = h5f.create_dataset('Y', (total_patients, 2), dtype='f')

        logger.info('Merging shards')
        pb = 0
        for shard_id in range(1,nr_shards+1):
            ts = Timer('Processing shard' + str(shard_id))
            dataset_file = input_dir + '{}/data-centered-rotated-{}-{}-{}.h5'.format(shard_id, image_dims[0], image_dims[1], image_dims[2])
            with h5py.File(dataset_file, 'r') as sh5f:
                shard_x_ds = sh5f['X']
                shard_y_ds = sh5f['Y']
                le = len(shard_x_ds)
                pe = pb + le
                logger.debug('output' + str(pb) + ' ' + str(pe) + ' input ' + str(0) + str(le))
                x_ds[pb:pe] = shard_x_ds[0:le]
                y_ds[pb:pe] = shard_y_ds[0:le]
                pb = pe
            ts.stop()
    t.stop()
    
    t = Timer('Output dataset validations')
    validate_dataset(output_dataset_file)
    t.stop()

In [9]:
logger.info('==== PROCESSING SHARDS MERGE ====')
start_processing(INPUT_FOLDER, NR_SHARDS, IMAGE_DIMS, OUTPUT_FOLDER)
logger.info('==== ALL DONE ====')

2017-02-19 19:55:23,915 INFO > [started] Count total patients among shards...
2017-02-19 19:55:23,917 INFO shard_id=1 shape=(19, 224, 152, 224, 1)
2017-02-19 19:55:23,919 INFO shard_id=2 shape=(19, 224, 152, 224, 1)
2017-02-19 19:55:23,920 INFO > [done]    Count total patients among shards (4.851 ms)
2017-02-19 19:55:23,921 INFO total_patients=38
2017-02-19 19:55:23,922 INFO > [started] Creating output merged dataset...
2017-02-19 19:55:23,924 INFO Merging shards
2017-02-19 19:55:23,925 INFO > [started] Processing shard1...
2017-02-19 19:55:24,749 INFO > [done]    Processing shard1 (824.677 ms)
2017-02-19 19:55:24,750 INFO > [started] Processing shard2...
2017-02-19 19:55:25,602 INFO > [done]    Processing shard2 (851.977 ms)
2017-02-19 19:55:25,611 INFO > [done]    Creating output merged dataset (1688.656 ms)
2017-02-19 19:55:25,611 INFO > [started] Output dataset validations...
2017-02-19 19:55:25,612 INFO VALIDATING OUTPUT DATASET ../../output/step4/data-centered-rotated-224-152-224