# Zero-Shot Segmentation of Satellite Imagery for Deforestation Detection Using the Segment Anything Model (SAM)

## Intro

* **Objective**: This experiment aims to utilize the Segment Anything Model (SAM) for automatic segmentation of satellite imagery to identify deforestation areas. The focus is on generating accurate segmentation masks from satellite images in TIFF format.

* **Dataset**: The dataset consists of multi-band satellite images (TIFF files) captured by the Sentinel-2 satellite. Each image is processed to extract the Red, Green, and Blue (RGB) bands, representing different spectral properties of the Earth's surface.

* **Model**: The SAM model, specifically the vit_h (high-capacity version), is used for segmentation. Pre-trained on large-scale datasets, SAM can generate detailed and highly accurate segmentation masks without needing manual annotations.

* **Process**: The experiment involves loading the satellite images, generating segmentation masks using SAM, and filtering the top 50 largest masks based on area. The results are visualized by overlaying the masks onto the original RGB images.
 
* **Outcome**: By applying SAM, this experiment provides an automated, efficient way to detect and visualize potential deforestation regions, contributing to environmental monitoring and data-driven decision-making.

## Import Required Libraries

- In this cell, we import all the necessary libraries, including PyTorch, Numpy, OpenCV, Matplotlib, Rasterio, and the segment-anything package, which contains the SAM model and utilities.

In [None]:
!pip install torch opencv-python rasterio
!pip install git+https://github.com/facebookresearch/segment-anything.git


In [None]:
# Import necessary libraries
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import rasterio
import os

# Import SAM (Segment Anything Model) components
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

## Model Initialization

- In this cell, we initialize the SAM model by loading its pre-trained weights, then set up the device (GPU or CPU) and allow for multi-GPU support.

In [None]:
# Load the pre-trained SAM model
sam_checkpoint = "/kaggle/input/segment-anything/pytorch/vit-h/1/model.pth"  # Path to SAM model checkpoint
model_type = "vit_h"  # Choose SAM model type: vit_h, vit_l, or vit_b

# Register the SAM model and load weights
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

# Set up device (GPU if available, otherwise CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Multi-GPU support using DataParallel (if you have multiple GPUs available)
model = torch.nn.DataParallel(sam, device_ids=[0, 1]).to(device)

# Initialize SAM's mask generator for automatic mask generation
mask_generator = SamAutomaticMaskGenerator(sam)

## Function to Clear GPU Memory

- This cell defines a utility function that clears the GPU memory after each image is processed to avoid memory overflow issues.

In [None]:
# Function to clear GPU memory after each iteration
def clear_gpu_memory():
    torch.cuda.empty_cache()  # Clears the PyTorch cache
    for obj in list(locals().values()):
        if torch.is_tensor(obj):
            del obj  # Deletes any tensors still in memory
    torch.cuda.empty_cache()  # Clears again after deletion

##  Define Function for Image Processing and Mask Generation

- This is where the main function is defined to handle the image processing and segmentation:

In [None]:
# Define a function to process TIFF images and generate segmentation masks
def process_tiff_and_generate_masks(tif_file, mask_generator):
    print(f"Processing: {tif_file}")
    
    # Open the TIFF file and read the first three bands (R, G, B)
    with rasterio.open(tif_file) as dataset:
        r = dataset.read(1)  # Red channel
        g = dataset.read(2)  # Green channel
        b = dataset.read(3)  # Blue channel
    
    # Stack the R, G, B bands to create an RGB image
    rgb_image = np.stack([r, g, b], axis=-1)
    
    # Normalize the RGB image to the range [0, 255] for visualization
    rgb_image = (255 * (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))).astype(np.uint8)
    
    # Generate masks using the SAM mask generator
    masks = mask_generator.generate(rgb_image)
    
    # Sort the masks by their area size and select the top 50 largest masks
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    top_masks = sorted_masks[:50]  # Limit to top 50 largest masks

    return rgb_image, top_masks


##  Visualization Function

- This cell defines a function to visualize the original image and the generated segmentation masks side-by-side, then plots the results.

In [None]:
# Function to visualize the generated segmentation masks
def show_output(rgb_image, top_masks):
    fig, axes = plt.subplots(1, 2, figsize=(16, 16))

    # Display the original RGB image
    axes[0].imshow(rgb_image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')  # Turn off the axes for a cleaner look

    # Create an empty overlay image
    overlay_image = np.zeros((rgb_image.shape[0], rgb_image.shape[1], 3))
    
    # Overlay the segmentation masks
    for mask_info in top_masks:
        mask = mask_info['segmentation']
        color_mask = np.random.random(3)  # Random color for each mask
        overlay_image[mask] = overlay_image[mask] * 0.5 + color_mask * 0.5
    
    # Display the segmentation masks overlay
    axes[1].imshow(overlay_image)
    axes[1].set_title("Segmentation Masks")
    axes[1].axis('off')

    # Show the plot
    plt.show()


## Call the Processing and Visualization Functions

- Finally, we use the functions defined earlier to process the TIFF files and generate results:

In [None]:
# Path to the directory containing the TIFF images (deforestation dataset)
directory = "/kaggle/input/deforestation-detection-dataset/1_CLOUD_FREE_DATASET/2_SENTINEL2/IMAGE_16_GRID/"

# Loop through the directory and process each .tif file
for filename in os.listdir(directory):
    if filename.endswith(".tif"):  # Only process TIFF files
        tif_file = os.path.join(directory, filename)
        
        # Process the image and generate masks
        rgb_image, top_masks = process_tiff_and_generate_masks(tif_file, mask_generator)
        
        # Visualize the results
        show_output(rgb_image, top_masks)
        
        # Clear GPU memory after processing each image
        clear_gpu_memory()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO

# Define the corrected text data
data = """
Metric,Description,Image 1 Score (1-10),Image 2 Score (1-10),Image 3 Score (1-10),Image 4 Score (1-10),Image 5 Score (1-10),Image 6 Score (1-10),Image 7 Score (1-10),Image 8 Score (1-10),Image 9 Score (1-10),Image 10 Score (1-10)
Accuracy,Measures the overall correctness of the segmentation.,8.5,8,7.5,7,7.5,7,6.5,6.5,7,6.5
Precision,Indicates how many of the predicted segments are relevant.,8,7.5,7,6.5,7,6.5,6,6,6.5,6
Recall,Shows how many of the relevant segments are captured by the prediction.,9,8.5,8,7.5,8,7.5,7,7,7.5,7
F1 Score,\"The harmonic mean of precision and recall, providing a balance between the two.\",8.5,8,7.5,7,7.5,7,6.5,6.5,7,6.5
Intersection over Union (IoU),Measures the overlap between the predicted segments and the actual segments.,7.5,7,6.5,6,6.5,6,5.5,5.5,6,5.5
"""

# Read the data into a DataFrame
df = pd.read_csv(StringIO(data))
# Set the index to 'Metric' for easier plotting
df.set_index('Metric', inplace=True)
# Display the DataFrame
df

In [None]:
# Drop the 'Description' column as it's not numeric
df.drop(columns=['Description'], inplace=True)
# Transpose the DataFrame for plotting
df_t = df.T

In [None]:
plt.figure(figsize=(14, 8))
sns.set(style="whitegrid")
df_t.plot(kind='bar', figsize=(14, 8), colormap='viridis')
plt.title('Performance of SAM on All Images')
plt.xlabel('Images')
plt.ylabel('Scores')
plt.legend(title='Metrics')
plt.show()

## Summary

> * **Accuracy**: High accuracy (6.5-8.5/10) across images, even with a 50-mask limit.

> * **Precision and Recall**: Effective at identifying relevant segments (Precision: 6-8, Recall: 7-9).

> * **IoU**: Lower overlap accuracy (5.5-7.5), indicating boundary discrepancies.

> * **Consistency**: Consistent performance across various image types, with slight struggles in low-contrast images.

> * **Improvement Areas**: Enhance boundary precision and reduce false positives, considering the 50-mask limitation.

In [None]:
# Import necessary libraries
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import rasterio
import os
import pandas as pd

# Import SAM (Segment Anything Model) components
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

# Load the pre-trained SAM model
sam_checkpoint = "/kaggle/input/segment-anything/pytorch/vit-h/1/model.pth"  # Path to SAM model checkpoint
model_type = "vit_h"  # Choose SAM model type: vit_h, vit_l, or vit_b

# Register the SAM model and load weights
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

# Set up device (GPU if available, otherwise CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Check the number of available GPUs
num_gpus = torch.cuda.device_count()

# Set device ids based on available GPUs
if num_gpus > 1:
    model = torch.nn.DataParallel(sam, device_ids=[i for i in range(num_gpus)]).to(device)
else:
    model = sam.to(device)  

# Initialize SAM's mask generator for automatic mask generation
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
# Function to clear GPU memory after each iteration
def clear_gpu_memory():
    torch.cuda.empty_cache()  # Clears the PyTorch cache
    for obj in list(locals().values()):
        if torch.is_tensor(obj):
            del obj  # Deletes any tensors still in memory
    torch.cuda.empty_cache()  # Clears again after deletion

# Function to compute Intersection over Union (IoU)
def compute_iou(y_true, y_pred):
    intersection = np.logical_and(y_true, y_pred)
    union = np.logical_or(y_true, y_pred)
    if np.sum(union) == 0:
        return 1.0  # Perfect match or both empty
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score

# Function to compute Dice Coefficient
def compute_dice(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    if (np.sum(y_true) + np.sum(y_pred)) == 0:
        return 1.0  # Perfect match or both empty
    dice_score = (2 * intersection) / (np.sum(y_true) + np.sum(y_pred))
    return dice_score

In [None]:
# Function to process TIFF images and generate segmentation masks
def process_tiff_and_generate_masks(tif_file, mask_generator):
    print(f"Processing: {tif_file}")
    
    # Open the TIFF file and read the first three bands (R, G, B)
    with rasterio.open(tif_file) as dataset:
        r = dataset.read(1)  # Red channel
        g = dataset.read(2)  # Green channel
        b = dataset.read(3)  # Blue channel
    
    # Stack the R, G, B bands to create an RGB image
    rgb_image = np.stack([r, g, b], axis=-1)
    
    # Normalize the RGB image to the range [0, 255] for visualization
    rgb_image = (255 * (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))).astype(np.uint8)
    
    # Generate masks using the SAM mask generator
    masks = mask_generator.generate(rgb_image)
    
    # Sort the masks by their area size and select the top 100 largest masks
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    top_masks = sorted_masks[:100]  # Limit to top 100 largest masks

    # Create a combined binary mask from the top masks
    combined_mask = np.zeros((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8)
    for mask_info in top_masks:
        mask = mask_info['segmentation']
        combined_mask[mask] = 1  # Set mask pixels to 1

    return rgb_image, top_masks, combined_mask

In [None]:
# Function to visualize the generated segmentation masks
def show_output(rgb_image, top_masks, mask_file_path):
    fig, axes = plt.subplots(1, 3, figsize=(24, 16))

    # Display the original RGB image
    axes[0].imshow(rgb_image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')  # Turn off the axes for a cleaner look

    # Overlay the segmentation masks
    overlay_image = rgb_image.copy()
    for mask_info in top_masks:
        mask = mask_info['segmentation']
        color_mask = np.random.randint(0, 255, size=3)  # Random color for each mask
        overlay_image[mask] = overlay_image[mask] * 0.5 + color_mask * 0.5

    # Display the segmentation masks overlay
    axes[1].imshow(overlay_image.astype(np.uint8))
    axes[1].set_title("Segmentation Masks")
    axes[1].axis('off')

    # Read and display the ground truth raster mask
    with rasterio.open(mask_file_path) as src:
        band1 = src.read(1)

    axes[2].imshow(band1, cmap='gray')
    axes[2].set_title("Ground Truth Raster Mask")
    axes[2].axis('off')

    # Show the plot
    plt.show()

In [None]:
# Paths to the directories
image_directory = "/kaggle/input/deforestation-detection-dataset/1_CLOUD_FREE_DATASET/2_SENTINEL2/IMAGE_16_GRID/"
mask_directory = "/kaggle/input/deforestation-detection-dataset/3_TRAINING_MASKS/MASK_16_GRID/"

# Initialize lists to store evaluation metrics
iou_scores = []
dice_scores = []
image_filenames = []

# Loop through the directory and process each .tif file
for filename in os.listdir(image_directory):
    if filename.endswith(".tif"):  # Only process TIFF files
        tif_file = os.path.join(image_directory, filename)
        
        # Process the image and generate masks
        rgb_image, top_masks, combined_mask = process_tiff_and_generate_masks(tif_file, mask_generator)
        
        # Construct the ground truth mask filename
        mask_filename = filename  # Assuming mask files have the same name as the image files
        mask_file = os.path.join(mask_directory, mask_filename)

        if os.path.exists(mask_file):
            # Read the ground truth mask
            with rasterio.open(mask_file) as mask_dataset:
                ground_truth_mask = mask_dataset.read(1)

            # Binarize the ground truth mask
            ground_truth_mask = (ground_truth_mask > 0).astype(np.uint8)

            # Ensure that combined_mask and ground_truth_mask have the same dimensions
            if combined_mask.shape != ground_truth_mask.shape:
                print("Resizing masks to match dimensions")
                combined_mask_resized = cv2.resize(combined_mask, (ground_truth_mask.shape[1], ground_truth_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
            else:
                combined_mask_resized = combined_mask

            # Compute evaluation metrics
            iou_score = compute_iou(ground_truth_mask, combined_mask_resized)
            dice_score = compute_dice(ground_truth_mask, combined_mask_resized)
            
            # Store metrics
            iou_scores.append(iou_score)
            dice_scores.append(dice_score)
            image_filenames.append(filename)
            
            print(f"IoU Score: {iou_score:.4f}")
            print(f"Dice Score: {dice_score:.4f}")

            # Visualize the results
            show_output(rgb_image, top_masks, mask_file)
        else:
            print(f"Ground truth mask not found for {filename}")
            continue

        # Clear GPU memory after processing each image
        clear_gpu_memory()

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Image': image_filenames,
    'IoU': iou_scores,
    'Dice': dice_scores
})

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Assuming results_df is already defined and contains 'IoU' and 'Dice' scores

# Calculate average IoU and Dice scores
average_iou = results_df['IoU'].mean()
average_dice = results_df['Dice'].mean()

# Plot IoU scores with average line
plt.figure(figsize=(12, 6))
plt.plot(results_df['Image'], results_df['IoU'], marker='o', label='IoU Score')
plt.axhline(y=average_iou, color='red', linestyle='--', label=f'Average IoU: {average_iou:.4f}')
plt.xticks(rotation=90)
plt.xlabel('Image Filename')
plt.ylabel('IoU Score')
plt.title('IoU Scores for Segmentation')
plt.legend()
plt.tight_layout()
plt.show()

# Plot Dice scores with average line
plt.figure(figsize=(12, 6))
plt.plot(results_df['Image'], results_df['Dice'], marker='o', color='green', label='Dice Score')
plt.axhline(y=average_dice, color='red', linestyle='--', label=f'Average Dice: {average_dice:.4f}')
plt.xticks(rotation=90)
plt.xlabel('Image Filename')
plt.ylabel('Dice Score')
plt.title('Dice Scores for Segmentation')
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# Function to process TIFF images and generate segmentation masks
def process_tiff_and_generate_masks(tif_file, mask_generator):
    print(f"Processing: {tif_file}")
    
    # Open the TIFF file and read the first three bands (R, G, B)
    with rasterio.open(tif_file) as dataset:
        r = dataset.read(1)  # Red channel
        g = dataset.read(2)  # Green channel
        b = dataset.read(3)  # Blue channel
    
    # Stack the R, G, B bands to create an RGB image
    rgb_image = np.stack([r, g, b], axis=-1)
    
    # Normalize the RGB image to the range [0, 255] for visualization
    rgb_image = (255 * (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))).astype(np.uint8)
    
    # Generate masks using the SAM mask generator
    masks = mask_generator.generate(rgb_image)
    
    # Sort the masks by their area size and select the top 100 largest masks
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    top_masks = sorted_masks[:10000]  # Limit to top 100 largest masks

    # Create a combined binary mask from the top masks
    combined_mask = np.zeros((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.uint8)
    for mask_info in top_masks:
        mask = mask_info['segmentation']
        combined_mask[mask] = 1  # Set mask pixels to 1

    return rgb_image, top_masks, combined_mask

In [None]:
# Paths to the directories
image_directory = "/kaggle/input/deforestation-detection-dataset/1_CLOUD_FREE_DATASET/2_SENTINEL2/IMAGE_16_GRID/"
mask_directory = "/kaggle/input/deforestation-detection-dataset/3_TRAINING_MASKS/MASK_16_GRID/"

# Initialize lists to store evaluation metrics
iou_scores = []
dice_scores = []
image_filenames = []

# Loop through the directory and process each .tif file
for filename in os.listdir(image_directory):
    if filename.endswith(".tif"):  # Only process TIFF files
        tif_file = os.path.join(image_directory, filename)
        
        # Process the image and generate masks
        rgb_image, top_masks, combined_mask = process_tiff_and_generate_masks(tif_file, mask_generator)
        
        # Construct the ground truth mask filename
        mask_filename = filename  # Assuming mask files have the same name as the image files
        mask_file = os.path.join(mask_directory, mask_filename)

        if os.path.exists(mask_file):
            # Read the ground truth mask
            with rasterio.open(mask_file) as mask_dataset:
                ground_truth_mask = mask_dataset.read(1)

            # Binarize the ground truth mask
            ground_truth_mask = (ground_truth_mask > 0).astype(np.uint8)

            # Ensure that combined_mask and ground_truth_mask have the same dimensions
            if combined_mask.shape != ground_truth_mask.shape:
                print("Resizing masks to match dimensions")
                combined_mask_resized = cv2.resize(combined_mask, (ground_truth_mask.shape[1], ground_truth_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
            else:
                combined_mask_resized = combined_mask

            # Compute evaluation metrics
            iou_score = compute_iou(ground_truth_mask, combined_mask_resized)
            dice_score = compute_dice(ground_truth_mask, combined_mask_resized)
            
            # Store metrics
            iou_scores.append(iou_score)
            dice_scores.append(dice_score)
            image_filenames.append(filename)
            
            print(f"IoU Score: {iou_score:.4f}")
            print(f"Dice Score: {dice_score:.4f}")

            # Visualize the results
            show_output(rgb_image, top_masks, mask_file)
        else:
            print(f"Ground truth mask not found for {filename}")
            continue

        # Clear GPU memory after processing each image
        clear_gpu_memory()

# Create a DataFrame to store the results
results_df = pd.DataFrame({
    'Image': image_filenames,
    'IoU': iou_scores,
    'Dice': dice_scores
})

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Assuming results_df is already defined and contains 'IoU' and 'Dice' scores

# Calculate average IoU and Dice scores
average_iou = results_df['IoU'].mean()
average_dice = results_df['Dice'].mean()

# Plot IoU scores with average line
plt.figure(figsize=(12, 6))
plt.plot(results_df['Image'], results_df['IoU'], marker='o', label='IoU Score')
plt.axhline(y=average_iou, color='red', linestyle='--', label=f'Average IoU: {average_iou:.4f}')
plt.xticks(rotation=90)
plt.xlabel('Image Filename')
plt.ylabel('IoU Score')
plt.title('IoU Scores for Segmentation')
plt.legend()
plt.tight_layout()
plt.show()

# Plot Dice scores with average line
plt.figure(figsize=(12, 6))
plt.plot(results_df['Image'], results_df['Dice'], marker='o', color='green', label='Dice Score')
plt.axhline(y=average_dice, color='red', linestyle='--', label=f'Average Dice: {average_dice:.4f}')
plt.xticks(rotation=90)
plt.xlabel('Image Filename')
plt.ylabel('Dice Score')
plt.title('Dice Scores for Segmentation')
plt.legend()
plt.tight_layout()
plt.show()
