In [18]:
## Load Imports ##
import nibabel as nib
import numpy as np
from scipy import stats
import time

import warnings
warnings.filterwarnings('ignore')


import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
 

from unet import myUnet
from metrics import dice_coef
from data_generator import DataGenerator
from data_loader import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
## Classification Parallelization ##
from multiprocessing import Process

def runInParallel(*fns):
    proc = []
    for fn in fns:
        p = Process(target=fn)
        p.start()
        proc.append(p)
    for p in proc:
        p.join()

def a(x):
    return x



In [25]:
# File path for dif. data sources 
patient_fp = '../patient1/'
side_fp = (patient_fp + 'slice_data_side')
back_fp = (patient_fp + 'slice_data_back')
top_fp = (patient_fp + 'slice_data_top')

# Load patient orig. file
patient = np.random.rand(256,256,150)

# Params for generators
params = {'dim': (256,256),
          'batch_size': 1,
          'n_channels': 1,
          'shuffle': False}

## Load side data ##
(_,
 _,
 _,
 _,
 x_side,
 y_side)  = load_data(side_fp, split=(0, 0, 100))
predict_side_gen = DataGenerator(x_side, y_side, **params)

## Load Back Data ##
(_,
 _,
 _,
 _,
 x_back,
 y_back)  = load_data(back_fp, split=(0, 0, 100))
predict_back_gen = DataGenerator(x_back, y_back, **params)


## Load Top Data ##
(_,
 _,
 _,
 _,
 x_top,
 y_top)  = load_data(top_fp, split=(0, 0, 100))
predict_top_gen = DataGenerator(x_top, y_top, **params)



In [20]:
### Load Models ###
dim = (256, 256)

## Side Model
slice_type = 'side'
model_prefix = 'zhi_unet_' + slice_type
weights_fp = ('../weights/' + model_prefix + '.hdf5')

side_unet = myUnet(img_rows=dim[0], img_cols=dim[1])
side_model = side_unet.get_unet_zhi()
side_model.load_weights(weights_fp)

## Back Model
slice_type = 'back'
model_prefix = 'zhi_unet_' + slice_type
weights_fp = ('../weights/' + model_prefix + '.hdf5')

back_unet = myUnet(img_rows=dim[0], img_cols=dim[1])
back_model = back_unet.get_unet_zhi()
back_model.load_weights(weights_fp)

## Top Model
slice_type = 'top'
model_prefix = 'zhi_unet_' + slice_type
weights_fp = ('../weights/' + model_prefix + '.hdf5')

top_unet = myUnet(img_rows=dim[0], img_cols=dim[1])
top_model = top_unet.get_unet_zhi()
top_model.load_weights(weights_fp)



In [31]:
pad = (256-150)//2


In [39]:
start = time.time()
mask = np.zeros_like(patient)
### Predict via generator (implicitly indexing the orig. patient 3d scan) ###
side_output = side_model.predict_generator(predict_side_gen)
end = time.time()
print(end-start)





12.17639446258545


In [44]:
start = time.time()
### Predict via generator (implicitly indexing the orig. patient 3d scan) ###
# Predict sides
back_output = back_model.predict_generator(predict_back_gen)
end = time.time()
print(end-start)

20.66177463531494


In [41]:
start = time.time()
# Predict tops
top_output = top_model.predict_generator(predict_top_gen)

### Reconstruct output ###

end = time.time()
print(end-start)


20.627891302108765
