In [9]:
import os
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from helpers.UNET import UNET
from helpers.load import LoadandAugment

# Base path for the dataset
path = "/home/tfuser/project/Satelite/data/dataset/"
train_path = os.path.join(path, "train")
val_path = os.path.join(path, "val")
test_path = os.path.join(path, "test")

# Initialize data loaders
train_data = LoadandAugment(train_path, "train", 8)
val_data = LoadandAugment(val_path, "val", 8)

# Initialize UNET model
unet = UNET(input_shape=(512, 512, 4))
unet.model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Setup the model checkpoint callback to save the best model
checkpoint_callback = ModelCheckpoint(
    'best_model.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min'
)

# Train the model
unet.model.fit(
    train_data.dataset, 
    validation_data=val_data.dataset, 
    epochs=40, 
    callbacks=[checkpoint_callback]
)


Epoch 1/40
Epoch 1: val_loss improved from inf to 14.72005, saving model to best_model.h5
Epoch 2/40
Epoch 2: val_loss improved from 14.72005 to 1.10219, saving model to best_model.h5
Epoch 3/40
Epoch 3: val_loss did not improve from 1.10219
Epoch 4/40
Epoch 4: val_loss did not improve from 1.10219
Epoch 5/40
Epoch 5: val_loss did not improve from 1.10219
Epoch 6/40
Epoch 6: val_loss did not improve from 1.10219
Epoch 7/40
Epoch 7: val_loss did not improve from 1.10219
Epoch 8/40
Epoch 8: val_loss did not improve from 1.10219
Epoch 9/40
Epoch 9: val_loss did not improve from 1.10219
Epoch 10/40
Epoch 10: val_loss did not improve from 1.10219
Epoch 11/40
Epoch 11: val_loss did not improve from 1.10219
Epoch 12/40
Epoch 12: val_loss did not improve from 1.10219
Epoch 13/40
Epoch 13: val_loss did not improve from 1.10219
Epoch 14/40
Epoch 14: val_loss did not improve from 1.10219
Epoch 15/40
Epoch 15: val_loss did not improve from 1.10219
Epoch 16/40
Epoch 16: val_loss did not improve fro

<keras.callbacks.History at 0x7f4ab7febaf0>

In [11]:

import matplotlib.pyplot as plt
def plot_predictions(images, masks, predictions, num=3):
    plt.figure(figsize=(15, 5*num))
    
    for i in range(num):
        plt.subplot(num, 3, i*3+1)
        plt.imshow(images[i])
        plt.title("Satellite Image")
        plt.axis('off')
        
        plt.subplot(num, 3, i*3+2)
        plt.imshow(masks[i], cmap='gray')
        plt.title("Actual Mask")
        plt.axis('off')

        plt.subplot(num, 3, i*3+3)
        plt.imshow(predictions[i], cmap='gray')
        plt.title("Predicted Mask")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
# Load the best model
unet.model.load_weights('best_model.h5')

# Evaluate the model
test_data = LoadandAugment(test_path, "test", 4)
loss, acc = unet.model.evaluate(test_data.dataset)


Test Loss: 0.46007630228996277, Test Accuracy: 0.8033626675605774
