In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import rasterio
import cv2
import glob

# Load the trained model
model_path = 'forest_segmentation_amazon.h5'
if os.path.exists(model_path):
    model = load_model(model_path, custom_objects={'dice_coefficient': dice_coefficient})
    print("Model loaded successfully.")
else:
    print("Model not found. Please ensure the path is correct.")
    exit()

# Paths to test dataset (adjust paths as needed)
test_image_folder = "/content/drive/MyDrive/folder/AMAZON-1/Test/image/"
test_image_paths = sorted(glob.glob(test_image_folder + "/*.tif"))

# Load and preprocess test images (no masks available)
def load_test_data(image_paths, target_size=(256, 256)):
    images = []
    for img_path in image_paths:
        # Load GeoTIFF image
        with rasterio.open(img_path) as src:
            image = src.read()
            image = np.transpose(image, (1, 2, 0))  # Reorder dimensions to (height, width, channels)
            image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)  # Resize image
            image = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

        images.append(image)
    images = np.array(images)

    return images

# Load test images
X_test = load_test_data(test_image_paths, target_size=(256, 256))

# Function to make predictions and visualize results
def visualize_predictions(model, X_test, num_samples=5):
    for i in range(num_samples):
        sample_image = X_test[i]

        # Expand dimensions to create a batch of 1 for prediction
        predicted_mask = model.predict(np.expand_dims(sample_image, axis=0))[0]

        # Threshold the predicted mask (convert probability to binary output)
        predicted_mask = (predicted_mask > 0.5).astype(np.uint8)

        # Plot original image and predicted mask
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.title('Input Image')
        plt.imshow(sample_image[:, :, :3])  # Display RGB channels

        plt.subplot(1, 2, 2)
        plt.title('Predicted Mask')
        plt.imshow(predicted_mask[:, :, 0], cmap='gray')

        plt.show()

# Visualize predictions on the test dataset
visualize_predictions(model, X_test, num_samples=5)