In [1]:
IMAGE_DIMS = (50, 34, 50, 1)

SAVE_IMAGES = True

INPUT_FOLDER = '../../input/sample_images/'
OUTPUT_FOLDER = '../../output/step10/'

PATIENTS_FILE = '../../input/sample_dummy_submission.csv'
CNN_MODEL_FILE = '../../output/train-local/tf-checkpoint-best7826'

_model = None

In [2]:
import csv
import sys
import h5py
from random import shuffle
import numpy as np # linear algebra
from numpy import ndarray
import statistics
import csv
import dicom
import math
from time import time
import os
import shutil
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import rotate
from scipy.ndimage.interpolation import shift
import itertools
from itertools import product, combinations
from skimage import measure, morphology, transform
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import logging

import tflearn
from tflearn.layers.core import *
from tflearn.layers.conv import *
from tflearn.data_utils import *
from tflearn.layers.normalization import *
from tflearn.layers.estimator import regression

from modules.logging import logger
import modules.logging
import modules.lungprepare as lungprepare
import modules.utils as utils
from modules.utils import Timer


In [3]:
def get_patient_ids(patients_file):
    patients = []
    
    file = csv.DictReader(open(patients_file))
    for row in file:
        p = row['id']
        patients.append(p)
    logger.info('found {} patients for prediction'.format(len(patients)))
    
    return patients

In [4]:
def network(image_dims):
    net = input_data(shape=[None, image_dims[0], image_dims[1], image_dims[2], image_dims[3]], dtype=tf.float32)
    
    net = conv_3d(net, 32, 3, strides=1, activation='relu')
    net = max_pool_3d(net, [1,2,2,2,1], strides=[1,2,2,2,1])

    net = conv_3d(net, 64, 3, strides=1, activation='relu')
    net = max_pool_3d(net, [1,2,2,2,1], strides=[1,2,2,2,1])
    
    net = fully_connected(net, 64, activation='relu')
    net = dropout(net, 0.8)
    
    net = fully_connected(net, 2, activation='softmax')
    
    net = regression(net, optimizer='adam',
                     loss='categorical_crossentropy',
                     learning_rate=0.001)
    return net

In [5]:
def prepare_cnn(cnn_model_file, image_dims):
    global _model
    if(_model is None):
        logger.info('Prepare CNN')

        logger.info('Load CNN network...')
        net = network(image_dims)

        logger.info('Start engine...')
        _model = tflearn.models.dnn.DNN(net)

        logger.info('Load previous training...')
        _model.load(cnn_model_file)
    else:
        logger.info('CNN model already loaded. Reusing.')
    return _model

In [6]:
def predict_patient(model, input_dir, patient_id, image_dims, output_dir):
    logger.info('>>> Predict patient_id ' + patient_id)
    logger.info('Loading pre-processed images for patient')

    dataset_file = output_dir + 'predict-centered-rotated-{}-{}-{}.h5'.format(image_dims[0], image_dims[1], image_dims[2])
    
    #patient pre-processed image cache
    patient_pixels = None
    with h5py.File(dataset_file, 'a') as h5f:
        try:
            patient_pixels = h5f[patient_id]
            logger.debug('Patient image found in cache. Using it.')
            #disconnect from HDF5
            patient_pixels = np.array(patient_pixels)
            
        except KeyError:
            logger.debug('Patient image not found in cache')
            t = Timer('Preparing patient scan image volume. patient_id=' + patient_id)
            patient_pixels = lungprepare.process_patient_images(input_dir + patient_id, image_dims)
            if(patient_pixels is None):
                logger.warning('Patient lung not found. Skipping.')
            logger.debug('Storing patient image in cache')
            h5f[patient_id] = patient_pixels
            t.stop()
    
    t = Timer('Predicting result on CNN (forward)')
    y = model.predict(np.expand_dims(patient_pixels, axis=0))
    logger.info('PATIENT '+ patient_id +' PREDICT=' + str(y))
    utils.show_slices(patient_pixels, patient_id)
    t.stop()
    
    return y

In [7]:
def start_processing(input_dir, patients_file, cnn_model_file, max_patients, image_dims, output_dir):
    logger.info('Predicting patients. ' + ' max_patients='+ str(max_patients) + ' input_dir=' + input_dir + ' output_dir=' + output_dir)
    
    logger.info('Preparing output dir')
#     shutil.rmtree(output_dir, True)
    try:
        os.makedirs(output_dir + 'images/')
    except:
        logger.warning('Ops! Couldnt create output dir')
        pass

    modules.logging.setup_file_logger(output_dir + 'out.log')

    model = prepare_cnn(cnn_model_file, image_dims)
    
    logger.info('Collect patient ids for analysis')
    patient_ids = get_patient_ids(patients_file)
    total_patients = len(patient_ids)
    logger.debug('Found ' + str(total_patients) + ' patients')

    count = 0
    for patient_id in patient_ids:
        if(count>(max_patients-1)):
            break
            
        y = predict_patient(model, input_dir, patient_id, image_dims, output_dir)
        logger.info("Prediction for patient " + patient_id + ' is ' + str(y))

        count = count + 1

In [None]:
logger.info('==== PROCESSING PREDICTION ====')
start_processing(INPUT_FOLDER, PATIENTS_FILE, CNN_MODEL_FILE, 9, IMAGE_DIMS, OUTPUT_FOLDER)
logger.info('==== ALL DONE ====')

2017-02-21 04:15:26,986 INFO ==== PROCESSING PREDICTION ====
2017-02-21 04:15:26,988 INFO Predicting patients.  max_patients=9 input_dir=../../input/sample_images/ output_dir=../../output/step10/
2017-02-21 04:15:26,989 INFO Preparing output dir
2017-02-21 04:15:26,992 INFO CNN model already loaded. Reusing.
2017-02-21 04:15:26,993 INFO Collect patient ids for analysis
2017-02-21 04:15:26,994 INFO found 19 patients for prediction
2017-02-21 04:15:26,995 DEBUG Found 19 patients
2017-02-21 04:15:26,996 INFO >>> Predict patient_id 0de72529c30fe642bc60dcb75c87f6bd
2017-02-21 04:15:26,997 INFO Loading pre-processed images for patient
2017-02-21 04:15:27,000 DEBUG Patient image found in cache. Using it.
2017-02-21 04:15:27,005 INFO > [started] Predicting result on CNN (forward)...
2017-02-21 04:15:27,937 INFO PATIENT 0de72529c30fe642bc60dcb75c87f6bd PREDICT=[[0.904554009437561, 0.095445916056633]]
