In [None]:
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import segmentation_models_3D as sm
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
import timeit
import tensorflow.keras.backend as K
from keras.models import load_model


In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess= tf.compat.v1.Session(config=config)

In [None]:
from ipynb.fs.full.custom_datagen import imageLoader, predict_image, predict_batch, plot_predict

In [None]:
#Define the image generators for training and validation

train_img_dir = "./data2/train/images/"
train_mask_dir = "./data2/train/masks/"

val_img_dir = "./data2/val/images/"
val_mask_dir = "./data2/val/masks/"

train_img_list=list(np.sort(os.listdir(train_img_dir)))
train_mask_list = list(np.sort(os.listdir(train_mask_dir)))

val_img_list = list(np.sort(os.listdir(val_img_dir)))
val_mask_list = list(np.sort(os.listdir(val_mask_dir)))

In [None]:
'''
Modalities
0 - Flair
2 - T2
1 - T1ce
3 - T1
'All' - for all above

Mask-Type
0 - Whole Tumor
1 - Tumor Core
2 - Enhancing Tumor
'''

batch_size = 2

train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                    train_mask_dir, train_mask_list, batch_size,'All',0)

val_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                    val_mask_dir, val_mask_list, batch_size,'All',0)

In [None]:
#Verify generator.... In python 3 next() is renamed as __next__()
img, msk = train_img_datagen.__next__()


img_num = random.randint(0,img.shape[0]-1)
test_img=img[img_num]
test_mask=msk[img_num]

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(151)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(152)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t2')
plt.subplot(153)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t1ce')
plt.subplot(154)
plt.imshow(test_img[:,:,n_slice, 3], cmap='gray')
plt.title('Image t1')
plt.subplot(155)
plt.imshow(test_mask[:,:,n_slice])
plt.title('whole tumor')
plt.show()

In [None]:
from ipynb.fs.full.Anam_net import Anam_net
from ipynb.fs.full.UNetpp import unet_pp
from ipynb.fs.full.UNet import unet
from ipynb.fs.full.RescueNet import Rescue_Net

#Test if everything is working ok. 
model = Rescue_Net(IMG_HEIGHT=128, 
                          IMG_WIDTH=128, 
                          IMG_DEPTH=128, 
                          IMG_CHANNELS=4, 
                          num_classes=1)

model.summary()
print(model.input_shape)
print(model.output_shape)

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1. - dice_coef(y_true, y_pred)

In [None]:
#Define loss, metrics and optimizer to be used for training
metrics = [dice_coef ,'Precision', 'Recall', sm.metrics.IOUScore(threshold=0.5)]

LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

In [None]:
steps_per_epoch = len(train_img_list) // batch_size
val_steps_per_epoch = len(val_img_list) // batch_size

model.compile(optimizer = optim, loss= dice_coef_loss, metrics=metrics)

checkpoint = ModelCheckpoint('Anam_Net_All_whole_tumor' + ".hdf5", monitor='loss',
    save_best_only=True, mode='auto',save_freq="epoch",)

start = timeit.default_timer()

history=model.fit(train_img_datagen,
              steps_per_epoch=steps_per_epoch,
              epochs=1,
              verbose=1,
              validation_data=val_img_datagen,
              validation_steps=val_steps_per_epoch,
              callbacks=[checkpoint],
              )
stop = timeit.default_timer()

print('Time: ', stop - start)

In [None]:
my_model = load_model('./Anam_Net_All_whole_tumor.hdf5', compile=False)

start = timeit.default_timer()
predict_batch(my_model, val_img_dir, val_img_list, val_mask_dir, val_mask_list, 'All' , 0, batch_size)
stop = timeit.default_timer()
print('Time: ', (stop - start)/8)

In [None]:
my_model = load_model('./Anam_Net_All_whole_tumor.hdf5', compile=False)

test_img, test_mask, test_prediction = predict_image(my_model, 'All' , 0, img_num = 0)

In [None]:
#Plot individual slices from test predictions for verification

# n_slice=random.randint(0, test_prediction_argmax.shape[2])
# print(n_slice)
n_slice = 60
plot_predict(test_img, test_mask, test_prediction, n_slice)