## Sample Prediction Demo

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.models import Unet
from src.config import ModelConfig
from src.dataloader import prepare_datasets, prediction_dataset, x

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import segmentation_models as sm

# set up environment
config = ModelConfig()
sm.set_framework('tf.keras')

# prepare data
preprocess_input = sm.get_preprocessing(config.backbone) if config.backbone else None
img_paths = list(config.img_dir.glob('*'))
mask_paths = list(config.mask_dir.glob('*'))

# define model
if config.backbone:
    unet_model = sm.Unet(
        config.backbone, 
        encoder_weights=config.encoder_weights, 
        input_shape=config.img_shape, 
        classes=1,
        activation='sigmoid'
    )
else:
    unet_model = Unet(
        input_shape=config.img_shape, 
        classes=1, 
        activation='sigmoid'
    ).build()

# compile model
unet_model.compile()
unet_model.load_weights(config.save_model_path)

In [None]:
ds = prediction_dataset(
    img_paths,
    1,
    config.img_shape,
    preprocess_input
)

num_samples = 3
samples = ds.take(num_samples)

for image, path in samples:
    filename = path.numpy()[0].decode('utf-8')
    prediction = unet_model.predict(image)
    
    # squeeze dimensions
    image = tf.squeeze(image)
    prediction = tf.squeeze(prediction)
    
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    axs[0].imshow(image)
    axs[0].set_title('Original Image')
    axs[0].axis('off')
    axs[1].imshow(prediction, cmap='gray')
    axs[1].set_title('Prediction')
    axs[1].axis('off')

    fig.suptitle(filename)
    plt.tight_layout()
    plt.show()