In [None]:
import utils
import metrics
import CNNLSTMModel
import ConvLSTM2DModel

import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

LUCKY_NUMBER = 2
TARGET_SIZE = (32, 32) # For no compression choose -1
TARGET_SLICES = 304

PHOTOS_PATH = "/run/media/student/DataStorage/images/"
MASK_PATH = "/run/media/student/DataStorage/masks/"

In [None]:
print("GPUs Available: ",tf.config.list_physical_devices('GPU'))

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

In [None]:
from tensorflow.keras.mixed_precision import set_global_policy

# Set the policy to mixed precision
set_global_policy('mixed_float16')


Prepare data

In [None]:
np.random.seed(LUCKY_NUMBER)
batch_size = 2
epochs = 10

scan_names = [file[:8] for file in os.listdir(PHOTOS_PATH) if file.endswith(".nii.gz")]
train, val, test = utils.split_train_val_test(scan_names, 0.7, 0.15, 0.15)
print(f"Training data size: {len(train)}, Validation data size: {len(val)}, Test data size: {len(test)}")

print(train)
train_gen = utils.cbct_data_generator(PHOTOS_PATH, MASK_PATH, train)
val_gen = utils.cbct_data_generator(PHOTOS_PATH, MASK_PATH, val)


model = CNNLSTMModel.create_cnn_lstm_model(image_shape=TARGET_SIZE, num_slices=TARGET_SLICES)
#model = ConvLSTM2DModel.create_cnn_convlstm2d_model(image_shape=TARGET_SIZE, num_slices=TARGET_SLICES)
model.summary()

print(model.output_shape)

Train model

In [None]:
model.layers[9].input.shape

In [None]:
model.fit(
    train_gen, 
    batch_size=2,
    steps_per_epoch=len(train)//batch_size,
    #validation_data=val_gen,
    #validation_steps=len(val),
    epochs=epochs)

In [None]:
utils.save_model(model)

Make prediction

In [None]:
test_scan = utils.load_nifti_cbct_scan(PHOTOS_PATH+train[1]+"_0000.nii.gz")
test_scan = test_scan[..., np.newaxis]
test_scan = np.expand_dims(test_scan, axis=0)

predictions = model.predict(test_scan)

predicted_mask = predictions[0]

binary_mask = (predicted_mask > 0.5).astype(np.float32)

print(predicted_mask[200,:,:,0] == predicted_mask[300,:,:,0])

plt.imshow(predicted_mask[120, :, :, 0], cmap="gray")

Plot results

In [None]:
slice_index = 100

plt.figure(figsize=(10, 5))
plt.subplot(1, 4, 1)
plt.title("Original Slice")
plt.imshow(test_scan[0, slice_index, :, :, 0], cmap="gray")
plt.axis('off')

plt.subplot(1, 4, 2)
plt.title("Predicted Mask")
plt.imshow(predicted_mask[slice_index, :, :, 0], cmap="gray")
plt.axis('off')

plt.subplot(1, 4, 3)
plt.title("Binary Mask")
plt.imshow(binary_mask[slice_index, :, :, 0], cmap="gray")
plt.axis('off')

original_mask = utils.load_nifti_mask(MASK_PATH+train[1]+".nii.gz")
plt.subplot(1, 4, 4)
plt.title("Original Mask")
plt.imshow(original_mask[slice_index,:,:], cmap="gray")
plt.axis('off')
plt.show()