In [None]:
import utils
import metrics
import CNNLSTMModel

import os
import matplotlib.pyplot as plt
import numpy as np

Prepare data

In [None]:
photos_path = "./photos/"
masks_path = "./masks/"

photos_names = [file for file in os.listdir(photos_path) if file.endswith(".zip")]
masks_names = [file for file in os.listdir(masks_path) if file.endswith(".nii.gz")]

train_gen = utils.cbct_data_generator(photos_path, masks_path, photos_names, masks_names)

epochs = 6
model = CNNLSTMModel.create_cnn_lstm_model()
model.summary()

Train model

In [None]:
model.fit(
    train_gen, 
    steps_per_epoch=len(photos_names),
    epochs=epochs)

Make prediction

In [None]:
test_scan = utils.load_dicom(photos_path+"2.zip")
test_scan = test_scan[..., np.newaxis]
test_scan = np.expand_dims(test_scan, axis=0)

predictions = model.predict(test_scan)

predicted_mask = predictions[0]
binary_mask = (predicted_mask > 0.5).astype(np.uint8)

Plot results

In [None]:
# Select a slice index to visualize
slice_index = 25

# Original slice
plt.figure(figsize=(10, 5))
plt.subplot(1, 4, 1)
plt.title("Original Slice")
plt.imshow(test_scan[0, slice_index, :, :, 0], cmap="gray")

# Predicted mask
plt.subplot(1, 4, 2)
plt.title("Predicted Mask")
plt.imshow(predicted_mask[slice_index, :, :, 0], cmap="gray")

# Binary mask after thresholding
plt.subplot(1, 4, 3)
plt.title("Binary Mask")
plt.imshow(binary_mask[slice_index, :, :, 0], cmap="gray")


original_mask = utils.load_nifti(masks_path+"2.nii.gz")
plt.subplot(1, 4, 4)
plt.title("Original Mask")
plt.imshow(original_mask[slice_index,:,:], cmap="gray")
plt.show()