In [None]:
# Import libraries
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.metrics import MeanIoU

In [None]:
# Load Stenosis annotation data
stenosis_train = pd.read_parquet("../datasets/stenosis_train.parquet")
stenosis_val = pd.read_parquet("../datasets/stenosis_val.parquet")
stenosis_test = pd.read_parquet("../datasets/stenosis_test.parquet")

# Preview Stenosis annotation data
stenosis_train.head()

In [None]:
# Helper function to load tensors from .tf files
def load_tensor_from_file(file_path):
    serialized_tensor = tf.io.read_file(file_path)
    tensor = tf.io.parse_tensor(serialized_tensor, out_type=tf.float16)
    return tensor

In [None]:
# Load tensors
X_train = load_tensor_from_file("../datasets/X_train_stenosis.tf")
y_train = load_tensor_from_file("../datasets/y_train_stenosis.tf")
X_val = load_tensor_from_file("../datasets/X_val_stenosis.tf")
y_val = load_tensor_from_file("../datasets/y_val_stenosis.tf")
X_test = load_tensor_from_file("../datasets/X_test_stenosis.tf")
y_test = load_tensor_from_file("../datasets/y_test_stenosis.tf")

In [None]:
#Build dice loss function
def dice_loss(y_true, y_pred):
    smooth = 1e-6
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_bce_loss(y_true, y_pred):
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return bce + dice

In [None]:
# Build Model Architecture
def conv_block(input_tensor, num_filters):
    """Build the convolutional block for U-Net."""
    x = layers.Conv2D(num_filters, (3, 3), activation='relu', padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(num_filters, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    return(x)

def encoder_block(input_tensor, num_filters):
    """Downsample and build encoder block."""
    x = conv_block(input_tensor, num_filters)
    p = layers.MaxPooling2D((2, 2))(x)
    return x, p

def decoder_block(input_tensor, concat_tensor, num_filters):
    """Upsample and concatenate the feature map from encoder."""
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
    x = layers.concatenate([x, concat_tensor], axis=-1)
    x = conv_block(x, num_filters)
    return x

def build_unet(input_shape=(256, 256, 1)):
    """Build U-Net model."""
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1, p1 = encoder_block(inputs, 128)
    c2, p2 = encoder_block(p1, 64)
    c3, p3 = encoder_block(p2, 32)

    # Bridge
    b = conv_block(p3, 32)

    # Decoder
    d1 = decoder_block(b, c3, 32)
    d2 = decoder_block(d1, c2, 64)
    d3 = decoder_block(d2, c1, 128)

    # Output
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(d3)
    model = models.Model(inputs=[inputs], outputs=[outputs])
    return model

#Build U-Net model
tf.keras.backend.clear_session()
model = build_unet()
adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=adam_optimizer, loss=dice_bce_loss, metrics=[MeanIoU(num_classes=2)])

# Model summary
model.summary()

In [None]:
history = model.fit(
    x=X_train,
    y=tf.cast(y_train == 255, tf.float16),
    batch_size=16,
    epochs=10,
    validation_data=(X_val, tf.cast(y_val == 255, tf.float16))
)

In [None]:
clean_results = pd.DataFrame(history.history).reset_index()
clean_results = round(clean_results, 4)
clean_results['index'] = (clean_results['index'] + 1)
clean_results.columns = ['Epoch', 'Loss', 'Mean IoU', 'Validaiton Loss', 'Validation Mean IoU']
plt.figure(figsize=(8,4))
plt.axis('off')
plt.table(cellText=clean_results.values, colLabels=clean_results.columns, cellLoc='center', loc='center', bbox=[0, 0, 1, 1])

In [None]:
y_pred = model.predict(X_test)

In [None]:
predicted_masks = (y_pred > 0.5).astype(np.float16)

iou_metric = MeanIoU(num_classes=2)

# Update states with the ground truth and predictions
iou_metric.update_state(tf.cast(y_test == 255, tf.float16), predicted_masks)

# Get the IoU score
test_iou = iou_metric.result().numpy()
print(f"Test IoU: {test_iou}")

In [None]:
# Example: Visualize the first test image, its true mask, and the predicted mask
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(tf.squeeze(X_test[0]), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.title('Test Image')
plt.subplot(1, 3, 2)
plt.imshow(tf.squeeze(y_test[0]), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.title('True Mask')
plt.subplot(1, 3, 3)
plt.imshow(tf.squeeze(predicted_masks[0]), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.title('Predicted Mask')
plt.show()
