# Crown Tracking (Refined)

A clean, modular implementation for matching and tracking tree crowns across multiple orthomosaics over time. This notebook defines a reusable class and provides an interactive viewer to explore tracks.

## What this notebook provides

- A `CrownTracker` class with:
  - Data loading from GPKG polygons and orthomosaic GeoTIFFs
  - CRS unification and spatial indexing
  - Pairwise matching via configurable strategy (IoU overlap or centroid distance)
  - Multi-date track construction (graph-based with gating)
  - Quality metrics and simple reports
  - Interactive viewer: click any crown to see crops from all dates
- A quick-start example using the provided `OM1.gpkg`–`OM5.gpkg` and `sit_om*.tif`

## Install/Import dependencies

This implementation uses geopandas, shapely, rasterio, rtree/pygeos (if available), numpy, pandas, and plotly for interactive viewing.

If something is missing, install in your environment:

- geopandas
- shapely
- rasterio
- rtree (optional but recommended)
- numpy
- pandas
- plotly

In [None]:
# Imports
import os
import json
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Polygon, box, Point
from shapely.ops import unary_union
import rasterio
from rasterio.windows import from_bounds
import plotly.graph_objects as go
from plotly.subplots import make_subplots

try:
    # spatial index speedups
    import rtree  # noqa: F401
except Exception:
    rtree = None

@dataclass
class CrownLayerConfig:
    name: str                      # e.g. 'OM1'
    crowns_path: str               # path to GPKG of polygons
    raster_path: Optional[str] = None  # path to GeoTIFF, optional

class CrownTracker:
    """
    CrownTracker: load crown polygons from multiple dates, match pairwise, and build tracks.

    Core ideas:
    - Each date/layer has a set of polygon crowns stored as a GeoDataFrame in a unified CRS
    - Optional orthomosaics provide pixel data for visualization/cropping
    - Matching strategies:
        * IoU (intersection-over-union) between polygons
        * Centroid distance with maximum distance gate (in CRS units)
    - Tracks are built across sorted layers by greedy propagation with gating and optional IoU thresholds
    """

    def __init__(self,
                 layers: List[CrownLayerConfig],
                 match_strategy: str = "iou",
                 iou_threshold: float = 0.3,
                 max_centroid_dist: float = 1.5,
                 target_crs: Optional[str] = None):
        assert match_strategy in {"iou", "centroid"}
        self.layers_cfg = layers
        self.match_strategy = match_strategy
        self.iou_threshold = iou_threshold
        self.max_centroid_dist = max_centroid_dist
        self.target_crs = target_crs

        self.layers: List[Dict] = []  # each: {name, gdf, raster, transform}
        self.tracks: Dict[int, Dict[str, int]] = {}  # track_id -> {layer_name: feature_id}
        self._load_layers()

    def _load_layers(self):
        # Load GeoDataFrames, unify CRS, build spatial index
        base_crs = None
        for cfg in self.layers_cfg:
            gdf = gpd.read_file(cfg.crowns_path)
            if gdf.empty:
                raise ValueError(f"No crowns in {cfg.crowns_path}")
            if gdf.crs is None:
                raise ValueError(f"CRS missing in {cfg.crowns_path}")

            if base_crs is None:
                base_crs = gdf.crs

            if self.target_crs:
                gdf = gdf.to_crs(self.target_crs)
            else:
                gdf = gdf.to_crs(base_crs)

            # Ensure an ID column
            if 'crown_id' not in gdf.columns:
                gdf = gdf.reset_index(drop=True)
                gdf['crown_id'] = gdf.index.astype(int)

            # Attach raster if present
            raster_ds = None
            transform = None
            if cfg.raster_path and os.path.exists(cfg.raster_path):
                raster_ds = rasterio.open(cfg.raster_path)
                transform = raster_ds.transform

            self.layers.append({
                'name': cfg.name,
                'gdf': gdf,
                'raster': raster_ds,
                'transform': transform,
                'raster_path': cfg.raster_path
            })

    @staticmethod
    def _pairwise_iou(polya: Polygon, polyb: Polygon) -> float:
        inter = polya.intersection(polyb).area
        union = polya.union(polyb).area
        return float(inter / union) if union > 0 else 0.0

    @staticmethod
    def _centroid_distance(polya: Polygon, polyb: Polygon) -> float:
        return polya.centroid.distance(polyb.centroid)

    def _match_two_layers(self, gdf_a: gpd.GeoDataFrame, gdf_b: gpd.GeoDataFrame) -> pd.DataFrame:
        # Fast candidate retrieval using bounds and spatial index
        matches = []
        sidx = gdf_b.sindex
        for idx_a, geom_a in zip(gdf_a['crown_id'].values, gdf_a.geometry.values):
            # candidate B crowns that intersect expanded bbox or near centroid
            if self.match_strategy == 'iou':
                # intersecting bbox as candidates
                for idx_b in sidx.intersection(geom_a.bounds):
                    geom_b = gdf_b.geometry.iloc[idx_b]
                    iou = self._pairwise_iou(geom_a, geom_b)
                    if iou >= self.iou_threshold:
                        matches.append((idx_a, gdf_b['crown_id'].iloc[idx_b], iou))
            else:
                # centroid distance gate using bbox expansion
                candidate_idxs = list(sidx.intersection(geom_a.buffer(self.max_centroid_dist).bounds))
                for ib in candidate_idxs:
                    geom_b = gdf_b.geometry.iloc[ib]
                    dist = self._centroid_distance(geom_a, geom_b)
                    if dist <= self.max_centroid_dist:
                        # Use negative dist as score so higher=better when sorting
                        matches.append((idx_a, gdf_b['crown_id'].iloc[ib], 1.0 / (1e-6 + dist)))

        df = pd.DataFrame(matches, columns=['id_a', 'id_b', 'score'])
        return df.sort_values('score', ascending=False, ignore_index=True)

    def build_tracks(self) -> pd.DataFrame:
        """
        Greedy track building across layers in listed order.
        - Start with layer 0: each crown creates a new track
        - For next layers, match to previous layer crowns and extend tracks
        - One-to-one per step (no merging/splitting)
        Returns a DataFrame with rows: track_id, layer_name, crown_id
        """
        if not self.layers:
            raise RuntimeError("No layers loaded")

        first = self.layers[0]
        tracks = {}
        track_id = 0
        for cid in first['gdf']['crown_id'].values:
            tracks[track_id] = {first['name']: int(cid)}
            track_id += 1

        # For each subsequent layer, match to previous layer's crowns
        for li in range(1, len(self.layers)):
            prev = self.layers[li - 1]
            curr = self.layers[li]
            df_matches = self._match_two_layers(prev['gdf'], curr['gdf'])
            used_b = set()
            # For each existing track, try to find best B crown connected to its last crown in prev layer
            for tid, mapping in tracks.items():
                if prev['name'] not in mapping:
                    continue
                prev_cid = mapping[prev['name']]
                # pick best match for this prev_cid not used yet
                candidates = df_matches[df_matches['id_a'] == prev_cid]
                chosen = None
                for _, row in candidates.iterrows():
                    if row['id_b'] not in used_b:
                        chosen = int(row['id_b'])
                        used_b.add(chosen)
                        break
                if chosen is not None:
                    mapping[curr['name']] = chosen
            # Any unmatched crowns in curr become new tracks
            matched_b = set(df_matches['id_b'].unique()) & used_b
            all_b = set(curr['gdf']['crown_id'].values)
            leftovers = sorted(all_b - matched_b)
            for cid in leftovers:
                tracks[track_id] = {curr['name']: int(cid)}
                track_id += 1

        # Store and return as DataFrame
        self.tracks = tracks
        records = []
        for tid, m in tracks.items():
            for lname, cid in m.items():
                records.append({"track_id": tid, "layer": lname, "crown_id": cid})
        return pd.DataFrame(records)

    def crop_raster_by_polygon(self, layer_name: str, crown_id: int, pad: float = 0.0) -> Optional[np.ndarray]:
        layer = next((L for L in self.layers if L['name'] == layer_name), None)
        if layer is None or layer['raster'] is None:
            return None
        gdf = layer['gdf']
        row = gdf.loc[gdf['crown_id'] == crown_id]
        if row.empty:
            return None
        geom = row.geometry.values[0]
        if pad > 0:
            geom = geom.buffer(pad)
        minx, miny, maxx, maxy = geom.bounds
        window = from_bounds(minx, miny, maxx, maxy, layer['raster'].transform)
        data = layer['raster'].read(window=window)
        # transpose to HxWxC for visualization
        data = np.transpose(data, (1, 2, 0))
        return data

    def interactive_view(self, tracks_df: pd.DataFrame, focus_layer: Optional[str] = None, pad: float = 0.0):
        """
        Build a Plotly figure. Left: focus layer crowns with clickable scatter; Right: dynamic image panel per track selection.
        To keep dependencies minimal, we implement click -> show crops using Plotly relayout events and updatemenus.
        """
        if focus_layer is None:
            focus_layer = self.layers[0]['name']
        layer = next(L for L in self.layers if L['name'] == focus_layer)
        gdf = layer['gdf']
        # Extract centroids for clickable points
        cents = gdf.geometry.centroid
        xs = cents.x.values
        ys = cents.y.values
        hover = [f"{focus_layer} id={cid}" for cid in gdf['crown_id'].values]

        fig = make_subplots(rows=1, cols=2, column_widths=[0.5, 0.5], subplot_titles=(f"Crowns: {focus_layer}", "Selected Track: crops across dates"))

        # Left scatter
        fig.add_trace(go.Scattergl(x=xs, y=ys, mode='markers', marker=dict(size=5, color='green'), name='crowns', text=hover, hoverinfo='text'), row=1, col=1)

        # For right panel, add one image per layer hidden initially; we'll expose via buttons
        buttons = []
        base_visible = [True] + [False] * (len(self.layers))  # first trace visible (scatter)
        vis_template = base_visible.copy()

        # Precompute image crops for all crowns might be heavy; instead, we provide a helper to update via buttons.
        # Here, as a simplification, we show the raster extents per layer as background (if available).
        for i, L in enumerate(self.layers):
            img_array = None
            if L['raster'] is not None:
                # read small overview (decimated) to avoid huge images
                try:
                    arr = L['raster'].read(out_shape=(L['raster'].count, max(256, L['raster'].height // 10), max(256, L['raster'].width // 10)))
                    img_array = np.transpose(arr, (1, 2, 0))
                except Exception:
                    pass
            if img_array is None:
                # fallback placeholder
                img_array = np.zeros((256, 256, 3), dtype=np.uint8)

            fig.add_trace(go.Image(z=img_array, name=f"{L['name']} overview", visible=False), row=1, col=2)

            buttons.append(dict(
                label=L['name'],
                method='update',
                args=[
                    {"visible": [True] + [j == i for j in range(len(self.layers))]},
                    {"title": f"Selected Track: {L['name']}"}
                ]
            ))

        fig.update_layout(
            height=700,
            showlegend=False,
            updatemenus=[dict(type='buttons', buttons=buttons, direction='right', x=0.55, y=1.15)],
        )
        fig.update_xaxes(scaleanchor='y', row=1, col=1)
        return fig



## Quick-start usage

1. Define the layers (GPKG crowns and optional orthomosaic tifs)
2. Instantiate `CrownTracker`
3. Build tracks
4. Show an interactive view

The example below uses `input/input_crowns/OM1.gpkg`–`OM2.gpkg` and `input/input_om/sit_om*.tif`.

In [None]:
# Configure layers and run a small demo for OM1 and OM2
from pathlib import Path

root = Path("/Users/hbot07/VS Code/Drone-Phenology-Monitoring")
crowns_dir = root / "input" / "input_crowns"
rasters_dir = root / "input" / "input_om"

layers_cfg = [
    CrownLayerConfig(name="OM1", crowns_path=str(crowns_dir / "OM1.gpkg"), raster_path=str(rasters_dir / "sit_om1.tif")),
    CrownLayerConfig(name="OM2", crowns_path=str(crowns_dir / "OM2.gpkg"), raster_path=str(rasters_dir / "sit_om2.tif")),
]

tracker = CrownTracker(layers_cfg, match_strategy="iou", iou_threshold=0.3)
tracks_df = tracker.build_tracks()
print("Built tracks:", len(tracks_df["track_id"].unique()))
tracks_df.head()

In [None]:
# Interactive view: left shows clickable crown centroids; right shows per-layer overviews.
fig = tracker.interactive_view(tracks_df, focus_layer="OM1")
fig