<a href="https://colab.research.google.com/github/matt-needle/project-data/blob/master/docs/examples/train_object_detection_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train an Object Detection Model with GeoAI

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/train_object_detection_model.ipynb)

## Install package
To use the `geoai-py` package, ensure it is installed in your environment. Uncomment the command below if needed.

In [3]:
%pip install geoai-py

Collecting geoai-py
  Using cached geoai_py-0.7.1-py2.py3-none-any.whl.metadata (6.7 kB)
Collecting buildingregulariser (from geoai-py)
  Using cached buildingregulariser-0.2.2-py3-none-any.whl.metadata (6.9 kB)
Collecting contextily (from geoai-py)
  Using cached contextily-1.6.2-py3-none-any.whl.metadata (2.9 kB)
Collecting jupyter-server-proxy (from geoai-py)
  Using cached jupyter_server_proxy-4.4.0-py3-none-any.whl.metadata (8.7 kB)
Collecting leafmap (from geoai-py)
  Using cached leafmap-0.48.6-py2.py3-none-any.whl.metadata (16 kB)
Collecting localtileserver (from geoai-py)
  Using cached localtileserver-0.10.6-py3-none-any.whl.metadata (5.2 kB)
Collecting mapclassify (from geoai-py)
  Using cached mapclassify-2.9.0-py3-none-any.whl.metadata (3.1 kB)
Collecting maplibre (from geoai-py)
  Using cached maplibre-0.3.4-py3-none-any.whl.metadata (3.9 kB)
Collecting overturemaps (from geoai-py)
  Using cached overturemaps-0.15.0-py3-none-any.whl.metadata (4.0 kB)
Collecting planetary-

## Import libraries

In [4]:
import os
import geoai

out_folder = "output"

In [5]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


## Download sample data

## Train object detection model

In [6]:
import os

from google.colab import drive
drive.mount('/content/drive')

out_folder = "output"

# print('images:', len(os.listdir('/content/drive/MyDrive/geoai/' + out_folder + '/images')), 'labels:', len(os.listdir('/content/drive/MyDrive/geoai/' + out_folder + '/labels')))

# Access your data
# !cp /content/drive/MyDrive/geoai/output .

Mounted at /content/drive


In [None]:
!cp /content/drive/MyDrive/geoai/output ./output -r

In [None]:
out_folder = '/content/output/'
out_folder = '/content/drive/MyDrive/geoai/output'

geoai.train_MaskRCNN_model(
    images_dir=f"{out_folder}/images",
    labels_dir=f"{out_folder}/labels",
    output_dir=f"{out_folder}/models",
    num_channels=4,
    pretrained=True,
    batch_size=12,
    num_epochs=100,
    learning_rate=0.001,
    val_split=0.1,
)

Using device: cpu
Found 948 image files and 12 label files
Using 12 matching files
Training on 10 images, validating on 2 images


Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:01<00:00, 117MB/s]


In [None]:
images_dir=f"{out_folder}/images"

os.listdir(images_dir)

## Run inference

In [None]:
masks_path = "naip_test_prediction.tif"
model_path = f"{out_folder}/models/best_model.pth"

In [None]:
geoai.object_detection(
    test_raster_path,
    masks_path,
    model_path,
    window_size=512,
    overlap=256,
    confidence_threshold=0.5,
    batch_size=4,
    num_channels=4,
)

## Vectorize masks

In [None]:
output_path = "naip_test_prediction.geojson"
gdf = geoai.orthogonalize(masks_path, output_path, epsilon=2)

## Visualize results

In [None]:
geoai.view_vector_interactive(output_path, tiles=test_raster_url)

In [None]:
geoai.create_split_map(
    left_layer=output_path,
    right_layer=test_raster_url,
    left_args={"style": {"color": "red", "fillOpacity": 0.2}},
    basemap=test_raster_url,
)

![image](https://github.com/user-attachments/assets/8dfcc69e-7a6c-408a-9fae-10b81b7d85dc)