# Tree Crown Detection using Mask R-CNN

This notebook implements a tree crown detection model using Mask R-CNN based on the methodology described in the paper. The model can detect and map tree crowns from Google Earth images.

## Setup and Dependencies

In [None]:
# Install required packages if needed
!pip install tensorflow
!pip install numpy
!pip install opencv-python==4.7.0.72
!pip install scikit-image
!pip install matplotlib

[31mERROR: Could not find a version that satisfies the requirement tf-nightly (from versions: none)[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: No matching distribution found for tf-nightly[0m[31m
[0m

In [7]:
!pip freeze #> requirements.txt

absl-py==2.1.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
astunparse==1.6.3
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
executing==2.2.0
fastjsonschema==2.21.1
flatbuffers==25.2.10
fonttools==4.56.0
fqdn==1.5.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h11==0.14.0
h5py==3.13.0
httpcore==1.0.7
httpx==0.28.1
idna==3.10
imageio==2.37.0
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jedi==0.19.2
Jinja2==3.1.5
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterla

In [3]:
# Clone Mask RCNN repository if not already installed
!git clone https://github.com/matterport/Mask_RCNN.git
!pip install -e Mask_RCNN

Cloning into 'Mask_RCNN'...
remote: Enumerating objects: 956, done.[K
remote: Total 956 (delta 0), reused 0 (delta 0), pack-reused 956 (from 1)[K
Receiving objects: 100% (956/956), 137.67 MiB | 9.23 MiB/s, done.
Resolving deltas: 100% (558/558), done.
Obtaining file:///Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/Mask_RCNN
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: mask-rcnn
  Running setup.py develop for mask-rcnn
Successfully installed mask-rcnn-2.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [13]:
# Add Mask_RCNN to Python path
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow import keras
import cv2
import matplotlib.pyplot as plt
import random
import math
import re
import time
import skimage.draw
import skimage.io
import json
# Add the repository's root directory to Python path
repo_dir = os.path.abspath("./Mask_RCNN")
if repo_dir not in sys.path:
    sys.path.append(repo_dir)

# Now try importing
from mrcnn.config import Config
print("Import successful!")

Import successful!


## Configuration

Set up the directory structure and configure paths.

In [15]:
import os
import ssl
import urllib.request

# Temporarily disable SSL verification (use with caution)
ssl._create_default_https_context = ssl._create_unverified_context

# Set file paths
ROOT_DIR = os.path.abspath("./")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

# Download weights if needed
if not os.path.exists(COCO_MODEL_PATH):
    print("Downloading COCO weights...")
    urllib.request.urlretrieve(
        "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5",
        COCO_MODEL_PATH
    )
    print("Download completed.")

Downloading COCO weights...
Download completed.


## Use the existing model to test how it works

In [20]:
!pip install keras==2.2.4

Collecting keras==2.2.4
  Obtaining dependency information for keras==2.2.4 from https://files.pythonhosted.org/packages/5e/10/aa32dad071ce52b5502266b5c659451cfd6ffcbf14e6c8c4f16c0ff5aaab/Keras-2.2.4-py2.py3-none-any.whl.metadata
  Downloading Keras-2.2.4-py2.py3-none-any.whl.metadata (2.2 kB)
Collecting keras-applications>=1.0.6 (from keras==2.2.4)
  Obtaining dependency information for keras-applications>=1.0.6 from https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl.metadata
  Downloading Keras_Applications-1.0.8-py3-none-any.whl.metadata (1.7 kB)
Collecting keras-preprocessing>=1.0.5 (from keras==2.2.4)
  Obtaining dependency information for keras-preprocessing>=1.0.5 from https://files.pythonhosted.org/packages/79/4c/7c3275a01e12ef9368a892926ab932b33bb13d55794881e3573482b378a7/Keras_Preprocessing-1.1.2-py2.py3-none-any.whl.metadata
  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-a

In [24]:
!pip install opencv-python scikit-image tensorflow

Collecting keras>=3.5.0 (from tensorflow)
  Obtaining dependency information for keras>=3.5.0 from https://files.pythonhosted.org/packages/2b/98/e81c6b2cb522f0eadcc8e16f3cabaccd5462bff6cf52194acfed4a031d3f/keras-3.9.0-py3-none-any.whl.metadata
  Using cached keras-3.9.0-py3-none-any.whl.metadata (6.1 kB)
Using cached keras-3.9.0-py3-none-any.whl (1.3 MB)
Installing collected packages: keras
  Attempting uninstall: keras
    Found existing installation: Keras 2.2.4
    Uninstalling Keras-2.2.4:
      Successfully uninstalled Keras-2.2.4
Successfully installed keras-3.9.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# choose a different approach to use geoai-py

In [3]:
!pip install geoai-py

Collecting geoai-py
  Obtaining dependency information for geoai-py from https://files.pythonhosted.org/packages/19/75/18354dc9f89c410f9b3bd3b7fb4bb9e667fa174780faad4f17175b8a3c89/geoai_py-0.3.6-py2.py3-none-any.whl.metadata
  Downloading geoai_py-0.3.6-py2.py3-none-any.whl.metadata (6.3 kB)
Collecting albumentations (from geoai-py)
  Obtaining dependency information for albumentations from https://files.pythonhosted.org/packages/97/d3/cf3aab593209d1be5e4bca54aeea297225708bd25f06426d6b8ec3630a76/albumentations-2.0.5-py3-none-any.whl.metadata
  Downloading albumentations-2.0.5-py3-none-any.whl.metadata (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.7/41.7 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting contextily (from geoai-py)
  Obtaining dependency information for contextily from https://files.pythonhosted.org/packages/fb/46/07a029b73f9a5c7bbf9b538e6441c42014a448f335a1cc780616f2594bad/contextily-1.6.2-py3-none-any.whl.metadata
  Downloading 

In [4]:
import geoai

In [5]:
raster_url = (
    "https://huggingface.co/datasets/giswqs/geospatial/resolve/main/trees_brazil.tif"
)
raster_path = geoai.download_file(raster_url)

trees_brazil.tif: 100%|██████████| 3.93M/3.93M [00:00<00:00, 11.6MB/s]


In [6]:
geoai.view_raster(raster_url)

In [7]:
segmenter = geoai.CLIPSegmentation(tile_size=512, overlap=32)

preprocessor_config.json:   0%|          | 0.00/380 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/974 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.73k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/603M [00:00<?, ?B/s]

Model loaded on cpu


In [8]:
output_path = "tree_masks.tif"
text_prompt = "trees"

In [9]:
segmenter.segment_image(
    raster_path,
    output_path=output_path,
    text_prompt=text_prompt,
    threshold=0.5,
    smoothing_sigma=1.0,
)

Processing tiles: 100%|██████████| 15/15 [00:11<00:00,  1.26it/s]

Segmentation saved to tree_masks.tif





'tree_masks.tif'

In [10]:
geoai.view_raster(
    output_path,
    nodata=0,
    opacity=0.8,
    colormap="greens",
    layer_name="Trees",
    basemap=raster_url,
)

2025-03-10 16:36:19,820 - ERROR - Exception possibly due to cache backend.
Traceback (most recent call last):
  File "/Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/tf_env_py311/lib/python3.11/site-packages/flask_caching/__init__.py", line 435, in decorated_function
    self.cache.set(
  File "/Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/tf_env_py311/lib/python3.11/site-packages/cachelib/simple.py", line 75, in set
    self._prune()
  File "/Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/tf_env_py311/lib/python3.11/site-packages/cachelib/simple.py", line 52, in _prune
    self._remove_expired(now)
  File "/Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/tf_env_py311/lib/python3.11/site-packages/cachelib/simple.py", line 36, in _remove_expired
    toremove = [k for k, (expires, _) in self._cache.items() if expires < now]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/dynamicpacific/Dropbox/DEV/fore

In [11]:
geoai.create_split_map(
    left_layer=output_path,
    right_layer=raster_url,
    left_label="Trees",
    right_label="Satellite Image",
    left_args={"nodata": 0, "opacity": 0.8, "colormap": "greens"},
    basemap=raster_url,
)

Map(center=[20, 0], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_text…

In [13]:
import leafmap
from geoai.download import (
    download_naip
)

In [14]:
m = leafmap.Map(center=[40.785091, -73.968285], zoom=16)
#  Latitude: 40.785091
# Longitude: -73.968285
m.add_basemap("Google Satellite")
m

Map(center=[40.785091, -73.968285], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title'…

In [15]:
bbox = m.user_roi_bounds()
# if bbox is None:
#     bbox = (-117.6029, 47.65, -117.5936, 47.6563)

In [16]:
# Download NAIP imagery for the specified region
downloaded_files = download_naip(
    bbox=bbox,
    output_dir="naip_data",
    max_items=1,
    # year=2020,
)

print(f"Downloaded {len(downloaded_files)} files.")

Found 1 NAIP items.
Downloading item 1/1: m_4007309_sw_18_030_20230820_20231019.tif


m_4007309_sw_18_030_20230820_20231019.tif: 100%|██████████| 1.14G/1.14G [02:32<00:00, 8.07MiB/s] 

Successfully saved to naip_data/m_4007309_sw_18_030_20230820_20231019.tif
Downloaded 1 files.





In [17]:
geoai.view_raster(downloaded_files[0])

In [32]:
import sys
sys.path.insert(0, "/Users/dynamicpacific/Dropbox/DEV/geoai")

import geoai  # Now it will load from the local path
# segmenter = geoai.CLIPSegmentation(tile_size=512, overlap=32)

In [33]:
segmenter = geoai.CLIPSegmentation(tile_size=512, overlap=32)

Model loaded on cpu


In [34]:
output_path = "tree_masks_centralpark_mps.tif"
text_prompt = "trees"

In [35]:
segmenter.segment_image(
    downloaded_files[0],
    output_path=output_path,
    text_prompt=text_prompt,
    threshold=0.5,
    smoothing_sigma=1.0,
)

Processing tiles: 100%|██████████| 2173/2173 [22:49<00:00,  1.59it/s] 


Segmentation saved to tree_masks_centralpark_mps.tif


'tree_masks_centralpark_mps.tif'

In [31]:
geoai.view_raster(
    "tree_masks_centralpark.tif",
    nodata=0,
    opacity=0.6,
    colormap="reds",
    layer_name="Trees",
    basemap=downloaded_files[0],
)

In [36]:
# the result from text prompt "trees" is not good, let's try another using mask cnn
detector = geoai.SolarPanelDetector()  # Initialize the detector for solar panels (there are several model available)

Model path not specified, downloading from Hugging Face...


solar_panel_detection.pth:   0%|          | 0.00/176M [00:00<?, ?B/s]

Model downloaded to: /Users/dynamicpacific/.cache/huggingface/hub/models--giswqs--geoai/snapshots/f519baf6bb6d3c8b37b4f1f98e8d6a57bbc21f57/solar_panel_detection.pth
Model loaded successfully


In [1]:
# H5 to PTH conversion
import torch
import h5py
import numpy as np

# Load H5 model
h5_file = h5py.File('~/.cache/huggingface/hub/models--giswqs--geoai/snapshots/f519baf6bb6d3c8b37b4f1f98e8d6a57bbc21f57/mask_rcnn_shapes_0010.h5', 'r')

# Create PyTorch model structure (matching architecture)
pytorch_model = YourPyTorchModelClass()

# Map and transfer weights
# (Need to match layer names/structure between frameworks)
for name, param in pytorch_model.named_parameters():
    # Find corresponding weight in h5_file
    # This is the complex part requiring architecture knowledge
    
# Save as PyTorch model
torch.save(pytorch_model.state_dict(), 'mask_rcnn_shapes_0010.pth')


import torch
import h5py
import numpy as np
from torchvision.models.detection import maskrcnn_resnet50_fpn

def convert_h5_to_pth(h5_path, pth_path):
    """Convert Mask R-CNN H5 model to PyTorch PTH format."""
    print(f"Loading H5 model from: {h5_path}")
    
    # Load H5 file
    h5_file = h5py.File(h5_path, 'r')
    
    # Print H5 structure to help with mapping
    def print_structure(name, obj):
        if isinstance(obj, h5py.Dataset):
            print(f"Dataset: {name}, Shape: {obj.shape}, Dtype: {obj.dtype}")
    
    print("H5 Model Structure:")
    h5_file.visititems(print_structure)
    
    # Create PyTorch Mask R-CNN model
    num_classes = 2  # Background + trees
    pytorch_model = maskrcnn_resnet50_fpn(
        weights=None,
        progress=False,
        num_classes=num_classes
    )
    
    # Map H5 weights to PyTorch model
    state_dict = pytorch_model.state_dict()
    
    # Create mapping dictionary - this requires manual inspection of the H5 file structure
    # This is a starting point that needs customization based on actual model structure
    mapping = {
        # Example mappings (these need to be adjusted based on actual H5 structure)
        'mrcnn_class_logits': 'roi_heads.box_predictor.cls_score',
        'mrcnn_bbox_fc': 'roi_heads.box_predictor.bbox_pred',
        'mrcnn_mask': 'roi_heads.mask_predictor.mask_fcn_logits',
        'fpn_c5p5': 'backbone.fpn.inner_blocks.0',
        'fpn_c4p4': 'backbone.fpn.inner_blocks.1',
        'fpn_c3p3': 'backbone.fpn.inner_blocks.2',
        'fpn_c2p2': 'backbone.fpn.inner_blocks.3',
        'fpn_p5': 'backbone.fpn.layer_blocks.0',
        'fpn_p4': 'backbone.fpn.layer_blocks.1',
        'fpn_p3': 'backbone.fpn.layer_blocks.2',
        'fpn_p2': 'backbone.fpn.layer_blocks.3',
        # Add more mappings based on H5 structure
    }
    
    # Transfer weights using mapping
    converted_params = 0
    for h5_name, pt_name_prefix in mapping.items():
        if h5_name in h5_file:
            # Handle different parameter types (weights, biases)
            for suffix in ['weight', 'bias']:
                pt_name = f"{pt_name_prefix}.{suffix}"
                if pt_name in state_dict:
                    try:
                        # Get weight from H5
                        h5_weight = np.array(h5_file[h5_name][suffix])
                        
                        # Handle dimensionality differences
                        if suffix == 'weight' and len(h5_weight.shape) == 4:
                            # NHWC (TensorFlow) to NCHW (PyTorch) conversion for conv layers
                            h5_weight = np.transpose(h5_weight, (3, 2, 0, 1))
                        
                        # Convert to torch tensor
                        torch_weight = torch.from_numpy(h5_weight)
                        
                        # Check shape compatibility
                        if torch_weight.shape == state_dict[pt_name].shape:
                            state_dict[pt_name] = torch_weight
                            converted_params += 1
                            print(f"Converted: {h5_name}/{suffix} -> {pt_name}")
                        else:
                            print(f"Shape mismatch: {h5_name}/{suffix} {torch_weight.shape} vs {pt_name} {state_dict[pt_name].shape}")
                    except Exception as e:
                        print(f"Error converting {h5_name}/{suffix}: {e}")
        else:
            print(f"Warning: {h5_name} not found in H5 file")
    
    print(f"Converted {converted_params} parameters")
    
    # Load state dict into model
    pytorch_model.load_state_dict(state_dict, strict=False)
    
    # Save as PyTorch model
    torch.save(pytorch_model.state_dict(), pth_path)
    print(f"Saved PyTorch model to: {pth_path}")
    
    return pytorch_model

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Convert Mask R-CNN H5 model to PyTorch PTH')
    parser.add_argument('h5_path', help='Path to input H5 model file')
    parser.add_argument('pth_path', help='Path to output PTH model file')
    
    args = parser.parse_args()
    convert_h5_to_pth(args.h5_path, args.pth_path)

IndentationError: expected an indented block after 'for' statement on line 14 (3304413080.py, line 19)

## Model Configuration

Define the configuration for the Mask R-CNN model as specified in the paper.

In [2]:
class TreeCrownConfig(Config):
    """Configuration for training on the tree crown dataset.
    Derives from the base Config class and overrides specific values.
    """
    # Give the configuration a recognizable name
    NAME = "tree_crown"
    
    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # Background + Tree Crown
    
    # Number of training steps per epoch
    STEPS_PER_EPOCH = 100
    
    # Number of validation steps to run at the end of every training epoch
    VALIDATION_STEPS = 50
    
    # Learning rate and momentum (as described in the paper)
    LEARNING_RATE = 0.001
    
    # Backbone architecture for feature extraction
    BACKBONE = "resnet101"
    
    # Input image resizing - keep images with their original aspect ratio
    # and enforce a maximum size limit
    IMAGE_RESIZE_MODE = "square"
    IMAGE_MIN_DIM = 800
    IMAGE_MAX_DIM = 1024
    
    # ROIs below this threshold are discarded
    DETECTION_MIN_CONFIDENCE = 0.7

NameError: name 'Config' is not defined

## Dataset Handler

The TreeCrownDataset class manages the dataset loading and preprocessing.

In [None]:
class TreeCrownDataset(utils.Dataset):
    def load_tree_crowns(self, dataset_dir, subset):
        """Load a subset of the Tree Crown dataset.
        dataset_dir: Root directory of the dataset.
        subset: Subset to load: train or val
        """
        # Add classes. We have only one class to add.
        self.add_class("tree_crown", 1, "tree_crown")
        
        # Train or validation dataset?
        assert subset in ["train", "val"]
        dataset_dir = os.path.join(dataset_dir, subset)
        
        # Load annotations
        # LabelMe format (poly format annotations)
        annotations = self.load_labelme_annotations(dataset_dir)
        
        for a in annotations:
            # Get the x, y coordinates of points of the polygons that make up
            # the outline of each object instance
            polygons = a['polygons']
            image_path = os.path.join(dataset_dir, a['filename'])
            
            # Load the image
            image = skimage.io.imread(image_path)
            height, width = image.shape[:2]
            
            self.add_image(
                "tree_crown",
                image_id=a['filename'],  # use file name as a unique image id
                path=image_path,
                width=width, height=height,
                polygons=polygons)
    
    def load_labelme_annotations(self, dataset_dir):
        """Load LabelMe annotations for tree crown polygons.
        This is specifically designed for the annotation format
        used in the paper with Labelme tool.
        """
        # Implementation would depend on specific format of annotations
        # For this example, assuming we have a JSON file for each image
        # with polygon coordinates for tree crowns
        
        annotations = []
        
        # Scan through all files in the directory
        for filename in os.listdir(dataset_dir):
            if filename.endswith('.json'):  # Labelme annotations are typically JSON
                json_path = os.path.join(dataset_dir, filename)
                
                # Parse the JSON file
                with open(json_path) as f:
                    data = json.load(f)
                
                # Extract image filename from JSON
                image_filename = data['imagePath']
                
                # Extract polygons - adapt this to match actual Labelme format
                polygons = []
                for shape in data['shapes']:
                    if shape['label'] == 'tree_crown':
                        # Convert points to array format
                        points = np.array(shape['points'], dtype=np.int32)
                        polygons.append(points)
                
                annotations.append({
                    'filename': image_filename,
                    'polygons': polygons
                })
        
        return annotations
    
    def load_mask(self, image_id):
        """Generate instance masks for an image.
        Returns:
        masks: A bool array of shape [height, width, instance count] with
            one mask per instance.
        class_ids: a 1D array of class IDs of the instance masks.
        """
        # If not a tree crown dataset image, delegate to parent class.
        image_info = self.image_info[image_id]
        if image_info["source"] != "tree_crown":
            return super(self.__class__, self).load_mask(image_id)
        
        # Convert polygons to a bitmap mask of shape
        # [height, width, instance_count]
        info = self.image_info[image_id]
        mask = np.zeros([info["height"], info["width"], len(info["polygons"])],
                       dtype=np.uint8)
        
        for i, p in enumerate(info["polygons"]):
            # Get indexes of pixels inside the polygon and set them to 1
            rr, cc = skimage.draw.polygon(p[:, 1], p[:, 0])
            mask[rr, cc, i] = 1
        
        # Return mask, and array of class IDs of each instance
        return mask.astype(np.bool), np.ones([mask.shape[-1]], dtype=np.int32)
    
    def image_reference(self, image_id):
        """Return the path of the image."""
        info = self.image_info[image_id]
        if info["source"] == "tree_crown":
            return info["path"]
        else:
            super(self.__class__, self).image_reference(image_id)

## Image Splitting Function

This function divides large satellite images into smaller sub-images as described in the paper.

In [None]:
def split_image(image_path, output_dir, tile_size=(935, 910)):
    """Split a large Google Earth image into smaller sub-images
    as described in the paper.
    
    Args:
        image_path: Path to the large image
        output_dir: Directory to save the sub-images
        tile_size: Size of the sub-images (width, height)
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Load the image
    img = cv2.imread(image_path)
    h, w = img.shape[:2]
    
    # Calculate number of tiles
    n_h = math.ceil(h / tile_size[1])
    n_w = math.ceil(w / tile_size[0])
    
    print(f"Splitting image of size {w}x{h} into {n_w}x{n_h} tiles")
    
    # Split the image
    count = 0
    for i in range(n_h):
        for j in range(n_w):
            x = j * tile_size[0]
            y = i * tile_size[1]
            
            # Handle edge cases
            x_end = min(x + tile_size[0], w)
            y_end = min(y + tile_size[1], h)
            
            # Extract tile
            tile = img[y:y_end, x:x_end]
            
            # Save tile
            tile_path = os.path.join(output_dir, f"tile_{count:03d}.jpg")
            cv2.imwrite(tile_path, tile)
            count += 1
    
    print(f"Split image into {count} tiles")
    return count

# Example usage
# split_image("large_satellite_image.jpg", "tiles/")

## Model Training Function

This function implements the training pipeline for the Mask R-CNN model.

In [None]:
def train_model(config, dataset_dir):
    """Train the Mask R-CNN model for tree crown detection.
    
    Args:
        config: TreeCrownConfig instance
        dataset_dir: Directory containing the dataset
    """
    # Create model in training mode
    model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
    
    # Load COCO weights as starting point
    model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=[
        "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"
    ])
    
    # Load training dataset
    dataset_train = TreeCrownDataset()
    dataset_train.load_tree_crowns(dataset_dir, "train")
    dataset_train.prepare()
    
    # Load validation dataset
    dataset_val = TreeCrownDataset()
    dataset_val.load_tree_crowns(dataset_dir, "val")
    dataset_val.prepare()
    
    # Train the model
    # First, train only the heads (as per the paper's approach)
    print("Training network heads")
    model.train(dataset_train, dataset_val,
               learning_rate=config.LEARNING_RATE,
               epochs=5,
               layers='heads')
    
    # Fine-tune all layers
    print("Fine-tuning all layers")
    model.train(dataset_train, dataset_val,
               learning_rate=config.LEARNING_RATE / 10,
               epochs=10,
               layers='all')
    
    return model

# Example usage
# config = TreeCrownConfig()
# model = train_model(config, "dataset_directory/")

## Detection Function

This function detects tree crowns in new images using the trained model.

In [None]:
def detect_tree_crowns(model, image_path, output_path=None):
    """Detect tree crowns in an image and save the result.
    
    Args:
        model: Trained Mask R-CNN model
        image_path: Path to the input image
        output_path: Path to save the output visualization
        
    Returns:
        Detection results
    """
    # Read the image
    image = skimage.io.imread(image_path)
    
    # Detect tree crowns
    results = model.detect([image], verbose=1)
    r = results[0]
    
    # Visualize results
    fig = plt.figure(figsize=(12, 12))
    visualize.display_instances(
        image, r['rois'], r['masks'], r['class_ids'],
        ['BG', 'Tree Crown'], r['scores'],
        title="Tree Crown Detection",
        figsize=(12, 12)
    )
    
    # Save the figure if output_path is specified
    if output_path:
        plt.savefig(output_path)
        plt.close()
    else:
        plt.show()
    
    return r

# Example usage
# detect_tree_crowns(model, "test_image.jpg", "result.png")

## Results Analysis

This function analyzes the detection results to get statistics about tree crowns.

In [None]:
def analyze_results(results):
    """Analyze the detection results to get statistics about tree crowns.
    
    Args:
        results: List of detection results for multiple images
    
    Returns:
        Dictionary with statistics
    """
    total_trees = 0
    total_area = 0
    area_distribution = {}
    bin_size = 50  # bin size in m²
    
    for r in results:
        n_trees = r['masks'].shape[-1]
        total_trees += n_trees
        
        # Calculate area for each tree crown
        for i in range(n_trees):
            mask = r['masks'][:, :, i]
            area = np.sum(mask) * (0.27**2)  # Convert pixels to m² (0.27m resolution)
            total_area += area
            
            # Update area distribution
            bin_idx = int(area / bin_size)
            if bin_idx not in area_distribution:
                area_distribution[bin_idx] = 0
            area_distribution[bin_idx] += 1
    
    # Prepare distribution data for plotting
    area_bins = []
    tree_counts = []
    for bin_idx in sorted(area_distribution.keys()):
        min_area = bin_idx * bin_size
        max_area = (bin_idx + 1) * bin_size
        area_bins.append(f"[{min_area}, {max_area})")
        tree_counts.append(area_distribution[bin_idx])
    
    stats = {
        'total_trees': total_trees,
        'total_area': total_area,
        'area_distribution': {
            'bins': area_bins,
            'counts': tree_counts
        }
    }
    
    return stats

def plot_area_distribution(stats):
    """Plot the distribution of tree crown areas.
    
    Args:
        stats: Statistics dictionary returned by analyze_results
    """
    plt.figure(figsize=(12, 6))
    plt.bar(stats['area_distribution']['bins'], 
           stats['area_distribution']['counts'])
    plt.xlabel('Crown Area (m²)')
    plt.ylabel('Number of Trees')
    plt.title('Distribution of Tree Crown Areas')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## Complete Workflow Example

Here's an example of a complete workflow using the functions defined above.

In [None]:
# Example workflow - uncomment and modify as needed

# 1. Split a large satellite image into smaller tiles
# split_image("large_satellite_image.jpg", "tiles/")

# 2. Create and train the model
# config = TreeCrownConfig()
# model = train_model(config, "dataset_directory/")

# 3. Load a trained model for inference
# config = TreeCrownConfig()
# config.BATCH_SIZE = 1  # For inference
# model = modellib.MaskRCNN(mode="inference", config=config, model_dir=MODEL_DIR)
# weights_path = model.find_last()  # Find last trained weights
# model.load_weights(weights_path, by_name=True)

# 4. Process all images in a directory
# results = []
# input_dir = "test_images/"
# output_dir = "results/"
# os.makedirs(output_dir, exist_ok=True)
# for filename in os.listdir(input_dir):
#     if filename.endswith(('.jpg', '.jpeg', '.png')):
#         image_path = os.path.join(input_dir, filename)
#         output_path = os.path.join(output_dir, f"result_{os.path.splitext(filename)[0]}.png")
#         print(f"Processing {image_path}")
#         result = detect_tree_crowns(model, image_path, output_path)
#         results.append(result)

# 5. Analyze and visualize results
# stats = analyze_results(results)
# print(f"Total trees detected: {stats['total_trees']}")
# print(f"Total crown area: {stats['total_area']:.2f} m²")
# plot_area_distribution(stats)

## Processing a Single Image Example

You can use this cell to process a single test image.

In [None]:
# Example for processing a single image
# config = TreeCrownConfig()
# config.BATCH_SIZE = 1  # For inference
# model = modellib.MaskRCNN(mode="inference", config=config, model_dir=MODEL_DIR)
# weights_path = model.find_last()  # Find last trained weights
# model.load_weights(weights_path, by_name=True)

# result = detect_tree_crowns(model, "test_image.jpg")
# print(f"Detected {result['masks'].shape[-1]} trees")

# # Calculate average crown area
# areas = []
# for i in range(result['masks'].shape[-1]):
#     mask = result['masks'][:, :, i]
#     area = np.sum(mask) * (0.27**2)  # Convert pixels to m² (0.27m resolution)
#     areas.append(area)

# avg_area = np.mean(areas) if areas else 0
# print(f"Average crown area: {avg_area:.2f} m²")