In [None]:
import utils
import metrics
import CNNLSTMModel

import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

LUCKY_NUMBER = 2
TARGET_SIZE = (128, 128) # For no compression choose -1

Prepare data

In [None]:
photos_path = "E:\\images\\"
masks_path = "E:\\masks\\"

scan_names = [file[:8] for file in os.listdir(photos_path) if file.endswith(".nii.gz")]
train, val, test = utils.split_train_val_test(scan_names, 0.7, 0.15, 0.15)
print(f"Training data size: {len(train)}, Validation data size: {len(val)}, Test data size: {len(test)}")

train_gen = utils.cbct_data_generator(photos_path, masks_path, train)
val_gen = utils.cbct_data_generator(photos_path, masks_path, val)

epochs = 1
model = CNNLSTMModel.create_cnn_lstm_model(target_size=TARGET_SIZE)
model.summary()

Train model

In [None]:
model.fit(
    train_gen, 
    steps_per_epoch=len(train),
    validation_data=val_gen,
    validation_steps=len(val),
    epochs=epochs)

Make prediction

In [None]:
test_scan = utils.load_dicom(photos_path+"2.zip")
test_scan = test_scan[..., np.newaxis]
test_scan = np.expand_dims(test_scan, axis=0)

predictions = model.predict(test_scan)

predicted_mask = predictions[0]
binary_mask = (predicted_mask > 0.5).astype(np.uint8)

Plot results

In [None]:
slice_index = 25

plt.figure(figsize=(10, 5))
plt.subplot(1, 4, 1)
plt.title("Original Slice")
plt.imshow(test_scan[0, slice_index, :, :, 0], cmap="gray")
plt.axis('off')

plt.subplot(1, 4, 2)
plt.title("Predicted Mask")
plt.imshow(predicted_mask[slice_index, :, :, 0], cmap="gray")
plt.axis('off')

plt.subplot(1, 4, 3)
plt.title("Binary Mask")
plt.imshow(binary_mask[slice_index, :, :, 0], cmap="gray")
plt.axis('off')

original_mask = utils.load_nifti(masks_path+"2.nii.gz")
plt.subplot(1, 4, 4)
plt.title("Original Mask")
plt.imshow(original_mask[slice_index,:,:], cmap="gray")
plt.axis('off')
plt.show()