# Multi-Class Object Detection with NWPU-VHR-10

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/object_detection_nwpu.ipynb)

This notebook demonstrates end-to-end multi-class object detection using the [NWPU-VHR-10](https://github.com/chaozhong2010/VHR-10_dataset_coco) dataset, a benchmark for object detection in very high resolution (VHR) remote sensing imagery.

The dataset contains 800 images with 10 object classes:
- airplane, ship, storage tank, baseball diamond, tennis court
- basketball court, ground track field, harbor, bridge, vehicle

## Install package
To use the `geoai-py` package, ensure it is installed in your environment. Uncomment the command below if needed.

In [None]:
# %pip install geoai-py

## Import libraries

In [None]:
import json
import os

import geoai

## Download NWPU-VHR-10 dataset

In [None]:
data_dir = geoai.download_nwpu_vhr10()

## Explore the dataset

In [None]:
print(f"Dataset directory: {data_dir}")
print(f"Contents: {os.listdir(data_dir)}")

In [None]:
print(f"\nNWPU-VHR-10 Classes:")
for i, name in enumerate(geoai.NWPU_VHR10_CLASSES):
    print(f"  {i}: {name}")

## Prepare dataset

Split the dataset into training and validation sets.

In [None]:
splits = geoai.prepare_nwpu_vhr10(data_dir, val_split=0.2, seed=42)

In [None]:
print(f"Images directory: {splits['images_dir']}")
print(f"Number of classes: {splits['num_classes']}")
print(f"Class names: {splits['class_names']}")
print(f"Training images: {len(splits['train_image_ids'])}")
print(f"Validation images: {len(splits['val_image_ids'])}")

## Visualize sample annotations

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from PIL import Image

# Load annotations
with open(splits["annotations_path"], "r") as f:
    coco_data = json.load(f)

# Get a few sample images
sample_images = coco_data["images"][:4]
categories = {cat["id"]: cat["name"] for cat in coco_data["categories"]}
cmap = plt.cm.get_cmap("tab10", 10)

fig, axes = plt.subplots(2, 2, figsize=(14, 14))
axes = axes.flatten()

for ax_idx, img_info in enumerate(sample_images):
    img_path = os.path.join(splits["images_dir"], img_info["file_name"])
    img = Image.open(img_path)
    axes[ax_idx].imshow(img)
    axes[ax_idx].set_title(img_info["file_name"], fontsize=10)
    axes[ax_idx].axis("off")

    # Draw annotations for this image
    img_anns = [
        ann for ann in coco_data["annotations"] if ann["image_id"] == img_info["id"]
    ]
    for ann in img_anns:
        x, y, w, h = ann["bbox"]
        cat_id = ann["category_id"]
        color = cmap(cat_id % 10)
        rect = plt.Rectangle(
            (x, y), w, h, linewidth=2, edgecolor=color, facecolor="none"
        )
        axes[ax_idx].add_patch(rect)
        axes[ax_idx].text(
            x,
            y - 3,
            categories.get(cat_id, str(cat_id)),
            color="white",
            fontsize=7,
            bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
        )

plt.tight_layout()
plt.show()

## Use pretrained model from HuggingFace

A pretrained Mask R-CNN model for NWPU-VHR-10 is available on HuggingFace Hub. You can download it directly and run inference without training. If you prefer to train your own model, skip to the "Train multi-class detection model" section below.

In [None]:
model_path = geoai.download_nwpu_vhr10_model()

Run inference on a sample image using the pretrained model. The `multiclass_detection` function will use the NWPU-VHR-10 class names automatically when using the pretrained model.

In [None]:
# Pick a sample image from the dataset
sample_img_path = os.path.join(splits["images_dir"], "012.jpg")
output_raster = "nwpu_pretrained_output.tif"

result_path, inference_time, detections = geoai.multiclass_detection(
    input_path=sample_img_path,
    output_path=output_raster,
    model_path=model_path,
    confidence_threshold=0.5,
)

print(f"Inference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")

In [None]:
geoai.visualize_multiclass_detections(
    image_path=sample_img_path,
    detections=detections,
    confidence_threshold=0.5,
    figsize=(12, 10),
)

You can also call `multiclass_detection` without specifying `model_path` at all. It will automatically download the pretrained model and use the NWPU-VHR-10 class names.

In [None]:
result_path, inference_time, detections = geoai.multiclass_detection(
    input_path=sample_img_path,
    output_path="nwpu_auto_output.tif",
    confidence_threshold=0.5,
)

print(f"Inference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")

# Clean up temporary output files
for f in ["nwpu_pretrained_output.tif", "nwpu_auto_output.tif"]:
    if os.path.exists(f):
        os.remove(f)

## Train multi-class detection model (Optional)

Alternatively, you can train your own Mask R-CNN model from scratch on the NWPU-VHR-10 dataset. This section is optional if you are using the pretrained model above.

In [None]:
output_dir = "nwpu_output"

model_path = geoai.train_multiclass_detector(
    images_dir=splits["images_dir"],
    annotations_path=splits["train_annotations"],
    output_dir=output_dir,
    class_names=splits["class_names"],
    num_channels=3,
    batch_size=4,
    num_epochs=20,
    learning_rate=0.005,
    val_split=0.15,
    seed=42,
    pretrained=True,
    verbose=True,
)

## Plot training metrics

In [None]:
import torch

history_path = os.path.join(output_dir, "training_history.pth")
if os.path.exists(history_path):
    history = torch.load(history_path, weights_only=True)

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(history["epochs"], history["train_loss"], label="Train Loss")
    axes[0].plot(history["epochs"], history["val_loss"], label="Val Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Training & Validation Loss")
    axes[0].legend()

    axes[1].plot(history["epochs"], history["val_iou"], label="Val IoU", color="green")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("IoU")
    axes[1].set_title("Validation IoU")
    axes[1].legend()

    axes[2].plot(
        history["epochs"], history["lr"], label="Learning Rate", color="orange"
    )
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("LR")
    axes[2].set_title("Learning Rate Schedule")
    axes[2].legend()

    plt.tight_layout()
    plt.show()

## Evaluate model with COCO metrics

In [None]:
metrics = geoai.evaluate_multiclass_detector(
    model_path=model_path,
    images_dir=splits["images_dir"],
    annotations_path=splits["val_annotations"],
    num_classes=splits["num_classes"],
    class_names=splits["class_names"][1:],  # Exclude background
    batch_size=4,
)

## Run inference on sample images

In [None]:
# Pick a validation image for inference
with open(splits["val_annotations"], "r") as f:
    val_data = json.load(f)

# Find an image with multiple object types
test_img_info = val_data["images"][0]
test_img_path = os.path.join(splits["images_dir"], test_img_info["file_name"])
print(f"Test image: {test_img_path}")

In [None]:
output_raster = "nwpu_detection_output.tif"

result_path, inference_time, detections = geoai.multiclass_detection(
    input_path=test_img_path,
    output_path=output_raster,
    model_path=model_path,
    num_classes=splits["num_classes"],
    class_names=splits["class_names"],
    window_size=512,
    overlap=256,
    confidence_threshold=0.5,
    batch_size=4,
    num_channels=3,
)

print(f"\nInference time: {inference_time:.2f}s")
print(f"Total detections: {len(detections)}")

## Visualize detections

In [None]:
geoai.visualize_multiclass_detections(
    image_path=test_img_path,
    detections=detections,
    class_names=splits["class_names"],
    confidence_threshold=0.5,
    figsize=(12, 10),
)

## Batch inference on multiple validation images

In [None]:
# Run inference on a few validation images and display results
num_samples = min(4, len(val_data["images"]))
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()

for idx in range(num_samples):
    img_info = val_data["images"][idx]
    img_path = os.path.join(splits["images_dir"], img_info["file_name"])
    out_path = f"nwpu_detection_{idx}.tif"

    _, _, dets = geoai.multiclass_detection(
        input_path=img_path,
        output_path=out_path,
        model_path=model_path,
        num_classes=splits["num_classes"],
        class_names=splits["class_names"],
        confidence_threshold=0.5,
        num_channels=3,
    )

    # Display
    img = Image.open(img_path)
    axes[idx].imshow(img)
    axes[idx].set_title(
        f"{img_info['file_name']} ({len(dets)} detections)", fontsize=10
    )
    axes[idx].axis("off")

    for det in dets:
        box = det["box"]
        label = det["label"]
        score = det["score"]
        color = cmap(label % 10)
        rect = plt.Rectangle(
            (box[0], box[1]),
            box[2] - box[0],
            box[3] - box[1],
            linewidth=2,
            edgecolor=color,
            facecolor="none",
        )
        axes[idx].add_patch(rect)
        name = (
            splits["class_names"][label]
            if label < len(splits["class_names"])
            else str(label)
        )
        axes[idx].text(
            box[0],
            box[1] - 3,
            f"{name}: {score:.2f}",
            color="white",
            fontsize=7,
            bbox=dict(boxstyle="round,pad=0.2", facecolor=color, alpha=0.7),
        )

    # Clean up temp file
    if os.path.exists(out_path):
        os.remove(out_path)

plt.tight_layout()
plt.show()

## Summary

In this notebook, we demonstrated:

1. **Downloading** the NWPU-VHR-10 remote sensing object detection dataset
2. **Preparing** train/validation splits from COCO-format annotations
3. **Using a pretrained model** from HuggingFace Hub for instant inference
4. **Training** a multi-class Mask R-CNN model for 10 object categories (optional)
5. **Evaluating** the model using COCO-style mAP metrics
6. **Running inference** on test images with multi-class detection
7. **Visualizing** detection results with colored bounding boxes