# Tree Crown Detection using Mask R-CNN

This notebook implements a tree crown detection model using Mask R-CNN based on the methodology described in the paper. The model can detect and map tree crowns from Google Earth images.

## Setup and Dependencies

In [None]:
# Install required packages if needed
!pip install tensorflow
!pip install numpy
!pip install opencv-python==4.7.0.72
!pip install scikit-image
!pip install matplotlib

[31mERROR: Could not find a version that satisfies the requirement tf-nightly (from versions: none)[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: No matching distribution found for tf-nightly[0m[31m
[0m

In [7]:
!pip freeze #> requirements.txt

absl-py==2.1.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
astunparse==1.6.3
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
executing==2.2.0
fastjsonschema==2.21.1
flatbuffers==25.2.10
fonttools==4.56.0
fqdn==1.5.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h11==0.14.0
h5py==3.13.0
httpcore==1.0.7
httpx==0.28.1
idna==3.10
imageio==2.37.0
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jedi==0.19.2
Jinja2==3.1.5
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterla

In [3]:
# Clone Mask RCNN repository if not already installed
!git clone https://github.com/matterport/Mask_RCNN.git
!pip install -e Mask_RCNN

Cloning into 'Mask_RCNN'...
remote: Enumerating objects: 956, done.[K
remote: Total 956 (delta 0), reused 0 (delta 0), pack-reused 956 (from 1)[K
Receiving objects: 100% (956/956), 137.67 MiB | 9.23 MiB/s, done.
Resolving deltas: 100% (558/558), done.
Obtaining file:///Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/Mask_RCNN
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: mask-rcnn
  Running setup.py develop for mask-rcnn
Successfully installed mask-rcnn-2.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [13]:
# Add Mask_RCNN to Python path
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow import keras
import cv2
import matplotlib.pyplot as plt
import random
import math
import re
import time
import skimage.draw
import skimage.io
import json
# Add the repository's root directory to Python path
repo_dir = os.path.abspath("./Mask_RCNN")
if repo_dir not in sys.path:
    sys.path.append(repo_dir)

# Now try importing
from mrcnn.config import Config
print("Import successful!")

Import successful!


## Configuration

Set up the directory structure and configure paths.

In [15]:
import os
import ssl
import urllib.request

# Temporarily disable SSL verification (use with caution)
ssl._create_default_https_context = ssl._create_unverified_context

# Set file paths
ROOT_DIR = os.path.abspath("./")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

# Download weights if needed
if not os.path.exists(COCO_MODEL_PATH):
    print("Downloading COCO weights...")
    urllib.request.urlretrieve(
        "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5",
        COCO_MODEL_PATH
    )
    print("Download completed.")

Downloading COCO weights...
Download completed.


## Use the existing model to test how it works

In [20]:
!pip install keras==2.2.4

Collecting keras==2.2.4
  Obtaining dependency information for keras==2.2.4 from https://files.pythonhosted.org/packages/5e/10/aa32dad071ce52b5502266b5c659451cfd6ffcbf14e6c8c4f16c0ff5aaab/Keras-2.2.4-py2.py3-none-any.whl.metadata
  Downloading Keras-2.2.4-py2.py3-none-any.whl.metadata (2.2 kB)
Collecting keras-applications>=1.0.6 (from keras==2.2.4)
  Obtaining dependency information for keras-applications>=1.0.6 from https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications-1.0.8-py3-none-any.whl.metadata
  Downloading Keras_Applications-1.0.8-py3-none-any.whl.metadata (1.7 kB)
Collecting keras-preprocessing>=1.0.5 (from keras==2.2.4)
  Obtaining dependency information for keras-preprocessing>=1.0.5 from https://files.pythonhosted.org/packages/79/4c/7c3275a01e12ef9368a892926ab932b33bb13d55794881e3573482b378a7/Keras_Preprocessing-1.1.2-py2.py3-none-any.whl.metadata
  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-a

In [24]:
!pip install opencv-python scikit-image tensorflow

Collecting keras>=3.5.0 (from tensorflow)
  Obtaining dependency information for keras>=3.5.0 from https://files.pythonhosted.org/packages/2b/98/e81c6b2cb522f0eadcc8e16f3cabaccd5462bff6cf52194acfed4a031d3f/keras-3.9.0-py3-none-any.whl.metadata
  Using cached keras-3.9.0-py3-none-any.whl.metadata (6.1 kB)
Using cached keras-3.9.0-py3-none-any.whl (1.3 MB)
Installing collected packages: keras
  Attempting uninstall: keras
    Found existing installation: Keras 2.2.4
    Uninstalling Keras-2.2.4:
      Successfully uninstalled Keras-2.2.4
Successfully installed keras-3.9.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import os
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import label_map_util

# Install TensorFlow Object Detection API if needed
if not os.path.exists("models"):
    os.system("git clone --depth 1 https://github.com/tensorflow/models.git")
    os.system("cd models/research && protoc object_detection/protos/*.proto --python_out=.")
    os.system("cd models/research && pip install -e .")

# Path to the TensorFlow models directory
PATH_TO_MODELS = "./models"

# Download a pre-trained Mask R-CNN model
MODEL_DATE = "20200711"
MODEL_NAME = "mask_rcnn_inception_resnet_v2_1024x1024_coco17_gpu-8"
MODEL_PATH = f"{PATH_TO_MODELS}/research/object_detection/test_data/{MODEL_NAME}/saved_model"

# Download if model doesn't exist
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH, exist_ok=True)
    os.system(f"wget http://download.tensorflow.org/models/object_detection/tf2/20200711/{MODEL_NAME}.tar.gz")
    os.system(f"tar -xzvf {MODEL_NAME}.tar.gz -C {MODEL_PATH}")

# Load the model
print("Loading model... This might take a minute.")
detect_fn = tf.saved_model.load(MODEL_PATH)

# Load label map
PATH_TO_LABELS = f"{PATH_TO_MODELS}/research/object_detection/data/mscoco_label_map.pbtxt"
category_index = label_map_util.create_category_index_from_labelmap(
    PATH_TO_LABELS, use_display_name=True)

def detect_vegetation(image_path, output_path=None):
    """Detect vegetation in an image using TensorFlow's Object Detection API."""
    # Read image
    image_np = cv2.imread(image_path)
    image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    
    # Convert image to tensor
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.uint8)
    
    # Run detection
    print("Running inference...")
    detections = detect_fn(input_tensor)
    
    # Extract useful data from the result
    boxes = detections['detection_boxes'][0].numpy()
    classes = detections['detection_classes'][0].numpy().astype(np.int32)
    scores = detections['detection_scores'][0].numpy()
    masks = None
    if 'detection_masks' in detections:
        masks = detections['detection_masks'][0].numpy()
    
    # Define vegetation class IDs (potted plant = 64, tree = 47 in COCO)
    vegetation_class_ids = [47, 64]  # COCO IDs for vegetation-related classes
    
    # Filter for vegetation
    vegetation_indices = np.where(np.isin(classes, vegetation_class_ids) & (scores > 0.5))[0]
    
    # Visualize results
    image_np_with_detections = image_np.copy()
    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        boxes[vegetation_indices],
        classes[vegetation_indices],
        scores[vegetation_indices],
        category_index,
        instance_masks=masks[vegetation_indices] if masks is not None else None,
        use_normalized_coordinates=True,
        line_thickness=2
    )
    
    # Display and save results
    plt.figure(figsize=(12, 8))
    plt.imshow(image_np_with_detections)
    plt.axis('off')
    
    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
        print(f"Detection results saved to {output_path}")
    
    plt.show()
    
    # Print statistics
    num_detections = len(vegetation_indices)
    print(f"Detected {num_detections} vegetation instances")
    
    # Calculate mask areas if masks are available
    if masks is not None and num_detections > 0:
        vegetation_masks = masks[vegetation_indices]
        areas = np.sum(vegetation_masks, axis=(1, 2))
        avg_area = np.mean(areas)
        print(f"Average vegetation area: {avg_area:.2f} pixels")
    
    return {
        'num_detections': num_detections,
        'classes': classes[vegetation_indices],
        'scores': scores[vegetation_indices],
        'boxes': boxes[vegetation_indices],
        'masks': masks[vegetation_indices] if masks is not None else None
    }

# Example usage
# detect_vegetation("../data/Nursury_Screenshot_GoogleEarth.jpg", "../data/output_detection.png")

Cloning TensorFlow 2.x compatible Mask R-CNN repository...


Cloning into '/Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/TF2-Mask_RCNN'...


Successfully imported Mask R-CNN modules
Loading weights from /Users/dynamicpacific/Dropbox/DEV/forestai-platform-model/mask_rcnn_coco.h5


NotImplementedError: numpy() is only available when eager execution is enabled.

In [None]:
import os
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
import urllib.request
import tarfile

def download_model():
    """Download and extract pre-trained Mask R-CNN model for OpenCV."""
    model_path = "mask_rcnn_inception_v2_coco_2018_01_28"
    weights_path = f"{model_path}/frozen_inference_graph.pb"
    
    # Check if model already exists
    if os.path.exists(weights_path):
        print("Model files already exist.")
        return weights_path
    
    # Download model if needed
    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)
        
        # Download the model
        model_url = "http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz"
        tar_file = f"{model_path}.tar.gz"
        
        print(f"Downloading model from {model_url}...")
        urllib.request.urlretrieve(model_url, tar_file)
        
        # Extract the model
        print("Extracting model...")
        with tarfile.open(tar_file, "r:gz") as tar:
            tar.extractall()
        
        # Clean up
        os.remove(tar_file)
    
    return weights_path

def detect_trees(image_path, output_path=None):
    """Detect vegetation in an image using OpenCV's DNN module with a pre-trained network."""
    # Download the model
    weights_path = download_model()
    
    # Load image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not load image from {image_path}")
    
    height, width = image.shape[:2]
    
    # The pre-trained model expects specific input dimensions
    # This model expects RGB images in specific format
    blob = cv2.dnn.blobFromImage(image, 
                                swapRB=True, 
                                crop=False,
                                size=(800, 600))
    
    # Load the network
    print("Loading model...")
    net = cv2.dnn.readNetFromTensorflow(weights_path)
    
    # Set the input
    net.setInput(blob)
    
    # Run forward pass - get all outputs
    print("Running inference...")
    outs = net.forward()
    
    # Create class list for COCO dataset
    classes = ["background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 
               "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", 
               "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", 
               "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", 
               "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", 
               "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", 
               "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", 
               "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
    
    # Vegetation classes
    vegetation_class_ids = [
        classes.index("potted plant") if "potted plant" in classes else -1,
        classes.index("apple") if "apple" in classes else -1,
        classes.index("orange") if "orange" in classes else -1,
        classes.index("broccoli") if "broccoli" in classes else -1,
        classes.index("carrot") if "carrot" in classes else -1,
        classes.index("banana") if "banana" in classes else -1
    ]
    
    vegetation_class_ids = [id for id in vegetation_class_ids if id != -1]
    
    # Define colors for visualization
    colors = {
        classes.index("potted plant"): (0, 255, 0) if "potted plant" in classes else (0, 255, 0),
        classes.index("apple"): (0, 100, 255) if "apple" in classes else (0, 100, 255),
        classes.index("orange"): (0, 165, 255) if "orange" in classes else (0, 165, 255),
        classes.index("broccoli"): (0, 255, 100) if "broccoli" in classes else (0, 255, 100),
        classes.index("carrot"): (0, 200, 100) if "carrot" in classes else (0, 200, 100),
        classes.index("banana"): (0, 255, 255) if "banana" in classes else (0, 255, 255)
    }
    
    # Alternative implementation using a simpler detection approach
    output_image = image.copy()
    
    # Use a direct object detection model
    from pathlib import Path
    weights_dir = Path("yolo")
    weights_dir.mkdir(exist_ok=True)
    
    # Use YOLOv4 which has better compatibility with OpenCV
    weights_path = weights_dir / "yolov4.weights"
    cfg_path = weights_dir / "yolov4.cfg"
    
    if not weights_path.exists() or not cfg_path.exists():
        print("Downloading YOLOv4 model...")
        urllib.request.urlretrieve("https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights", 
                                  str(weights_path))
        urllib.request.urlretrieve("https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4.cfg", 
                                  str(cfg_path))
    
    # Load the network
    net = cv2.dnn.readNetFromDarknet(str(cfg_path), str(weights_path))
    
    # Specify the output layers
    out_layer_names = net.getUnconnectedOutLayersNames()
    
    # Prepare the image
    blob = cv2.dnn.blobFromImage(image, 1/255.0, (416, 416), swapRB=True, crop=False)
    net.setInput(blob)
    
    # Run inference
    outputs = net.forward(out_layer_names)
    
    # Get bounding boxes, confidences, and class IDs
    boxes = []
    confidences = []
    class_ids = []
    
    for output in outputs:
        for detection in output:
            scores = detection[5:]
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            
            # Filter for vegetation classes
            if confidence > 0.5 and (class_id in [64, 47, 73, 49]):  # potted plant, tree, apple, orange in YOLO
                # Scale the bounding box coordinates to the image size
                box = detection[0:4] * np.array([width, height, width, height])
                centerX, centerY, box_width, box_height = box.astype("int")
                
                # Calculate the top-left corner
                x = int(centerX - (box_width / 2))
                y = int(centerY - (box_height / 2))
                
                boxes.append([x, y, int(box_width), int(box_height)])
                confidences.append(float(confidence))
                class_ids.append(class_id)
    
    # Apply non-maximum suppression
    indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)
    
    # YOLO class names
    yolo_classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck", "boat", "traffic light", 
                   "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", 
                   "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", 
                   "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", 
                   "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", 
                   "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "sofa", 
                   "pottedplant", "bed", "dining table", "toilet", "tvmonitor", "laptop", "mouse", "remote", "keyboard", 
                   "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 
                   "teddy bear", "hair drier", "toothbrush"]
    
    # Define vegetation classes in YOLO
    yolo_vegetation = ["pottedplant", "apple", "orange", "banana", "broccoli", "carrot"]
    vegetation_indices = [yolo_classes.index(cls) for cls in yolo_vegetation if cls in yolo_classes]
    
    # Colors for YOLO classes
    yolo_colors = {
        yolo_classes.index("pottedplant"): (0, 255, 0) if "pottedplant" in yolo_classes else (0, 255, 0),
        yolo_classes.index("apple"): (0, 100, 255) if "apple" in yolo_classes else (0, 100, 255),
        yolo_classes.index("orange"): (0, 165, 255) if "orange" in yolo_classes else (0, 165, 255),
        yolo_classes.index("banana"): (0, 255, 255) if "banana" in yolo_classes else (0, 255, 255),
        yolo_classes.index("broccoli"): (0, 255, 100) if "broccoli" in yolo_classes else (0, 255, 100),
        yolo_classes.index("carrot"): (0, 200, 100) if "carrot" in yolo_classes else (0, 200, 100)
    }
    
    vegetation_count = 0
    vegetation_areas = []
    
    # Draw the filtered detections
    if len(indices) > 0:
        for i in indices.flatten():
            # Check if the detection is vegetation
            class_id = class_ids[i]
            if class_id in vegetation_indices:
                vegetation_count += 1
                
                # Get the box coordinates
                x, y, w, h = boxes[i]
                
                # Calculate area
                area = w * h
                vegetation_areas.append(area)
                
                # Draw the bounding box
                color = yolo_colors.get(class_id, (0, 255, 0))  # Default to green
                cv2.rectangle(output_image, (x, y), (x + w, y + h), color, 2)
                
                # Add label
                label = f"{yolo_classes[class_id]}: {confidences[i]:.2f}"
                cv2.putText(output_image, label, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    # Convert to RGB for matplotlib
    output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
    
    # Display results
    plt.figure(figsize=(12, 8))
    plt.imshow(output_image_rgb)
    plt.title(f"Tree/Vegetation Detection: {vegetation_count} instances")
    plt.axis('off')
    
    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
        print(f"Detection results saved to {output_path}")
    
    plt.show()
    
    # Print statistics
    print(f"Detected {vegetation_count} vegetation instances")
    
    if vegetation_areas:
        avg_area = np.mean(vegetation_areas)
        print(f"Average vegetation area: {avg_area:.2f} pixels")
    
    return {
        'count': vegetation_count,
        'areas': vegetation_areas,
        'processed_image': output_image_rgb
    }


Downloading model from http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz...
Extracting model...
Loading model...
Running inference...


error: OpenCV(4.11.0) /Users/xperience/GHA-Actions-OpenCV/_work/opencv-python/opencv-python/opencv/modules/dnn/src/net.cpp:105: error: (-215:Assertion failed) !empty() in function 'forward'


## Model Configuration

Define the configuration for the Mask R-CNN model as specified in the paper.

In [None]:
class TreeCrownConfig(Config):
    """Configuration for training on the tree crown dataset.
    Derives from the base Config class and overrides specific values.
    """
    # Give the configuration a recognizable name
    NAME = "tree_crown"
    
    # Number of classes (including background)
    NUM_CLASSES = 1 + 1  # Background + Tree Crown
    
    # Number of training steps per epoch
    STEPS_PER_EPOCH = 100
    
    # Number of validation steps to run at the end of every training epoch
    VALIDATION_STEPS = 50
    
    # Learning rate and momentum (as described in the paper)
    LEARNING_RATE = 0.001
    
    # Backbone architecture for feature extraction
    BACKBONE = "resnet101"
    
    # Input image resizing - keep images with their original aspect ratio
    # and enforce a maximum size limit
    IMAGE_RESIZE_MODE = "square"
    IMAGE_MIN_DIM = 800
    IMAGE_MAX_DIM = 1024
    
    # ROIs below this threshold are discarded
    DETECTION_MIN_CONFIDENCE = 0.7

## Dataset Handler

The TreeCrownDataset class manages the dataset loading and preprocessing.

In [None]:
class TreeCrownDataset(utils.Dataset):
    def load_tree_crowns(self, dataset_dir, subset):
        """Load a subset of the Tree Crown dataset.
        dataset_dir: Root directory of the dataset.
        subset: Subset to load: train or val
        """
        # Add classes. We have only one class to add.
        self.add_class("tree_crown", 1, "tree_crown")
        
        # Train or validation dataset?
        assert subset in ["train", "val"]
        dataset_dir = os.path.join(dataset_dir, subset)
        
        # Load annotations
        # LabelMe format (poly format annotations)
        annotations = self.load_labelme_annotations(dataset_dir)
        
        for a in annotations:
            # Get the x, y coordinates of points of the polygons that make up
            # the outline of each object instance
            polygons = a['polygons']
            image_path = os.path.join(dataset_dir, a['filename'])
            
            # Load the image
            image = skimage.io.imread(image_path)
            height, width = image.shape[:2]
            
            self.add_image(
                "tree_crown",
                image_id=a['filename'],  # use file name as a unique image id
                path=image_path,
                width=width, height=height,
                polygons=polygons)
    
    def load_labelme_annotations(self, dataset_dir):
        """Load LabelMe annotations for tree crown polygons.
        This is specifically designed for the annotation format
        used in the paper with Labelme tool.
        """
        # Implementation would depend on specific format of annotations
        # For this example, assuming we have a JSON file for each image
        # with polygon coordinates for tree crowns
        
        annotations = []
        
        # Scan through all files in the directory
        for filename in os.listdir(dataset_dir):
            if filename.endswith('.json'):  # Labelme annotations are typically JSON
                json_path = os.path.join(dataset_dir, filename)
                
                # Parse the JSON file
                with open(json_path) as f:
                    data = json.load(f)
                
                # Extract image filename from JSON
                image_filename = data['imagePath']
                
                # Extract polygons - adapt this to match actual Labelme format
                polygons = []
                for shape in data['shapes']:
                    if shape['label'] == 'tree_crown':
                        # Convert points to array format
                        points = np.array(shape['points'], dtype=np.int32)
                        polygons.append(points)
                
                annotations.append({
                    'filename': image_filename,
                    'polygons': polygons
                })
        
        return annotations
    
    def load_mask(self, image_id):
        """Generate instance masks for an image.
        Returns:
        masks: A bool array of shape [height, width, instance count] with
            one mask per instance.
        class_ids: a 1D array of class IDs of the instance masks.
        """
        # If not a tree crown dataset image, delegate to parent class.
        image_info = self.image_info[image_id]
        if image_info["source"] != "tree_crown":
            return super(self.__class__, self).load_mask(image_id)
        
        # Convert polygons to a bitmap mask of shape
        # [height, width, instance_count]
        info = self.image_info[image_id]
        mask = np.zeros([info["height"], info["width"], len(info["polygons"])],
                       dtype=np.uint8)
        
        for i, p in enumerate(info["polygons"]):
            # Get indexes of pixels inside the polygon and set them to 1
            rr, cc = skimage.draw.polygon(p[:, 1], p[:, 0])
            mask[rr, cc, i] = 1
        
        # Return mask, and array of class IDs of each instance
        return mask.astype(np.bool), np.ones([mask.shape[-1]], dtype=np.int32)
    
    def image_reference(self, image_id):
        """Return the path of the image."""
        info = self.image_info[image_id]
        if info["source"] == "tree_crown":
            return info["path"]
        else:
            super(self.__class__, self).image_reference(image_id)

## Image Splitting Function

This function divides large satellite images into smaller sub-images as described in the paper.

In [None]:
def split_image(image_path, output_dir, tile_size=(935, 910)):
    """Split a large Google Earth image into smaller sub-images
    as described in the paper.
    
    Args:
        image_path: Path to the large image
        output_dir: Directory to save the sub-images
        tile_size: Size of the sub-images (width, height)
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Load the image
    img = cv2.imread(image_path)
    h, w = img.shape[:2]
    
    # Calculate number of tiles
    n_h = math.ceil(h / tile_size[1])
    n_w = math.ceil(w / tile_size[0])
    
    print(f"Splitting image of size {w}x{h} into {n_w}x{n_h} tiles")
    
    # Split the image
    count = 0
    for i in range(n_h):
        for j in range(n_w):
            x = j * tile_size[0]
            y = i * tile_size[1]
            
            # Handle edge cases
            x_end = min(x + tile_size[0], w)
            y_end = min(y + tile_size[1], h)
            
            # Extract tile
            tile = img[y:y_end, x:x_end]
            
            # Save tile
            tile_path = os.path.join(output_dir, f"tile_{count:03d}.jpg")
            cv2.imwrite(tile_path, tile)
            count += 1
    
    print(f"Split image into {count} tiles")
    return count

# Example usage
# split_image("large_satellite_image.jpg", "tiles/")

## Model Training Function

This function implements the training pipeline for the Mask R-CNN model.

In [None]:
def train_model(config, dataset_dir):
    """Train the Mask R-CNN model for tree crown detection.
    
    Args:
        config: TreeCrownConfig instance
        dataset_dir: Directory containing the dataset
    """
    # Create model in training mode
    model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
    
    # Load COCO weights as starting point
    model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=[
        "mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"
    ])
    
    # Load training dataset
    dataset_train = TreeCrownDataset()
    dataset_train.load_tree_crowns(dataset_dir, "train")
    dataset_train.prepare()
    
    # Load validation dataset
    dataset_val = TreeCrownDataset()
    dataset_val.load_tree_crowns(dataset_dir, "val")
    dataset_val.prepare()
    
    # Train the model
    # First, train only the heads (as per the paper's approach)
    print("Training network heads")
    model.train(dataset_train, dataset_val,
               learning_rate=config.LEARNING_RATE,
               epochs=5,
               layers='heads')
    
    # Fine-tune all layers
    print("Fine-tuning all layers")
    model.train(dataset_train, dataset_val,
               learning_rate=config.LEARNING_RATE / 10,
               epochs=10,
               layers='all')
    
    return model

# Example usage
# config = TreeCrownConfig()
# model = train_model(config, "dataset_directory/")

## Detection Function

This function detects tree crowns in new images using the trained model.

In [None]:
def detect_tree_crowns(model, image_path, output_path=None):
    """Detect tree crowns in an image and save the result.
    
    Args:
        model: Trained Mask R-CNN model
        image_path: Path to the input image
        output_path: Path to save the output visualization
        
    Returns:
        Detection results
    """
    # Read the image
    image = skimage.io.imread(image_path)
    
    # Detect tree crowns
    results = model.detect([image], verbose=1)
    r = results[0]
    
    # Visualize results
    fig = plt.figure(figsize=(12, 12))
    visualize.display_instances(
        image, r['rois'], r['masks'], r['class_ids'],
        ['BG', 'Tree Crown'], r['scores'],
        title="Tree Crown Detection",
        figsize=(12, 12)
    )
    
    # Save the figure if output_path is specified
    if output_path:
        plt.savefig(output_path)
        plt.close()
    else:
        plt.show()
    
    return r

# Example usage
# detect_tree_crowns(model, "test_image.jpg", "result.png")

## Results Analysis

This function analyzes the detection results to get statistics about tree crowns.

In [None]:
def analyze_results(results):
    """Analyze the detection results to get statistics about tree crowns.
    
    Args:
        results: List of detection results for multiple images
    
    Returns:
        Dictionary with statistics
    """
    total_trees = 0
    total_area = 0
    area_distribution = {}
    bin_size = 50  # bin size in m²
    
    for r in results:
        n_trees = r['masks'].shape[-1]
        total_trees += n_trees
        
        # Calculate area for each tree crown
        for i in range(n_trees):
            mask = r['masks'][:, :, i]
            area = np.sum(mask) * (0.27**2)  # Convert pixels to m² (0.27m resolution)
            total_area += area
            
            # Update area distribution
            bin_idx = int(area / bin_size)
            if bin_idx not in area_distribution:
                area_distribution[bin_idx] = 0
            area_distribution[bin_idx] += 1
    
    # Prepare distribution data for plotting
    area_bins = []
    tree_counts = []
    for bin_idx in sorted(area_distribution.keys()):
        min_area = bin_idx * bin_size
        max_area = (bin_idx + 1) * bin_size
        area_bins.append(f"[{min_area}, {max_area})")
        tree_counts.append(area_distribution[bin_idx])
    
    stats = {
        'total_trees': total_trees,
        'total_area': total_area,
        'area_distribution': {
            'bins': area_bins,
            'counts': tree_counts
        }
    }
    
    return stats

def plot_area_distribution(stats):
    """Plot the distribution of tree crown areas.
    
    Args:
        stats: Statistics dictionary returned by analyze_results
    """
    plt.figure(figsize=(12, 6))
    plt.bar(stats['area_distribution']['bins'], 
           stats['area_distribution']['counts'])
    plt.xlabel('Crown Area (m²)')
    plt.ylabel('Number of Trees')
    plt.title('Distribution of Tree Crown Areas')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## Complete Workflow Example

Here's an example of a complete workflow using the functions defined above.

In [None]:
# Example workflow - uncomment and modify as needed

# 1. Split a large satellite image into smaller tiles
# split_image("large_satellite_image.jpg", "tiles/")

# 2. Create and train the model
# config = TreeCrownConfig()
# model = train_model(config, "dataset_directory/")

# 3. Load a trained model for inference
# config = TreeCrownConfig()
# config.BATCH_SIZE = 1  # For inference
# model = modellib.MaskRCNN(mode="inference", config=config, model_dir=MODEL_DIR)
# weights_path = model.find_last()  # Find last trained weights
# model.load_weights(weights_path, by_name=True)

# 4. Process all images in a directory
# results = []
# input_dir = "test_images/"
# output_dir = "results/"
# os.makedirs(output_dir, exist_ok=True)
# for filename in os.listdir(input_dir):
#     if filename.endswith(('.jpg', '.jpeg', '.png')):
#         image_path = os.path.join(input_dir, filename)
#         output_path = os.path.join(output_dir, f"result_{os.path.splitext(filename)[0]}.png")
#         print(f"Processing {image_path}")
#         result = detect_tree_crowns(model, image_path, output_path)
#         results.append(result)

# 5. Analyze and visualize results
# stats = analyze_results(results)
# print(f"Total trees detected: {stats['total_trees']}")
# print(f"Total crown area: {stats['total_area']:.2f} m²")
# plot_area_distribution(stats)

## Processing a Single Image Example

You can use this cell to process a single test image.

In [None]:
# Example for processing a single image
# config = TreeCrownConfig()
# config.BATCH_SIZE = 1  # For inference
# model = modellib.MaskRCNN(mode="inference", config=config, model_dir=MODEL_DIR)
# weights_path = model.find_last()  # Find last trained weights
# model.load_weights(weights_path, by_name=True)

# result = detect_tree_crowns(model, "test_image.jpg")
# print(f"Detected {result['masks'].shape[-1]} trees")

# # Calculate average crown area
# areas = []
# for i in range(result['masks'].shape[-1]):
#     mask = result['masks'][:, :, i]
#     area = np.sum(mask) * (0.27**2)  # Convert pixels to m² (0.27m resolution)
#     areas.append(area)

# avg_area = np.mean(areas) if areas else 0
# print(f"Average crown area: {avg_area:.2f} m²")