<a href="https://colab.research.google.com/github/codenameoge/100days-of-Computer-Vision/blob/main/remote_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#%% [markdown]
# # Remote Sensing Image Segmentation with Partial Cross-Entropy Loss
#
# This notebook implements the following:
# 1. Downloads and preprocesses the **Global Land Cover Mapping (OpenEarthMap)** dataset from Kaggle.
# 2. Implements a custom **Partial Cross-Entropy Loss** that computes loss only on randomly sampled points from the segmentation mask.
# 3. Builds a simple U-Net segmentation network.
# 4. Trains and evaluates the model on the remote sensing dataset.
# 5. Visualizes segmentation predictions.
#
# The code is designed for Google Colab and uses PyTorch for model development.


In [None]:
# Install necessary packages and set up Kaggle API
!pip install kaggle
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from glob import glob
from tqdm import tqdm

Looking in indexes: https://download.pytorch.org/whl/cu118


In [None]:
 ## Kaggle Setup & Dataset Download
#
# Configure your Kaggle credentials and download the dataset from:
# [Global Land Cover Mapping (OpenEarthMap)](https://www.kaggle.com/datasets/aletbm/global-land-cover-mapping-openearthmap/data)

# Set up Kaggle API (make sure to replace with your credentials)
!mkdir -p ~/.kaggle
!echo '{"username":"ogechiezedozie","key":"2bed850e5fe8a7a9e7695ba09633c37e"}' > ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# Download dataset from Kaggle (this may take a few minutes)
!kaggle datasets download -d aletbm/global-land-cover-mapping-openearthmap --unzip -p /content/OpenEarthMap

Dataset URL: https://www.kaggle.com/datasets/aletbm/global-land-cover-mapping-openearthmap
License(s): CC-BY-NC-SA-4.0
Downloading global-land-cover-mapping-openearthmap.zip to /content/OpenEarthMap
100% 8.46G/8.47G [03:19<00:00, 91.1MB/s]
100% 8.47G/8.47G [03:19<00:00, 45.6MB/s]


In [None]:
# Define preprocessing parameters
IMG_SIZE = (256, 256)        # Resize images and masks to 256x256
SAMPLE_POINTS = 100          # Number of random points to sample from each mask
BATCH_SIZE = 8
EPOCHS = 10

In [None]:
# Define dataset base path and split files
DATASET_PATH = "/content/OpenEarthMap"
IMAGE_SPLITS = {"train": "train.txt", "val": "val.txt", "test": "test.txt"}
# Set directories for images and labels
IMAGE_DIR = os.path.join(DATASET_PATH, "images")
MASK_DIR = os.path.join(DATASET_PATH, "label")

In [None]:
## Loading Filenames for Each Split
#
# The function `load_filenames()` reads the filenames from `train.txt`, `val.txt`, and `test.txt`
# and constructs full paths by including the subfolder names (train, val, test).

def load_filenames():
    """
    Load filenames for each split and build full paths.
    """
    split_files = {split: os.path.join(DATASET_PATH, filename) for split, filename in IMAGE_SPLITS.items()}
    dataset = {"train": [], "val": [], "test": []}

    for split, split_file in split_files.items():
        if not os.path.exists(split_file):
            print(f"Error: {split_file} does not exist!")
            continue

        with open(split_file, 'r') as f:
            filenames = [line.strip() for line in f.readlines()]

        # Include subfolder (train, val, or test) in the full paths
        image_files = [os.path.join(IMAGE_DIR, split, f) for f in filenames]
        mask_files = [os.path.join(MASK_DIR, split, f) for f in filenames]

        # (Optional) Print first few samples for debugging
        for img, mask in zip(image_files[:5], mask_files[:5]):
            print(f"Image: {img}, Mask: {mask}")

        dataset[split] = list(zip(image_files, mask_files))

    return dataset

In [None]:
 ## Helper Functions for Preprocessing and Visualization
#
# These functions load and preprocess images, sample random point labels from masks,
# and visualize an image-mask pair.

def visualize_sample(image_path, mask_path):
    """Visualize an image and its corresponding mask."""
    if not os.path.exists(image_path) or not os.path.exists(mask_path):
        print(f"Error: File not found -> {image_path} or {mask_path}")
        return

    image = cv2.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    if image is None or mask is None:
        print(f"Error: Unable to load {image_path} or {mask_path}")
        return

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Satellite Image")
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='gray')
    plt.title("Segmentation Mask")
    plt.show()


def preprocess_image(image_path, mask_path):
    """Load and preprocess image and mask."""
    if not os.path.exists(image_path) or not os.path.exists(mask_path):
        print(f"Warning: Skipping missing file {image_path} or {mask_path}")
        return None, None

    image = cv2.imread(image_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    if image is None or mask is None:
        print(f"Warning: Unable to load {image_path} or {mask_path}")
        return None, None

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, IMG_SIZE) / 255.0  # Normalize image

    mask = cv2.resize(mask, IMG_SIZE, interpolation=cv2.INTER_NEAREST)
    return image, mask


def sample_point_labels(mask, num_points=SAMPLE_POINTS):
    """
    Randomly sample num_points from the mask.
    Returns a list of (x, y, label) tuples.
    """
    height, width = mask.shape
    sampled_points = []
    for _ in range(num_points):
        x = random.randint(0, width - 1)
        y = random.randint(0, height - 1)
        label = int(mask[y, x])  # Convert label to int
        sampled_points.append((x, y, label))
    return sampled_points

In [None]:
 ## Process the Dataset
#
# This function loads filenames for train, val, and test splits and processes each sample:
# - Loads and preprocesses images and masks.
# - Samples random point labels from the mask.
#
# **Note:** Since the test set does not have labels in this dataset, we only process the train and val sets.

def process_dataset():
    dataset_splits = load_filenames()
    processed_data = {"train": [], "val": []}  # Only process train and val sets

    for split in ["train", "val"]:
        files = dataset_splits[split]
        for img_path, mask_path in tqdm(files, total=len(files), desc=f"Processing {split} set"):
            if not os.path.exists(img_path) or not os.path.exists(mask_path):
                print(f"Warning: Missing file {img_path} or {mask_path}")
                continue

            image, mask = preprocess_image(img_path, mask_path)
            if image is None or mask is None:
                print(f"Warning: Skipping corrupted file {img_path} or {mask_path}")
                continue

            point_labels = sample_point_labels(mask)
            processed_data[split].append({
                "image": image,
                "mask": mask,
                "points": point_labels
            })
    return processed_data

In [None]:
 ## PyTorch Dataset and DataLoader
#
# We define a custom dataset class that returns an image, mask, and sampled point labels for each sample.
# Then, we create DataLoaders for the train and validation sets.

class RemoteSensingDataset(Dataset):
    def __init__(self, dataset_split):
        # Filter out any samples that are None
        self.data = [sample for sample in dataset_split if sample and sample.get("image") is not None and sample.get("mask") is not None]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        # Convert image to tensor with shape (C, H, W)
        image = torch.tensor(sample["image"].transpose(2, 0, 1), dtype=torch.float32)
        mask = torch.tensor(sample["mask"], dtype=torch.long)
        points = sample["points"]
        return image, mask, points

def custom_collate_fn(batch):
    """
    Custom collate function to handle the 'points' field properly.
    It stacks images and masks but leaves the points as a list.
    """
    images, masks, points = zip(*batch)
    images = torch.stack(images, 0)
    masks = torch.stack(masks, 0)
    # points is left as a tuple of lists (one per sample)
    return images, masks, points

# Update the DataLoader creation to use the custom collate function:
def create_dataloaders(processed_data):
    train_dataset = RemoteSensingDataset(processed_data["train"])
    val_dataset = RemoteSensingDataset(processed_data["val"])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)
    return train_loader, val_loader

# Now, when you iterate over train_loader, your sampled_points should remain a list of tuples.

In [None]:
 ## Partial Cross-Entropy Loss
#
# This custom loss function computes cross-entropy loss only at the sampled point locations.
# It expects that each sample's `sampled_points` is a list of tuples `(x, y, label)`.

class PartialCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(PartialCrossEntropyLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none')

    def forward(self, predictions, targets, sampled_points):
        # predictions: [batch_size, num_classes, height, width]
        batch_size = predictions.shape[0]
        loss_list = []  # List to accumulate losses for each sample

        for i in range(batch_size):
            if len(sampled_points[i]) == 0:
                continue  # Skip if no points in this sample

            try:
                # Unpack the list of (x, y, label) tuples into separate lists
                x_coords, y_coords, labels = zip(*sampled_points[i])
            except ValueError as e:
                print(f"Error unpacking sampled points in sample {i}: {e}")
                print(f"Full sampled_points[{i}]: {sampled_points[i]}")
                continue  # Skip this sample if unpacking fails

            # Use lists of coordinates to index the predictions
            pred_values = predictions[i, :, list(y_coords), list(x_coords)]  # shape: [num_points, num_classes]
            target_values = torch.tensor(labels, dtype=torch.long, device=predictions.device)

            # Compute cross entropy loss on the sampled points and take the mean
            sample_loss = self.criterion(pred_values, target_values).mean()
            loss_list.append(sample_loss)

        if loss_list:
            return torch.stack(loss_list).mean()
        else:
            # Return a zero tensor with grad enabled if no loss was computed (this should be rare)
            return torch.tensor(0.0, device=predictions.device, requires_grad=True)

In [None]:
## U-Net Model for Segmentation
#
# Here, we implement a simple U-Net model for remote sensing segmentation.
# You can adjust the architecture as needed.

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=6):
        """
        in_channels: number of input image channels (RGB -> 3)
        out_channels: number of segmentation classes (adjust according to your dataset)
        """
        super(UNet, self).__init__()
        # Encoder: Two convolution layers
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        # Decoder: Upsample using convolution layers
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
## Initialize Model, Optimizer, and Loss Function
#
# We create an instance of the U-Net model, set up the Adam optimizer, and instantiate the custom Partial Cross-Entropy Loss.

model = UNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = PartialCrossEntropyLoss()

In [None]:
## Training, Evaluation, and Visualization Functions
#
# The following functions handle training, evaluation, and visualization:
#
# - **train_model()**: Trains the model for a specified number of epochs.
# - **evaluate_model()**: Computes the average validation loss.
# - **visualize_predictions()**: Displays original images, ground truth masks, and model predictions.

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=EPOCHS):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion.to(device)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        # Debug: Print sampled points for the first batch in each epoch (optional)
        for images, masks, sampled_points in train_loader:
            print("Sampled Points Example:", sampled_points[:3])
            break  # Print once per epoch

        for images, masks, sampled_points in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks, sampled_points)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}")


def evaluate_model(model, val_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    total_loss = 0.0
    with torch.no_grad():
        for images, masks, sampled_points in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks, sampled_points)
            total_loss += loss.item()

    avg_val_loss = total_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")


def visualize_predictions(model, val_loader, num_samples=3):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    with torch.no_grad():
        for i, (images, masks, sampled_points) in enumerate(val_loader):
            if i >= num_samples:
                break

            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            masks = masks.cpu().numpy()

            # Create subplots: one row per image, three columns (Original, Ground Truth, Prediction)
            fig, axes = plt.subplots(len(images), 3, figsize=(12, 4 * len(images)))
            if len(images) == 1:
                axes = [axes]  # Ensure axes is iterable for one sample

            for j in range(len(images)):
                axes[j][0].imshow(images[j].permute(1, 2, 0).cpu().numpy())
                axes[j][0].set_title("Original Image")
                axes[j][0].axis("off")

                axes[j][1].imshow(masks[j], cmap="gray")
                axes[j][1].set_title("Ground Truth")
                axes[j][1].axis("off")

                axes[j][2].imshow(predictions[j], cmap="gray")
                axes[j][2].set_title("Model Prediction")
                axes[j][2].axis("off")

            plt.tight_layout()
            plt.show()

In [None]:
 ## Run the Full Pipeline
#
# Finally, we process the dataset, create data loaders, train the model, evaluate its performance on the validation set, and visualize predictions.

if __name__ == "__main__":
    # Load filenames and process dataset
    dataset_splits = load_filenames()
    processed_data = process_dataset()
    print("Dataset processing complete! Ready for training.")
    print(f"Processed Train Samples: {len(processed_data['train'])}")
    print(f"Processed Validation Samples: {len(processed_data['val'])}")

    # Create DataLoaders
    train_loader, val_loader = create_dataloaders(processed_data)
    print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

    # Train the model
    train_model(model, train_loader, val_loader, criterion, optimizer)

    # Evaluate the model on validation set
    evaluate_model(model, val_loader)

    # Visualize predictions on validation samples
    visualize_predictions(model, val_loader)

    # Optionally, save the trained model
    torch.save(model.state_dict(), "unet_model.pth")
    print("Model saved successfully!")

Image: /content/OpenEarthMap/images/train/aachen_1.tif, Mask: /content/OpenEarthMap/label/train/aachen_1.tif
Image: /content/OpenEarthMap/images/train/aachen_10.tif, Mask: /content/OpenEarthMap/label/train/aachen_10.tif
Image: /content/OpenEarthMap/images/train/aachen_12.tif, Mask: /content/OpenEarthMap/label/train/aachen_12.tif
Image: /content/OpenEarthMap/images/train/aachen_13.tif, Mask: /content/OpenEarthMap/label/train/aachen_13.tif
Image: /content/OpenEarthMap/images/train/aachen_14.tif, Mask: /content/OpenEarthMap/label/train/aachen_14.tif
Image: /content/OpenEarthMap/images/val/aachen_11.tif, Mask: /content/OpenEarthMap/label/val/aachen_11.tif
Image: /content/OpenEarthMap/images/val/aachen_39.tif, Mask: /content/OpenEarthMap/label/val/aachen_39.tif
Image: /content/OpenEarthMap/images/val/aachen_42.tif, Mask: /content/OpenEarthMap/label/val/aachen_42.tif
Image: /content/OpenEarthMap/images/val/aachen_5.tif, Mask: /content/OpenEarthMap/label/val/aachen_5.tif
Image: /content/OpenE

Processing train set:   5%|▌         | 158/3000 [00:04<00:22, 125.50it/s]



Processing train set:  31%|███       | 922/3000 [00:25<00:08, 257.53it/s]



Processing train set:  34%|███▍      | 1024/3000 [00:26<00:11, 172.27it/s]



Processing train set:  43%|████▎     | 1282/3000 [00:31<00:12, 142.61it/s]



Processing train set:  46%|████▌     | 1385/3000 [00:32<00:10, 158.84it/s]



Processing train set:  48%|████▊     | 1445/3000 [00:33<00:14, 110.14it/s]



Processing train set:  56%|█████▌    | 1682/3000 [00:39<00:10, 131.47it/s]



Processing train set:  64%|██████▍   | 1923/3000 [00:43<00:03, 296.60it/s]



Processing train set:  67%|██████▋   | 2003/3000 [00:43<00:04, 211.68it/s]



Processing train set:  71%|███████   | 2125/3000 [00:46<00:15, 57.79it/s]



Processing train set:  74%|███████▎  | 2211/3000 [00:48<00:11, 68.72it/s]



Processing train set:  76%|███████▌  | 2271/3000 [00:48<00:07, 95.82it/s]



Processing train set:  84%|████████▍ | 2514/3000 [00:55<00:04, 111.77it/s]



Processing train set:  89%|████████▊ | 2658/3000 [00:57<00:02, 164.27it/s]



Processing train set:  94%|█████████▍| 2813/3000 [01:21<00:10, 17.44it/s]