## Sample Prediction Demo

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.models import Unet
from src.config import ModelConfig
from src.dataloader import prepare_datasets

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import segmentation_models as sm

# set up environment
config = ModelConfig()
sm.set_framework('tf.keras')

# prepare data
preprocess_input = sm.get_preprocessing(config.backbone) if config.backbone else None
df = pd.read_csv(config.split_path)
train_ds, val_ds, test_ds = prepare_datasets(
    df,
    batch_size=1,
    input_shape=config.img_shape,
    preprocess_func=preprocess_input, 
    augment_flag=False, 
)

# define model
if config.backbone:
    unet_model = sm.Unet(
        config.backbone, 
        encoder_weights=config.encoder_weights, 
        input_shape=config.img_shape, 
        classes=1,
        activation='sigmoid'
    )
else:
    unet_model = Unet(
        input_shape=config.img_shape, 
        classes=1, 
        activation='sigmoid'
    ).build()

# compile model
unet_model.compile()
unet_model.load_weights(config.save_model_path)

In [None]:
num_samples = 3
samples = val_ds.take(num_samples)

for image, mask in samples:
    prediction = unet_model.predict(image)
    
    # squeeze dimensions
    image = tf.squeeze(image)
    mask = tf.squeeze(mask)
    prediction = tf.squeeze(prediction)
    
    fig, axs = plt.subplots(1, 3, figsize=(12, 3))
    
    axs[0].imshow(image)
    axs[0].set_title('Original Image')
    axs[0].axis('off')
    
    axs[1].imshow(mask, cmap='gray')
    axs[1].set_title('True Mask')
    axs[1].axis('off')
    
    axs[2].imshow(prediction, cmap='gray')
    axs[2].set_title('Predicted Mask')
    axs[2].axis('off')
    
    plt.tight_layout()
    plt.show()