#### To test: 
1. Create a folder ../data/luna16/
2. Create a folder ../data/luna16/subset2
    -Under this folder copy one scan for testing (script will process all the scan at this location) 
      1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405.mhd & raw file 
      (Google drive https://drive.google.com/drive/u/1/folders/13wmubTgm-7sh3MxPGxqmVZuoqi0G3ufW
3. Create a folder ../data/luna16/hdf5
    -Under this copy UNET_weights_H5.h5 (download from google drive)

In [2]:
import pandas as pd
import numpy as np
import h5py
import pandas as pd
import argparse
import SimpleITK as sitk
from PIL import Image
import os, glob 
import os, os.path
import tensorflow as tf
import keras

from ipywidgets import interact
import json
import pickle
from datetime import datetime
from tqdm import tqdm, trange

from UNET_utils import *
%matplotlib inline

In [3]:
# import argparse
# parser = argparse.ArgumentParser(description='Prediction on HOLDOUT subset',add_help=True)
# parser.add_argument("--holdout", type=int, default=0, help="HOLDOUT subset for predictions")
# args = parser.parse_args()
# HOLDOUT = args.holdout

In [4]:
HOLDOUT = 5
HO_dir = 'HO{}/'.format(HOLDOUT)
data_dir = '../data/luna16/'
model_wghts = 'hdf5/UNET_weights_H{}.h5'.format(HOLDOUT)

In [5]:
PADDED_SIZE = (448, 448, 368)
SLICES = 8
TILE_SIZE = (448,448,SLICES)

In [6]:
def model_create_loadWghts_Model_A(img_size=TILE_SIZE):
    input_shape = tuple(list(img_size) + [1])
    model = create_unet3D_Model_A(input_shape, use_upsampling=True)

    model.load_weights(data_dir + model_wghts)
    model.compile(optimizer='adam',
                  loss=[dice_coef_loss],
                  metrics= [dice_coef])
    return model

In [7]:
def model_create_loadWghts(img_size=TILE_SIZE):
    input_shape = tuple(list(img_size) + [1])
    model = create_UNET3D(input_shape, use_upsampling=True)

    model.load_weights(data_dir + model_wghts)
#   ##Uncomment the followng line when just want to Transfer Weights to matching layers
#     model.load_weights(data_dir + model_wghts, by_name=True)  
    model.compile(optimizer='adam',
                  loss={'PredictionMask': dice_coef_loss, \
                        'PredictionClass': 'binary_crossentropy'}, \
                  loss_weights={'PredictionMask': 0.8, 'PredictionClass': 0.2},
                  metrics={'PredictionMask':dice_coef,'PredictionClass': 'accuracy'})

    return model

In [8]:
def find_mask(model, padded_img):
    print ()
    predicted_mask = np.zeros(PADDED_SIZE)
    print ("Total tiles : {}".format(PADDED_SIZE[2]//SLICES))

    for i in  tqdm(range( PADDED_SIZE[2]//SLICES), total=PADDED_SIZE[2]//SLICES, unit="tiles"):
#         print ("Processing tile number : {}".format(i))
        tile = padded_img[:, :, (i*SLICES) : SLICES*(i+1)]
        tile = tile.reshape(tuple([1] + list (tile.shape) + [1]))
        tile_predictions = model.predict(tile, verbose=2)
        
        tile_mask = tile_predictions[0].reshape(TILE_SIZE)
        predicted_mask[:, :, (i*SLICES) : SLICES*(i+1)] = tile_mask
    return predicted_mask

In [9]:
%%time
t0 = datetime.now()
predictions_dict = {}
size_dict = {}
model = model_create_loadWghts_Model_A(TILE_SIZE) 
fileCount = len(glob.glob(data_dir + 'subset2/' + '*.mhd'))
                
for f in tqdm(glob.glob(data_dir + 'subset2/' + '*.mhd'), total=fileCount, unit="files") :
    print ("\n Processing scan file: {}".format(os.path.basename(f)))
    seriesuid = os.path.splitext(os.path.basename(f))[0]
    # Step-1
    itk_img = sitk.ReadImage(f) 
    img_np_array = sitk.GetArrayFromImage(itk_img)
    original_size = img_np_array.shape
    print ("Original-Size of loaded image : {}".format(original_size))
    # Step-2 
    itk_img_norm = normalize_img(itk_img)
    img_np_array_norm = sitk.GetArrayFromImage(itk_img_norm)
    normalized_size = img_np_array_norm.shape
    # Step-3 
    img = img_np_array_norm.copy()
#     img = normalize_HU(img_np_array_norm)
    img = np.swapaxes(img, 0,2)   ##needed as SITK swaps axis  
    print ("Normalized input image size: {}".format(img.shape))
    # Step-4   # Step-5
    padded_img = np.zeros(PADDED_SIZE)
    padded_img[ :img.shape[0], :img.shape[1], :img.shape[2] ] = img
    print ("Padded-image size: {}".format(padded_img.shape))
    
    predicted_mask = find_mask(model, padded_img)
    predictions_dict[seriesuid] = (img.shape, padded_img, predicted_mask)
    size_dict[seriesuid] = img.shape

print('Predicted Mask sum for entire scan: {}'.format(np.sum(predicted_mask)))
pickle.dump(predictions_dict, open('Model_A_noHU_entire_predictions_{}.dat'.format(seriesuid), 'wb'))
pickle.dump(size_dict, open('Model_A_noHU_entire_size_{}.dat'.format(seriesuid), 'wb'))    
print('Processing runtime: {}'.format(datetime.now() - t0))

  0%|          | 0/1 [00:00<?, ?files/s]


 Processing scan file: 1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405.mhd
Original-Size of loaded image : (321, 512, 512)
Normalized input image size: (285, 285, 321)



  0%|          | 0/46 [00:00<?, ?tiles/s][A

Padded-image size: (448, 448, 368)

Total tiles : 46



  2%|▏         | 1/46 [02:08<1:36:02, 128.06s/tiles][A
  4%|▍         | 2/46 [04:15<1:33:46, 127.88s/tiles][A
  7%|▋         | 3/46 [06:20<1:30:48, 126.70s/tiles][A
  9%|▊         | 4/46 [08:21<1:27:49, 125.47s/tiles][A
 11%|█         | 5/46 [10:23<1:25:12, 124.69s/tiles][A
 13%|█▎        | 6/46 [12:28<1:23:11, 124.79s/tiles][A
 15%|█▌        | 7/46 [14:52<1:22:50, 127.46s/tiles][A
 17%|█▋        | 8/46 [17:09<1:21:30, 128.71s/tiles][A
 20%|█▉        | 9/46 [19:35<1:20:34, 130.66s/tiles][A
 22%|██▏       | 10/46 [21:39<1:17:59, 129.99s/tiles][A
 24%|██▍       | 11/46 [23:40<1:15:19, 129.12s/tiles][A
 26%|██▌       | 12/46 [25:47<1:13:04, 128.94s/tiles][A
 28%|██▊       | 13/46 [27:53<1:10:48, 128.73s/tiles][A
 30%|███       | 14/46 [30:01<1:08:37, 128.67s/tiles][A
 33%|███▎      | 15/46 [32:03<1:06:15, 128.24s/tiles][A
 35%|███▍      | 16/46 [34:04<1:03:53, 127.79s/tiles][A
 37%|███▋      | 17/46 [36:30<1:02:17, 128.88s/tiles][A
 39%|███▉      | 18/46 [38:54<1:00:30, 

Predicted Mask sum for entire scan: 0.0
Processing runtime: 1:37:54.681545
CPU times: user 3h 11min 11s, sys: 1h 16min 35s, total: 4h 27min 46s
Wall time: 1h 37min 54s


In [12]:
def displaySlice(sliceNo):
    
    plt.figure(figsize=[20,20]);    
    plt.subplot(121)
    plt.title("True Image")
    plt.imshow(padded_img[:, :, sliceNo], cmap='bone');

    plt.subplot(122)
    plt.title("Predicted Mask")
    plt.imshow(predicted_mask[:, :, sliceNo], cmap='bone');
    plt.show()
interact(displaySlice, sliceNo=(1,img.shape[2],1));

###### Following sections for reference & WIP code snippets -AL

In [None]:
## Multiple tile test....performance hog, so exploiting the GPU for entire slice without compromising predictions 
##and for better performance  -AL

# slices = 16
# predicted_img = np.zeros(padded_size)

# for i in range(368//slices):
#     tile_1 = padded_img[:224, :224, (i*slices) : slices*(i+1)]
#     tile_2 = padded_img[224:, 224:, (i*slices) : slices*(i+1) ] 

In [None]:
# slices = 8
# predicted_mask = np.zeros(PADDED_SIZE)

# for i in range(24//SLICES):
#     tile = padded_img[:, :, (i*SLICES) : SLICES*(i+1)]
#     tile = tile.reshape(tuple([1] + list (tile.shape) + [1]))
# #     print(tile.shape)

#     tile_predictions = model.predict(tile, verbose=2)
#     tile_mask = tile_predictions[0].reshape(448, 448, 8)
    
#     print (tile_mask.shape)
#     predicted_mask[:, :, (i*SLICES) : SLICES*(i+1)] = tile_mask


In [None]:
# slices = 8
# test_slice = padded_img[:, :, :slices]
# print(test_slice.shape)
# model = model_create_loadWghts(test_slice.shape) 
# # slice_predictions = model.predict(test_slice, verbose=2)

In [None]:
# print ("Shape of predicted mask or segmented image : {}".format(predictions_small_img[0].shape))
# print ("Shape of predicted class : {}".format(predictions_small_img[1].shape))
# predictions_small_img[0] [:, 25 : 26, :]

In [None]:
# ## AL - TEST : making an image of size 48,48,48 with random 0 or 1
# ### Case 2 : As a test created an input image of size (1, 48,48,48,1) 
# # with random 0 or 1; this works fine and able to create predictions successfully
# t2 =  np.random.choice(2,(48,48,48))
# t2 = t2.reshape(tuple([1] + list (t2.shape) + [1]))

# print ("Shape of test input image : {}".format(t2.shape))
# predictions = model.predict(t2, verbose=2)

# print ("Shape of predicted mask or segmented image : {}".format(predictions[0].shape))
# print ("Shape of predicted class : {}".format(predictions[1].shape))
# # predictions[0] [:, 25 : 26, :]

In [None]:
# padded_img[225:232, 225:232, 175]
# predicted_mask[225:232, 225:232, 175]