# Crown Tracking on Real Data
This notebook demonstrates tree crown tracking across multiple orthomosaics using real crown polygon files. It includes: matching, graph construction, IoU matrix computation, shifting (alignment), merging, and new crown creation.

## Required Input Files
- A list of crown polygon files (GeoPackage format, `.gpkg`), one per orthomosaic.
- (Optional) Corresponding orthomosaic images (`.tif`) for visualization.

**How to use:**
- Place your crown files in the `input_crowns/` directory.
- Name them as `OM1.gpkg`, `OM2.gpkg`, `OM3.gpkg`, ... (or any order, but keep the list consistent).
- (Optional) Place orthomosaic images in `input_orthos/` as `OM1.tif`, `OM2.tif`, ... if you want background visualization.

**You need to provide:**
- At least two crown files in `input_crowns/` (e.g., `OM1.gpkg`, `OM2.gpkg`).
- (Optional) Corresponding `.tif` files in `input_orthos/`.

In [None]:
import os
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
from shapely.affinity import translate
from sklearn.neighbors import NearestNeighbors

# --- Parameters ---
crown_dir = 'input_crowns'
ortho_dir = 'input_om'  # corrected folder name
crown_files = sorted([os.path.join(crown_dir, f) for f in os.listdir(crown_dir) if f.endswith('.gpkg')])
ortho_files = sorted([os.path.join(ortho_dir, f) for f in os.listdir(ortho_dir) if f.endswith('.tif')]) if os.path.exists(ortho_dir) else []
print('Crown files found:', crown_files)
print('Ortho files found:', ortho_files)

In [None]:
# --- Extract and store image patch for each crown from corresponding orthomosaic ---
import rasterio
from rasterio.mask import mask
from shapely.geometry import mapping
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
def extract_crown_images(crown_files, ortho_files):
    crown_images = []  # List of lists: crown_images[om_idx][crown_idx] = image patch
    for om_idx, (crown_file, ortho_file) in enumerate(zip(crown_files, ortho_files)):
        gdf = gpd.read_file(crown_file)
        if not os.path.exists(ortho_file):
            crown_images.append([None]*len(gdf))
            continue
        with rasterio.open(ortho_file) as src:
            om_images = []
            for i, row in gdf.iterrows():
                geom = [mapping(row['geometry'])]
                try:
                    out_image, out_transform = mask(src, geom, crop=True)
                    # Convert to HWC for visualization
                    img_patch = np.moveaxis(out_image, 0, -1)
                except Exception as e:
                    img_patch = None
                om_images.append(img_patch)
            crown_images.append(om_images)
    return crown_images
crown_images = extract_crown_images(crown_files, ortho_files)
# crown_images[om_idx][crown_idx] is the image patch for that crown

In [None]:
# --- Display a few example crown image patches extracted from orthomosaics ---
import matplotlib.pyplot as plt
num_examples = 4
examples = []
for om_idx in range(min(len(crown_images), 2)):  # Show from first two orthomosaics
    for crown_idx in range(min(num_examples, len(crown_images[om_idx]))):
        img_patch = crown_images[om_idx][crown_idx]
        if img_patch is not None:
            examples.append((om_idx+1, crown_idx, img_patch))
fig, axes = plt.subplots(1, len(examples), figsize=(2*len(examples), 3))
for i, (om_num, crown_num, img_patch) in enumerate(examples):
    ax = axes[i]
    ax.imshow(img_patch)
    ax.set_title(f'OM{om_num} Crown {crown_num}')
    ax.axis('off')
plt.suptitle('Example Crown Image Patches Extracted from Orthomosaics', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# --- Visualize all crown polygons from all OMs ---
import matplotlib.pyplot as plt
import geopandas as gpd
import os
from matplotlib.lines import Line2D
crown_dir = 'input_crowns'
crown_files = sorted([os.path.join(crown_dir, f) for f in os.listdir(crown_dir) if f.endswith('.gpkg')])
colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'magenta']
fig, ax = plt.subplots(figsize=(10, 10))
legend_elements = []
for i, f in enumerate(crown_files):
    color = colors[i % len(colors)]
    gdf = gpd.read_file(f)
    if len(gdf) == 0:
        continue
    gdf.plot(ax=ax, facecolor='none', edgecolor=color, linewidth=1, label=f'OM{i+1}')
    legend_elements.append(Line2D([0], [0], color=color, lw=2, label=f'OM{i+1}'))
ax.set_title('Crown Polygons from All OMs')
if legend_elements:
    ax.legend(handles=legend_elements)
plt.show()

In [None]:
import os
try:
    import pyproj
    if not os.environ.get("PROJ_LIB"):
        os.environ["PROJ_LIB"] = pyproj.datadir.get_data_dir()
    print("PROJ_LIB set to:", os.environ["PROJ_LIB"])
except Exception as e:
    print("Could not set PROJ_LIB automatically:", e)

In [None]:
# Load all crowns into a list
all_crowns = [gpd.read_file(f) for f in crown_files]
for i, gdf in enumerate(all_crowns):
    print(f'OM{i+1}:', len(gdf), 'crowns')

## 1. Shifting (Alignment)
Align each OM to the previous one using median translation of matched crowns.

In [None]:
def align_crowns(all_crowns, threshold=10):
    aligned = [all_crowns[0].copy()]
    for t in range(1, len(all_crowns)):
        ref = aligned[-1]
        curr = all_crowns[t].copy()
        centroids_ref = np.array([[g.centroid.x, g.centroid.y] for g in ref.geometry])
        centroids_curr = np.array([[g.centroid.x, g.centroid.y] for g in curr.geometry])
        nn = NearestNeighbors(n_neighbors=1).fit(centroids_ref)
        distances, indices = nn.kneighbors(centroids_curr)
        matched = [(indices[i][0], i) for i in range(len(distances)) if distances[i][0] < threshold]
        if len(matched) < 3:
            print(f'Not enough matches for OM{t+1}, skipping alignment.')
            aligned.append(curr)
            continue
        shifts = np.array([centroids_ref[i] - centroids_curr[j] for i, j in matched])
        residuals = np.linalg.norm(shifts, axis=1)
        inliers = residuals < np.percentile(residuals, 90)
        dx, dy = np.median(shifts[inliers], axis=0)
        print(f'OM{t+1}: Applying median translation dx={dx:.2f}, dy={dy:.2f}')
        curr['geometry'] = curr['geometry'].apply(lambda g: translate(g, xoff=dx, yoff=dy))
        aligned.append(curr)
    return aligned

aligned_crowns = align_crowns(all_crowns)

## 3. Add New Crowns (No Overlap)
Add crowns from OM_t to OM_{t+1} if they have no significant overlap, and vice versa.

In [None]:
import pandas as pd
def compute_iou(g1, g2):
    intersection = g1.intersection(g2).area
    union = g1.union(g2).area
    return intersection / union if union > 0 else 0

def augment_crowns_with_images(all_crowns, crown_images, iou_thresh=0.01):
    augmented = [all_crowns[0].copy()]
    augmented_images = [crown_images[0][:]]
    for t in range(1, len(all_crowns)):
        prev = augmented[-1]
        prev_images = augmented_images[-1][:]
        curr = all_crowns[t].copy()
        curr_images = crown_images[t][:]
        # Add crowns from prev to curr if missing
        for idx, row in prev.iterrows():
            if all(compute_iou(row.geometry, g2) < iou_thresh for g2 in curr.geometry):
                curr = pd.concat([curr, row.to_frame().T], ignore_index=True)
                # Add corresponding image from prev_images
                curr_images.append(prev_images[idx])
        # Add crowns from curr to prev if missing
        for idx, row in curr.iterrows():
            if all(compute_iou(row.geometry, g2) < iou_thresh for g2 in prev.geometry):
                prev = pd.concat([prev, row.to_frame().T], ignore_index=True)
                prev_images.append(curr_images[idx])
        augmented.append(curr)
        augmented_images.append(curr_images)
    return augmented, augmented_images

augmented_crowns, augmented_crown_images = augment_crowns_with_images(aligned_crowns, crown_images)

## 4. Build Tracking Graph and IoU Matrix
Construct the tracking graph and IoU matrices for the processed crowns.

In [None]:
import networkx as nx
# Build tracking graphs and IoU matrices for all consecutive OM pairs
graphs = []
iou_matrices = []
for t in range(len(augmented_crowns)-1):
    G = nx.DiGraph()
    crowns1 = augmented_crowns[t]
    crowns2 = augmented_crowns[t+1]
    iou_matrix = np.zeros((len(crowns1), len(crowns2)))
    for i, g1 in enumerate(crowns1.geometry):
        for j, g2 in enumerate(crowns2.geometry):
            iou_matrix[i, j] = compute_iou(g1, g2)
            if iou_matrix[i, j] > 0.15:  # Example threshold
                G.add_edge(f'OM{t+1}_{i}', f'OM{t+2}_{j}', weight=iou_matrix[i, j])
    graphs.append(G)
    iou_matrices.append(iou_matrix)
    print(f'Pair OM{t+1} to OM{t+2}: IoU matrix shape {iou_matrix.shape}, edges in graph: {G.number_of_edges()}')

In [None]:
# --- Visualize tracking graphs and IoU matrices for each OM pair ---
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
for t, (G, iou_matrix) in enumerate(zip(graphs, iou_matrices)):
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))
    # IoU matrix heatmap
    im = axs[0].imshow(iou_matrix, aspect='auto', cmap='viridis')
    axs[0].set_title(f'IoU Matrix OM{t+1} to OM{t+2}')
    axs[0].set_xlabel(f'OM{t+2} crowns')
    axs[0].set_ylabel(f'OM{t+1} crowns')
    plt.colorbar(im, ax=axs[0], fraction=0.046, pad=0.04)
    # Graph visualization (spring layout)
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, ax=axs[1], node_size=20, edge_color='gray', arrowsize=10, with_labels=False)
    axs[1].set_title(f'Tracking Graph OM{t+1} to OM{t+2}')
    plt.tight_layout()
    plt.show()

In [None]:
# --- Visualize time series of tracked trees across all 5 orthomosaics ---
import matplotlib.pyplot as plt

def get_tracks(graphs):
    # Extract tracks (paths) from graphs: each track is a list of (om_idx, crown_idx)
    tracks = []
    for start_node in graphs[0].nodes:
        track = [start_node]
        current_node = start_node
        for t, G in enumerate(graphs):
            # Find next node with highest weight (IoU)
            edges = [(u, v, d['weight']) for u, v, d in G.edges(data=True) if u == current_node]
            if edges:
                next_node = max(edges, key=lambda x: x[2])[1]
                track.append(next_node)
                current_node = next_node
            else:
                break
        tracks.append(track)
    return tracks

tracks = get_tracks(graphs)

def parse_node(node):
    # node format: 'OM{om_idx}_{crown_idx}'
    parts = node.split('_')
    om_idx = int(parts[0][2:]) - 1
    crown_idx = int(parts[1])
    return om_idx, crown_idx

# Filter tracks that span all 5 orthomosaics
full_tracks = [track for track in tracks if len(track) == 5]
num_tracks_to_show = len(full_tracks)
print(f"Found {len(full_tracks)} tracks spanning all 5 orthomosaics. Showing {num_tracks_to_show}.")

for track_id, track in enumerate(full_tracks[:num_tracks_to_show]):
    fig, axes = plt.subplots(1, len(track), figsize=(3*len(track), 3))
    # If only one subplot, axes is not a list
    if len(track) == 1:
        axes = [axes]
    for i, node in enumerate(track):
        om_idx, crown_idx = parse_node(node)
        img_patch = None
        if om_idx < len(augmented_crown_images) and crown_idx < len(augmented_crown_images[om_idx]):
            img_patch = augmented_crown_images[om_idx][crown_idx]
        ax = axes[i]
        if img_patch is not None:
            ax.imshow(img_patch)
            ax.set_title(f'OM{om_idx+1} Crown {crown_idx}')
        else:
            ax.set_title(f'OM{om_idx+1} Crown {crown_idx}\n(No image)')
        ax.axis('off')
    plt.suptitle(f'Time Series of Tracked Tree {track_id+1} (All 5 OMs)', fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# --- Visualize cases where one tree is matched to multiple trees in another orthomosaic ---
import matplotlib.pyplot as plt

def find_split_matches(graphs):
    split_cases = []
    for t, G in enumerate(graphs):
        # For each node in OM_t, check if it has multiple outgoing edges to OM_{t+1}
        for node in G.nodes():
            outgoing = [v for u, v in G.edges(node)]
            if len(outgoing) > 1:
                split_cases.append((t, node, outgoing))
    return split_cases

split_matches = find_split_matches(graphs)
num_cases_to_show = min(10, len(split_matches))
print(f"Found {len(split_matches)} cases where one tree is matched to multiple trees in another OM. Showing {num_cases_to_show}.")

for case_id, (t, node, outgoing) in enumerate(split_matches[:num_cases_to_show]):
    om_idx, crown_idx = int(node.split('_')[0][2:]) - 1, int(node.split('_')[1])
    fig, axes = plt.subplots(1, 1 + len(outgoing), figsize=(3*(1 + len(outgoing)), 3))
    # Show source tree image
    src_img = None
    if om_idx < len(augmented_crown_images) and crown_idx < len(augmented_crown_images[om_idx]):
        src_img = augmented_crown_images[om_idx][crown_idx]
    ax = axes[0] if len(outgoing) > 0 else axes
    if src_img is not None:
        ax.imshow(src_img)
        ax.set_title(f'Source OM{om_idx+1} Crown {crown_idx}')
    else:
        ax.set_title(f'Source OM{om_idx+1} Crown {crown_idx}\n(No image)')
    ax.axis('off')
    # Show matched trees in next OM
    for i, out_node in enumerate(outgoing):
        next_om_idx, next_crown_idx = int(out_node.split('_')[0][2:]) - 1, int(out_node.split('_')[1])
        img_patch = None
        if next_om_idx < len(augmented_crown_images) and next_crown_idx < len(augmented_crown_images[next_om_idx]):
            img_patch = augmented_crown_images[next_om_idx][next_crown_idx]
        ax = axes[i+1]
        if img_patch is not None:
            ax.imshow(img_patch)
            ax.set_title(f'Match OM{next_om_idx+1} Crown {next_crown_idx}')
        else:
            ax.set_title(f'Match OM{next_om_idx+1} Crown {next_crown_idx}\n(No image)')
        ax.axis('off')
    plt.suptitle(f'Tree Split Case {case_id+1}: OM{om_idx+1} Crown {crown_idx} → {len(outgoing)} Crowns in OM{om_idx+2}', fontsize=14)
    plt.tight_layout()
    plt.show()