In [1]:
# Diagnostics: show interpreter and key package versions
import sys
print('Python:', sys.version)
try:
    import rasterio, detectron2, detectree2
    print('rasterio:', getattr(rasterio, '__version__', 'unknown'))
    import importlib
    print('detectron2:', importlib.metadata.version('detectron2'))
    # print('detectree2:', getattr(detectree2, '__version__', 'unknown'))
except Exception as e:
    print('Import error:', e)


Python: 3.10.18 (main, Jun  5 2025, 08:13:51) [Clang 14.0.6 ]
Import error: No module named 'detectron2'


# Detectree2 crown detection class

This notebook provides a clean, class-based pipeline to run Detectree2 crown detection on orthomosaics found in `input/input_om`, and saves crowns as GeoPackages named like `OM1.gpkg` into `output/detected_polygons`.

In [4]:
import os
import subprocess
import warnings
from dataclasses import dataclass
from typing import Optional

import geopandas as gpd
import matplotlib.pyplot as plt
from PIL import Image
# Pillow compatibility shim for detectron2 transforms
if not hasattr(Image, 'LINEAR'):
    Image.LINEAR = Image.BILINEAR
if not hasattr(Image, 'CUBIC'):
    Image.CUBIC = Image.BICUBIC
if not hasattr(Image, 'LANCZOS'):
    Image.LANCZOS = Image.LANCZOS if hasattr(Image, 'LANCZOS') else Image.BICUBIC

# PyTorch safe creators shim to handle numpy scalar/dtype inputs from downstream libs
import numpy as np
import torch

def _map_numpy_dtype_to_torch(dtype):
    mapping = {
        np.float32: torch.float32,
        np.float64: torch.float64,
        np.float16: torch.float16,
        np.int64: torch.int64,
        np.int32: torch.int32,
        np.int16: torch.int16,
        np.int8: torch.int8,
        np.uint8: torch.uint8,
        np.bool_: torch.bool,
    }
    if isinstance(dtype, np.dtype):
        dtype = dtype.type
    return mapping.get(dtype, dtype)

def _coerce_numpy_scalars(obj):
    # Recursively convert numpy scalar types to native Python scalars within common containers
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, (list, tuple)):
        coerced = [_coerce_numpy_scalars(o) for o in obj]
        return type(obj)(coerced) if not isinstance(obj, tuple) else tuple(coerced)
    if isinstance(obj, dict):
        return {k: _coerce_numpy_scalars(v) for k, v in obj.items()}
    return obj

if not getattr(torch, "_wrapped_torch_creators_numpy_dtype", False):
    def _wrap_creator(fn):
        def _inner(*args, **kwargs):
            if args:
                args = ( _coerce_numpy_scalars(args[0]), ) + args[1:]
            if "dtype" in kwargs:
                kwargs["dtype"] = _map_numpy_dtype_to_torch(kwargs["dtype"])
            return fn(*args, **kwargs)
        return _inner

    # Patch common creators used in preprocessing/prediction
    for name in [
        "tensor", "as_tensor", "zeros", "zeros_like", "ones", "ones_like",
        "empty", "empty_like", "full", "full_like", "arange", "linspace",
        "logspace", "eye", "rand", "randn",
    ]:
        if hasattr(torch, name):
            setattr(torch, name, _wrap_creator(getattr(torch, name)))

    torch._wrapped_torch_creators_numpy_dtype = True

import rasterio
from rasterio.plot import show

# detectree2 / detectron2
from detectree2.preprocessing.tiling import tile_data
from detectree2.models.train import setup_cfg
from detectree2.models.predict import predict_on_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectron2.engine import DefaultPredictor


@dataclass
class DetectreeConfig:
    tiles_buffer: int = 20
    tile_width: int = 45
    tile_height: int = 45
    dtype_bool: bool = True
    iou_threshold: float = 0.7
    confidence_threshold: float = 0.5
    simplify_tolerance: float = 0.3
    device: str = "cpu"  # set to 'cuda' if GPU available


class DetectreeRunner:
    """
    Clean, minimal class to tile, predict, stitch, clean and save crowns as GPKG.

    Contract
    - Input: path to orthomosaic (tif), path to trained .pth, optional output dir.
    - Output: GeoDataFrame of crowns and a saved GPKG file.
    - Errors: raises FileNotFoundError for missing inputs; RuntimeError for missing dependencies.
    """

    def __init__(self, config: Optional[DetectreeConfig] = None):
        self.cfg = config or DetectreeConfig()

    @staticmethod
    def _ensure_dir(path: str):
        os.makedirs(path, exist_ok=True)

    @staticmethod
    def install_detectree2():
        subprocess.check_call(["pip", "install", "git+https://github.com/PatBall1/detectree2.git"])  # optional helper

    def _setup_detection_cfg(self, trained_model_path: str):
        if not os.path.exists(trained_model_path):
            raise FileNotFoundError(f"Model not found: {trained_model_path}")
        cfg = setup_cfg(update_model=trained_model_path)
        # Force device if specified
        try:
            cfg.MODEL.DEVICE = self.cfg.device
        except Exception:
            pass
        return cfg

    def _tile(self, img_path: str, tiles_path: str):
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")
        self._ensure_dir(tiles_path)
        tile_data(
            img_path,
            tiles_path,
            self.cfg.tiles_buffer,
            self.cfg.tile_width,
            self.cfg.tile_height,
            dtype_bool=self.cfg.dtype_bool,
        )

    def _predict(self, tiles_path: str, cfg):
        predictor = DefaultPredictor(cfg)
        predict_on_data(directory=tiles_path, predictor=predictor)

    def _geo_project(self, tiles_path: str) -> str:
        predictions_folder = os.path.join(tiles_path, "predictions/")
        geo_predictions_folder = os.path.join(tiles_path, "predictions_geo/")
        project_to_geojson(tiles_path, predictions_folder, geo_predictions_folder)
        return geo_predictions_folder

    def _stitch_and_clean(self, geo_predictions_folder: str) -> gpd.GeoDataFrame:
        crowns = stitch_crowns(geo_predictions_folder, 1)
        # CRS safety: ensure geodataframe has CRS by reading from a sample tile if missing
        if crowns.crs is None:
            # Try to infer CRS from a sidecar if present; otherwise leave None
            pass
        crowns = crowns[crowns.is_valid]
        crowns = crowns.set_geometry(crowns.simplify(self.cfg.simplify_tolerance))
        crowns = clean_crowns(crowns, self.cfg.iou_threshold, self.cfg.confidence_threshold)
        return crowns

    @staticmethod
    def save_gpkg(crowns: gpd.GeoDataFrame, out_gpkg_path: str):
        DetectreeRunner._ensure_dir(os.path.dirname(out_gpkg_path))
        # If CRS is missing, try setting from EPSG:3857 as a fallback (most orthos are projected); better to inherit from source raster if available
        if crowns.crs is None:
            warnings.warn("Crowns GeoDataFrame has no CRS; saving without CRS. Consider setting CRS from source raster.")
        crowns.to_file(out_gpkg_path, driver="GPKG")

    @staticmethod
    def visualize(crowns: gpd.GeoDataFrame, img_path: Optional[str] = None, title: str = "Predicted Crowns"):
        fig, ax = plt.subplots(figsize=(10, 10))
        if img_path and os.path.exists(img_path):
            with rasterio.open(img_path) as src:
                show(src, ax=ax)
            ax.set_title("Orthomosaic with Predicted Crowns")
        else:
            ax.set_title(title)
        crowns.plot(ax=ax, facecolor='none', edgecolor='red')
        plt.show()

    def run(self, img_path: str, trained_model_path: str, work_dir: str, out_gpkg_path: str, visualize: bool = False) -> gpd.GeoDataFrame:
        # Validate model early to avoid costly tiling if missing
        cfg = self._setup_detection_cfg(trained_model_path)
        tiles_path = os.path.join(work_dir, "tiles")
        self._tile(img_path, tiles_path)
        self._predict(tiles_path, cfg)
        geo_dir = self._geo_project(tiles_path)
        crowns = self._stitch_and_clean(geo_dir)
        self.save_gpkg(crowns, out_gpkg_path)
        if visualize:
            self.visualize(crowns, img_path)
        return crowns

ModuleNotFoundError: No module named 'torch'

In [None]:
# Path helpers and naming scheme
from pathlib import Path

def infer_output_name_from_orthomosaic(om_filename: str) -> str:
    """
    Map sit_om1.tif -> OM1.gpkg, sit_om2.tif -> OM2.gpkg, etc.
    Fallback: base name uppercased without extension + .gpkg
    """
    stem = Path(om_filename).stem
    # Try sit_om{n}
    if stem.lower().startswith("sit_om") and stem[6:].isdigit():
        n = stem[6:]
        return f"OM{n}.gpkg"
    # Try lhc_om{n}
    if stem.lower().startswith("lhc_om") and stem[6:].isdigit():
        n = stem[6:]
        return f"OM{n}.gpkg"
    # generic fallback
    return f"{stem.upper()}.gpkg"

INPUT_OM_DIR = "/Users/hbot07/VS Code/Drone-Phenology-Monitoring/input/input_om"
OUTPUT_DIR = "/Users/hbot07/VS Code/Drone-Phenology-Monitoring/output/detected_polygons"
WORK_DIR = "/Users/hbot07/VS Code/Drone-Phenology-Monitoring/output/work"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(WORK_DIR, exist_ok=True)

# Model path set to provided model
MODEL_PATH = "/Users/hbot07/VS Code/Drone-Phenology-Monitoring/input/detectree_models/250312_flexi.pth"

# Pick one orthomosaic to test
TEST_OM = "sit_om1.tif"
IMG_PATH = os.path.join(INPUT_OM_DIR, TEST_OM)
OUT_GPKG = os.path.join(OUTPUT_DIR, infer_output_name_from_orthomosaic(TEST_OM))
TILES_WORK = os.path.join(WORK_DIR, Path(TEST_OM).stem)

runner = DetectreeRunner()
print("Image:", IMG_PATH)
print("Model:", MODEL_PATH)
print("Output GPKG:", OUT_GPKG)
print("Work dir:", TILES_WORK)

In [None]:
# Run pipeline on the selected orthomosaic and save GPKG with visualization
import traceback
try:
    crowns = runner.run(
        img_path=IMG_PATH,
        trained_model_path=MODEL_PATH,
        work_dir=TILES_WORK,
        out_gpkg_path=OUT_GPKG,
        visualize=True,
    )
    print(f"Saved crowns: {OUT_GPKG}")
except FileNotFoundError as e:
    print("Missing file:", e)
    print("Tip: ensure the model .pth exists at MODEL_PATH or set DETECTREE_MODEL env var.")
except Exception as e:
    print("Error while running detection:")
    traceback.print_exc()
    print("If detectron2/detectree2 are not installed for your Python kernel, install them and re-run.")

In [None]:
# Convenience function to run by orthomosaic file name

def run_and_save(om_filename: str, model_path: str = MODEL_PATH):
    img_path = os.path.join(INPUT_OM_DIR, om_filename)
    out_name = infer_output_name_from_orthomosaic(om_filename)
    out_path = os.path.join(OUTPUT_DIR, out_name)
    work = os.path.join(WORK_DIR, Path(om_filename).stem)
    return runner.run(img_path, model_path, work, out_path, visualize=False)

# Example:
# run_and_save('sit_om2.tif')