# Tree Crown Tracking (Strict Preset) — 15 Oct

This notebook bundles a reusable tracker class that links tree crown polygons across consecutive orthomosaics with a strict, high-precision configuration (one-to-one and containment only). It auto-discovers inputs under `input/input_crowns` and `input/input_om`, builds a match graph, and reports quality/complexity metrics.

In [1]:
import os
import re
import json
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Iterable

import numpy as np
import geopandas as gpd
import rasterio
from rasterio.mask import mask
from shapely.geometry import mapping, Polygon
import networkx as nx

In [3]:
@dataclass
class MatchCaseConfig:
    name: str
    base_similarity_weights: Dict[str, float]
    scoring_weights: Dict[str, float]
    similarity_threshold: float
    min_iou: float = 0.0
    min_overlap_prev: float = 0.0
    min_overlap_curr: float = 0.0
    max_centroid_dist: Optional[float] = None
    mutual_best: bool = False
    allow_multiple: bool = False
    max_edges_per_prev: Optional[int] = None
    max_edges_per_curr: Optional[int] = None

class TreeTrackingGraph:
    """
    High-precision tree crown tracker across orthomosaics using a directed graph.

    - Discovers crown GeoPackages (*.gpkg) and orthomosaics (*.tif)
    - Computes per-crown attributes (centroid, area, compactness, eccentricity, etc.)
    - Builds a match graph between consecutive OMs using strict cases: one_to_one + containment
    - Reports quality metrics and graph complexity

    Strict preset parameters (best-known):
    - base_max_dist ~ 70-80 (projected units);
    - overlap_gate = 0.48; min_base_similarity = 0.35;
    - Cases:
      one_to_one: similarity_threshold=0.82, min_iou=0.40, min_overlap_prev=0.72, min_overlap_curr=0.72, mutual_best=True, max_edges per node=1
      containment: similarity_threshold=0.74, min_overlap_prev=0.82, min_overlap_curr=0.82, max_edges per node=1
    """
    def __init__(self, crown_dir: Optional[str] = None, ortho_dir: Optional[str] = None, output_dir: str = '../../output', simplify_tol: float = 1.0, max_crowns_preview: int = 200, auto_discover: bool = True) -> None:
        self.output_dir = output_dir
        self.simplify_tol = simplify_tol
        self.max_crowns_preview = max_crowns_preview
        self.crown_dir = crown_dir or self._resolve_dir('input/input_crowns', '../../input/input_crowns')
        self.ortho_dir = ortho_dir or self._resolve_dir('input/input_om', '../../input/input_om')
        self.file_pairs: List[Tuple[str, Optional[str]]] = []
        self.om_ids: List[int] = []
        self.crowns_gdfs: Dict[int, gpd.GeoDataFrame] = {}
        self.crown_attrs: Dict[int, List[Dict[str, Any]]] = {}
        self.crown_images: Dict[int, List[Optional[np.ndarray]]] = {}
        self.G: nx.DiGraph = nx.DiGraph()
        self.case_configs: Dict[str, MatchCaseConfig] = self._strict_case_configs()
        self.case_order: List[str] = ['one_to_one', 'containment']
        self.last_case_counts: Dict[str, int] = {}
        self.last_selected_counts: Dict[str, int] = {}
        if auto_discover:
            self.discover_files()

    @staticmethod
    def _resolve_dir(root_rel: str, nb_rel: str) -> str:
        candidates = [
            os.path.abspath(os.path.join(os.getcwd(), root_rel)),
            os.path.abspath(os.path.join(os.getcwd(), nb_rel)),
        ]
        for p in candidates:
            if os.path.isdir(p):
                return p
        raise FileNotFoundError(f"Could not resolve directory for {root_rel}. Tried: {candidates}")

    @staticmethod
    def _extract_numeric_id(name: str) -> Optional[int]:
        m = re.search(r"(\d+)", os.path.basename(name))
        return int(m.group(1)) if m else None

    def discover_files(self) -> None:
        crown_files = [os.path.join(self.crown_dir, f) for f in os.listdir(self.crown_dir) if f.lower().endswith('.gpkg')]
        ortho_files = [os.path.join(self.ortho_dir, f) for f in os.listdir(self.ortho_dir) if f.lower().endswith('.tif')] if os.path.exists(self.ortho_dir) else []
        if not crown_files:
            raise FileNotFoundError(f"No .gpkg crown files found in {self.crown_dir}")
        crowns_by_id = {}
        for cf in crown_files:
            cid = self._extract_numeric_id(cf)
            crowns_by_id[cid if cid is not None else cf] = cf
        orthos_by_id = {}
        for of in ortho_files:
            oid = self._extract_numeric_id(of)
            orthos_by_id[oid if oid is not None else of] = of
        numeric_ids = sorted(set(k for k in crowns_by_id.keys() if isinstance(k, int)) & set(k for k in orthos_by_id.keys() if isinstance(k, int)))
        file_pairs: List[Tuple[str, Optional[str]]] = []
        if numeric_ids:
            for nid in numeric_ids:
                file_pairs.append((crowns_by_id[nid], orthos_by_id.get(nid)))
            crown_only = sorted(k for k in crowns_by_id.keys() if isinstance(k, int) and k not in numeric_ids)
            for nid in crown_only:
                file_pairs.append((crowns_by_id[nid], None))
        else:
            crown_files_sorted = sorted(crown_files)
            ortho_files_sorted = sorted(ortho_files)
            for i, cf in enumerate(crown_files_sorted):
                of = ortho_files_sorted[i] if i < len(ortho_files_sorted) else None
                file_pairs.append((cf, of))
        om_ids: List[int] = []
        for cf, _ in file_pairs:
            cid = self._extract_numeric_id(cf)
            om_ids.append(cid if cid is not None else len(om_ids) + 1)
        pairs_with_id = sorted([(oid, cf, of) for oid, (cf, of) in zip(om_ids, file_pairs)], key=lambda x: x[0])
        self.file_pairs = [(cf, of) for _, cf, of in pairs_with_id]
        self.om_ids = [oid for oid, _, _ in pairs_with_id]

    def load_data(self, load_images: bool = False) -> None:
        self.crowns_gdfs.clear()
        self.crown_attrs.clear()
        self.crown_images.clear()
        for om_id, (crown_file, ortho_file) in zip(self.om_ids, self.file_pairs):
            gdf = gpd.read_file(crown_file)
            self.crowns_gdfs[om_id] = gdf
            self.crown_attrs[om_id] = [self._compute_crown_attributes(row.geometry) for _, row in gdf.iterrows()]
            if load_images and ortho_file and os.path.exists(ortho_file):
                with rasterio.open(ortho_file) as src:
                    patches: List[Optional[np.ndarray]] = []
                    for _, row in gdf.iterrows():
                        geom = [mapping(row.geometry)]
                        try:
                            out_image, _ = mask(src, geom, crop=True)
                            img_patch = np.moveaxis(out_image, 0, -1)
                        except Exception:
                            img_patch = None
                        patches.append(img_patch)
                self.crown_images[om_id] = patches
            else:
                self.crown_images[om_id] = [None] * len(gdf)

    @staticmethod
    def _compute_crown_attributes(geometry) -> Dict[str, Any]:
        centroid = geometry.centroid
        area = geometry.area
        perimeter = geometry.length
        bounds = geometry.bounds
        compactness = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
        try:
            min_rect = geometry.minimum_rotated_rectangle
            coords = list(min_rect.exterior.coords)
            side1 = np.linalg.norm(np.array(coords[0]) - np.array(coords[1]))
            side2 = np.linalg.norm(np.array(coords[1]) - np.array(coords[2]))
            major_axis = max(side1, side2)
            minor_axis = min(side1, side2)
            eccentricity = minor_axis / major_axis if major_axis > 0 else 1
        except Exception:
            eccentricity = 1
        aspect_ratio = (bounds[3] - bounds[1]) / (bounds[2] - bounds[0]) if bounds[2] != bounds[0] else 1
        return {
            'geometry': geometry,
            'centroid': centroid,
            'area': area,
            'perimeter': perimeter,
            'compactness': compactness,
            'eccentricity': eccentricity,
            'aspect_ratio': aspect_ratio,
            'bounds': bounds,
        }

    def _strict_case_configs(self) -> Dict[str, MatchCaseConfig]:
        return {
            'one_to_one': MatchCaseConfig(
                name='one_to_one',
                base_similarity_weights={'spatial': 0.45, 'area': 0.2, 'shape': 0.15, 'iou': 0.2},
                scoring_weights={'base': 0.55, 'iou': 0.2, 'overlap_prev': 0.1, 'overlap_curr': 0.1, 'centroid': 0.05},
                similarity_threshold=0.82,
                min_iou=0.40,
                min_overlap_prev=0.72,
                min_overlap_curr=0.72,
                max_centroid_dist=45.0,
                mutual_best=True,
                allow_multiple=False,
                max_edges_per_prev=1,
                max_edges_per_curr=1,
            ),
            'containment': MatchCaseConfig(
                name='containment',
                base_similarity_weights={'spatial': 0.35, 'area': 0.15, 'shape': 0.15, 'iou': 0.35},
                scoring_weights={'base': 0.3, 'overlap_prev': 0.35, 'overlap_curr': 0.35},
                similarity_threshold=0.74,
                min_iou=0.0,
                min_overlap_prev=0.82,
                min_overlap_curr=0.82,
                max_centroid_dist=60.0,
                mutual_best=False,
                allow_multiple=False,
                max_edges_per_prev=1,
                max_edges_per_curr=1,
            ),
        }

    @staticmethod
    def _compute_iou(g1, g2) -> float:
        try:
            intersection = g1.intersection(g2).area
            union = g1.union(g2).area
        except Exception:
            intersection = 0.0
            union = g1.area + g2.area
        return intersection / union if union > 0 else 0.0

    def _weighted_similarity(self, a1: Dict[str, Any], a2: Dict[str, Any], weights: Optional[Dict[str, float]] = None, max_dist: float = 100.0) -> Tuple[float, Dict[str, float]]:
        if weights is None:
            weights = {'spatial': 0.4, 'area': 0.2, 'shape': 0.2, 'iou': 0.2}
        centroid_dist = a1['centroid'].distance(a2['centroid'])
        spatial_sim = max(0.0, 1.0 - (centroid_dist / max_dist))
        area_sim = min(a1['area'], a2['area']) / max(a1['area'], a2['area']) if max(a1['area'], a2['area']) > 0 else 0.0
        compactness_sim = 1.0 - abs(a1['compactness'] - a2['compactness'])
        eccentricity_sim = 1.0 - abs(a1['eccentricity'] - a2['eccentricity'])
        shape_sim = (compactness_sim + eccentricity_sim) / 2.0
        iou_sim = self._compute_iou(a1['geometry'], a2['geometry'])
        total = (weights.get('spatial', 0.0) * spatial_sim + weights.get('area', 0.0) * area_sim + weights.get('shape', 0.0) * shape_sim + weights.get('iou', 0.0) * iou_sim)
        return total, {'spatial': float(spatial_sim), 'area': float(area_sim), 'shape': float(shape_sim), 'iou': float(iou_sim), 'total': float(total)}

    def _compute_pair_metrics(self, prev_attrs: Dict[str, Any], curr_attrs: Dict[str, Any], max_dist: float) -> Dict[str, float]:
        prev_geom = prev_attrs['geometry']
        curr_geom = curr_attrs['geometry']
        try:
            intersection_area = prev_geom.intersection(curr_geom).area
        except Exception:
            intersection_area = 0.0
        try:
            union_area = prev_geom.union(curr_geom).area
        except Exception:
            union_area = prev_attrs['area'] + curr_attrs['area'] - intersection_area
        prev_area = prev_attrs['area'] if prev_attrs['area'] > 0 else 1e-6
        curr_area = curr_attrs['area'] if curr_attrs['area'] > 0 else 1e-6
        overlap_prev = intersection_area / prev_area
        overlap_curr = intersection_area / curr_area
        iou = intersection_area / union_area if union_area > 0 else 0.0
        centroid_dist = prev_attrs['centroid'].distance(curr_attrs['centroid'])
        base_similarity, parts = self._weighted_similarity(prev_attrs, curr_attrs, max_dist=max_dist)
        prev_radius = np.sqrt(prev_area / np.pi)
        curr_radius = np.sqrt(curr_area / np.pi)
        mean_radius = max((prev_radius + curr_radius) / 2.0, 1e-3)
        area_ratio = curr_area / prev_area if prev_area > 0 else np.inf
        if not np.isfinite(area_ratio) or area_ratio <= 0:
            balanced_area_ratio = 0.0
        else:
            balanced_area_ratio = area_ratio if area_ratio <= 1 else 1 / area_ratio
        try:
            prev_contains_curr = prev_geom.buffer(0).contains(curr_geom)
        except Exception:
            prev_contains_curr = False
        try:
            curr_contains_prev = curr_geom.buffer(0).contains(prev_geom)
        except Exception:
            curr_contains_prev = False
        return {
            'intersection_area': float(intersection_area),
            'union_area': float(union_area),
            'overlap_prev': float(overlap_prev),
            'overlap_curr': float(overlap_curr),
            'iou': float(iou),
            'centroid_dist': float(centroid_dist),
            'base_similarity': float(base_similarity),
            'spatial_similarity': float(parts['spatial']),
            'area_similarity': float(parts['area']),
            'shape_similarity': float(parts['shape']),
            'mean_radius': float(mean_radius),
            'area_ratio': float(area_ratio if np.isfinite(area_ratio) else 0.0),
            'balanced_area_ratio': float(balanced_area_ratio),
            'prev_contains_curr': bool(prev_contains_curr),
            'curr_contains_prev': bool(curr_contains_prev),
        }

    def _classify_match_case(self, prev_node: Tuple[int, int], curr_node: Tuple[int, int], features: Dict[str, float], prev_overlap_counts: Dict[Tuple[int, int], int], curr_overlap_counts: Dict[Tuple[int, int], int], overlap_gate: float) -> str:
        if features['prev_contains_curr'] or features['curr_contains_prev']:
            return 'containment'
        # one_to_one: strong mutual overlap & IoU; unique overlaps (strict default).
        overlap_prev = features['overlap_prev']
        overlap_curr = features['overlap_curr']
        iou = features['iou']
        prev_count = prev_overlap_counts.get(prev_node, 0)
        curr_count = curr_overlap_counts.get(curr_node, 0)
        if prev_count == 1 and curr_count == 1 and overlap_prev >= 0.72 and overlap_curr >= 0.72 and iou >= 0.40:
            return 'one_to_one'
        return 'none'

    def _score_candidate(self, base_similarity: float, similarity_parts: Dict[str, float], features: Dict[str, float], config: MatchCaseConfig) -> float:
        centroid_factor = 1.0 - min(1.0, features['centroid_dist'] / (features['mean_radius'] * 3.0))
        components = {
            'base': base_similarity,
            'spatial': similarity_parts.get('spatial', 0.0),
            'area': similarity_parts.get('area', 0.0),
            'shape': similarity_parts.get('shape', 0.0),
            'iou': features['iou'],
            'overlap_prev': features['overlap_prev'],
            'overlap_curr': features['overlap_curr'],
            'centroid': max(0.0, centroid_factor),
            'area_balance': features.get('balanced_area_ratio', 0.0),
        }
        score = 0.0
        for key, weight in config.scoring_weights.items():
            score += weight * components.get(key, 0.0)
        return score

    def _select_candidates_by_case(self, candidates: List[Dict[str, Any]], configs: Dict[str, MatchCaseConfig], case_order: List[str], max_dist: float) -> List[Dict[str, Any]]:
        selected: List[Dict[str, Any]] = []
        used_prev: Dict[Tuple[int, int], int] = defaultdict(int)
        used_curr: Dict[Tuple[int, int], int] = defaultdict(int)
        for case_name in case_order:
            config = configs.get(case_name)
            if not config:
                continue
            case_candidates = [cand for cand in candidates if cand['case'] == case_name]
            if not case_candidates:
                continue
            enriched: List[Dict[str, Any]] = []
            for cand in case_candidates:
                prev_attrs = cand['prev_attrs']
                curr_attrs = cand['curr_attrs']
                features = cand['features']
                if config.max_centroid_dist is not None and features['centroid_dist'] > config.max_centroid_dist:
                    continue
                if features['iou'] < config.min_iou:
                    continue
                if features['overlap_prev'] < config.min_overlap_prev:
                    continue
                if features['overlap_curr'] < config.min_overlap_curr:
                    continue
                base_similarity, parts = self._weighted_similarity(prev_attrs, curr_attrs, weights=config.base_similarity_weights, max_dist=max_dist)
                score = self._score_candidate(base_similarity, parts, features, config)
                if score < config.similarity_threshold:
                    continue
                cand['base_similarity'] = float(base_similarity)
                cand['similarity_parts'] = {k: float(v) for k, v in parts.items()}
                cand['score'] = float(score)
                enriched.append(cand)
            if not enriched:
                continue
            if config.mutual_best:
                best_prev: Dict[Tuple[int, int], Dict[str, Any]] = {}
                best_curr: Dict[Tuple[int, int], Dict[str, Any]] = {}
                for cand in enriched:
                    prev_node = cand['prev_node']
                    curr_node = cand['curr_node']
                    if used_prev.get(prev_node, 0) and not config.allow_multiple:
                        continue
                    if used_curr.get(curr_node, 0) and not config.allow_multiple:
                        continue
                    if cand['score'] < config.similarity_threshold:
                        continue
                    if prev_node not in best_prev or cand['score'] > best_prev[prev_node]['score']:
                        best_prev[prev_node] = cand
                    if curr_node not in best_curr or cand['score'] > best_curr[curr_node]['score']:
                        best_curr[curr_node] = cand
                for cand in enriched:
                    prev_node = cand['prev_node']
                    curr_node = cand['curr_node']
                    if best_prev.get(prev_node) is cand and best_curr.get(curr_node) is cand:
                        if not config.allow_multiple:
                            if used_prev.get(prev_node, 0) or used_curr.get(curr_node, 0):
                                continue
                        if config.max_edges_per_prev is not None and used_prev[prev_node] >= config.max_edges_per_prev:
                            continue
                        if config.max_edges_per_curr is not None and used_curr[curr_node] >= config.max_edges_per_curr:
                            continue
                        selected.append(cand)
                        used_prev[prev_node] += 1
                        used_curr[curr_node] += 1
            else:
                enriched.sort(key=lambda c: c['score'], reverse=True)
                for cand in enriched:
                    prev_node = cand['prev_node']
                    curr_node = cand['curr_node']
                    if not config.allow_multiple:
                        if used_prev.get(prev_node, 0) or used_curr.get(curr_node, 0):
                            continue
                    if config.max_edges_per_prev is not None and used_prev[prev_node] >= config.max_edges_per_prev:
                        continue
                    if config.max_edges_per_curr is not None and used_curr[curr_node] >= config.max_edges_per_curr:
                        continue
                    selected.append(cand)
                    used_prev[prev_node] += 1
                    used_curr[curr_node] += 1
        return selected

    def reset_graph(self) -> None:
        self.G = nx.DiGraph()

    def build_graph_conditional(self, case_configs: Optional[Dict[str, MatchCaseConfig]] = None, case_order: Optional[List[str]] = None, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, max_candidates_per_prev: Optional[int] = None, max_candidates_per_curr: Optional[int] = None) -> None:
        if not self.crowns_gdfs:
            self.load_data(load_images=False)
        self.reset_graph()
        configs = {name: replace(cfg) for name, cfg in (case_configs or self.case_configs).items()}
        order = case_order or self.case_order
        self.last_case_counts = {}
        self.last_selected_counts = {name: 0 for name in configs.keys()}
        for idx in range(len(self.om_ids)):
            om_id = self.om_ids[idx]
            gdf = self.crowns_gdfs[om_id]
            for crown_id, row in gdf.iterrows():
                attrs = self.crown_attrs[om_id][crown_id]
                self.G.add_node((om_id, crown_id), **attrs)
            if idx == 0:
                continue
            prev_om = self.om_ids[idx - 1]
            prev_nodes = [(prev_om, i) for i in range(len(self.crowns_gdfs[prev_om]))]
            curr_nodes = [(om_id, j) for j in range(len(gdf))]
            candidates: List[Dict[str, Any]] = []
            overlap_counts_prev: Dict[Tuple[int, int], int] = defaultdict(int)
            overlap_counts_curr: Dict[Tuple[int, int], int] = defaultdict(int)
            for prev_node in prev_nodes:
                prev_attrs = self.G.nodes[prev_node]
                for curr_node in curr_nodes:
                    curr_attrs = self.crown_attrs[om_id][curr_node[1]]
                    features = self._compute_pair_metrics(prev_attrs, curr_attrs, max_dist=base_max_dist)
                    if features['centroid_dist'] > base_max_dist:
                        continue
                    if features['base_similarity'] < min_base_similarity and features['iou'] < overlap_gate:
                        continue
                    cand = {
                        'prev_node': prev_node,
                        'curr_node': curr_node,
                        'prev_attrs': prev_attrs,
                        'curr_attrs': curr_attrs,
                        'features': features,
                    }
                    candidates.append(cand)
                    if features['overlap_prev'] >= overlap_gate:
                        overlap_counts_prev[prev_node] += 1
                    if features['overlap_curr'] >= overlap_gate:
                        overlap_counts_curr[curr_node] += 1
            if not candidates:
                continue
            for cand in candidates:
                cand['case'] = self._classify_match_case(cand['prev_node'], cand['curr_node'], cand['features'], overlap_counts_prev, overlap_counts_curr, overlap_gate)
            candidates = [cand for cand in candidates if cand['case'] != 'none']
            if not candidates:
                continue
            if max_candidates_per_prev is not None:
                grouped_prev: Dict[Tuple[int, int], List[Dict[str, Any]]] = defaultdict(list)
                for cand in candidates:
                    grouped_prev[cand['prev_node']].append(cand)
                trimmed: List[Dict[str, Any]] = []
                for group in grouped_prev.values():
                    group.sort(key=lambda c: (c['features']['base_similarity'], c['features']['iou']), reverse=True)
                    trimmed.extend(group[:max_candidates_per_prev])
                candidates = trimmed
            if max_candidates_per_curr is not None:
                grouped_curr: Dict[Tuple[int, int], List[Dict[str, Any]]] = defaultdict(list)
                for cand in candidates:
                    grouped_curr[cand['curr_node']].append(cand)
                trimmed_curr: List[Dict[str, Any]] = []
                for group in grouped_curr.values():
                    group.sort(key=lambda c: (c['features']['base_similarity'], c['features']['iou']), reverse=True)
                    trimmed_curr.extend(group[:max_candidates_per_curr])
                candidates = trimmed_curr
            case_counts = defaultdict(int)
            for cand in candidates:
                case_counts[cand['case']] += 1
            for case_name, count in case_counts.items():
                self.last_case_counts[case_name] = self.last_case_counts.get(case_name, 0) + count
            selected = self._select_candidates_by_case(candidates, configs, order, base_max_dist)
            for cand in selected:
                case_name = cand['case']
                features = cand['features']
                similarity_parts = cand.get('similarity_parts', {})
                self.G.add_edge(cand['prev_node'], cand['curr_node'], similarity=float(cand.get('score', features['base_similarity'])), method='conditional', case=case_name, overlap_prev=float(features['overlap_prev']), overlap_curr=float(features['overlap_curr']), iou=float(features['iou']), centroid_distance=float(features['centroid_dist']), base_similarity=float(cand.get('base_similarity', features['base_similarity'])), spatial_similarity=float(similarity_parts.get('spatial', features['spatial_similarity'])), area_similarity=float(similarity_parts.get('area', features['area_similarity'])), shape_similarity=float(similarity_parts.get('shape', features['shape_similarity'])))
                self.last_selected_counts[case_name] = self.last_selected_counts.get(case_name, 0) + 1

    def quality_report(self) -> Tuple[str, Dict[str, Any]]:
        G = self.G
        om_ids = self.om_ids
        metrics: Dict[str, Any] = {
            'total_trees_detected': G.number_of_nodes(),
            'total_edges': G.number_of_edges(),
            'total_possible_matches': 0,
            'successful_matches': 0,
            'match_rate_by_om_pair': {},
            'chain_length_distribution': {},
            'average_chain_length': 0,
            'median_chain_length': 0,
            'max_chain_length': 0,
        }
        chains = self._extract_all_chains()
        chain_lengths = [len(chain) for chain in chains]
        if chain_lengths:
            metrics['average_chain_length'] = float(np.mean(chain_lengths))
            metrics['median_chain_length'] = float(np.median(chain_lengths))
            metrics['max_chain_length'] = int(max(chain_lengths))
            for length in chain_lengths:
                metrics['chain_length_distribution'][int(length)] = metrics['chain_length_distribution'].get(int(length), 0) + 1
        for i in range(len(om_ids) - 1):
            om1, om2 = om_ids[i], om_ids[i + 1]
            om1_nodes = [n for n in G.nodes if n[0] == om1]
            om2_nodes = [n for n in G.nodes if n[0] == om2]
            matches = sum(1 for u, v in G.edges() if u[0] == om1 and v[0] == om2)
            possible_matches = min(len(om1_nodes), len(om2_nodes))
            match_rate = matches / possible_matches if possible_matches > 0 else 0.0
            metrics['match_rate_by_om_pair'][f"{om1}->{om2}"] = {
                'matches': matches,
                'possible': possible_matches,
                'rate': float(match_rate),
            }
            metrics['total_possible_matches'] += possible_matches
            metrics['successful_matches'] += matches
        metrics['overall_match_rate'] = (metrics['successful_matches'] / metrics['total_possible_matches'] if metrics['total_possible_matches'] > 0 else 0.0)
        report = [
            "# Tree Tracking Quality Assessment Report",
            f"Total Trees Detected: {metrics['total_trees_detected']}",
            f"Total Tracking Edges: {metrics['total_edges']}",
            f"Overall Match Rate: {metrics['overall_match_rate']:.3f}",
            f"Average Chain Length: {metrics.get('average_chain_length', 0):.2f}",
            f"Maximum Chain Length: {metrics.get('max_chain_length', 0)}",
            "Match Rates by Orthomosaic Pair:",
        ]
        for pair, data in metrics['match_rate_by_om_pair'].items():
            report.append(f"- {pair}: {data['matches']}/{data['possible']} ({data['rate']:.3f})")
        report.append("\nChain Length Distribution:")
        for length, count in sorted(metrics['chain_length_distribution'].items()):
            report.append(f"- Length {length}: {count} trees")
        if self.last_selected_counts:
            report.append("\nEdge selection by case:")
            for case_name, count in sorted(self.last_selected_counts.items(), key=lambda kv: (-kv[1], kv[0])):
                total_candidates = self.last_case_counts.get(case_name, 0)
                if total_candidates:
                    ratio = count / total_candidates
                    report.append(f"- {case_name}: {count} / {total_candidates} ({ratio:.2f})")
                else:
                    report.append(f"- {case_name}: {count}")
        return "\n".join(report), metrics

    def graph_complexity_metrics(self) -> Dict[str, Any]:
        G = self.G
        out_deg = dict(G.out_degree())
        in_deg = dict(G.in_degree())
        def dist(vals: Iterable[int]) -> Dict[int, int]:
            hist: Dict[int, int] = {}
            for v in vals:
                hist[int(v)] = hist.get(int(v), 0) + 1
            return dict(sorted(hist.items()))
        out_degree_dist = dist(out_deg.values())
        in_degree_dist = dist(in_deg.values())
        zero_out = sum(1 for v in out_deg.values() if v == 0)
        zero_in = sum(1 for v in in_deg.values() if v == 0)
        weak_comps = list(nx.weakly_connected_components(G))
        strong_comps = list(nx.strongly_connected_components(G))
        weak_sizes = sorted([len(c) for c in weak_comps], reverse=True)
        strong_sizes = sorted([len(c) for c in strong_comps], reverse=True)
        UG = G.to_undirected()
        diameters: List[int] = []
        for comp in nx.connected_components(UG):
            sub = UG.subgraph(comp)
            if sub.number_of_nodes() <= 1:
                diameters.append(0)
            else:
                try:
                    diameters.append(int(nx.diameter(sub)))
                except Exception:
                    diameters.append(0)
        return {
            'num_nodes': G.number_of_nodes(),
            'num_edges': G.number_of_edges(),
            'avg_out_degree': float(np.mean(list(out_deg.values()))) if out_deg else 0.0,
            'avg_in_degree': float(np.mean(list(in_deg.values()))) if in_deg else 0.0,
            'out_degree_distribution': out_degree_dist,
            'in_degree_distribution': in_degree_dist,
            'zero_out_degree_nodes': zero_out,
            'zero_in_degree_nodes': zero_in,
            'weak_components': len(weak_comps),
            'weak_component_sizes': weak_sizes,
            'strong_components': len(strong_comps),
            'strong_component_sizes': strong_sizes,
            'diameters': diameters,
            'average_diameter': float(np.mean(diameters)) if diameters else 0.0,
            'median_diameter': float(np.median(diameters)) if diameters else 0.0,
            'max_diameter': int(max(diameters)) if diameters else 0,
        }

    def complexity_report(self) -> Tuple[str, Dict[str, Any]]:
        m = self.graph_complexity_metrics()
        report = [
            "# Graph Complexity Report",
            f"Nodes: {m['num_nodes']}",
            f"Edges: {m['num_edges']}",
            f"Avg out-degree: {m['avg_out_degree']:.3f}",
            f"Avg in-degree: {m['avg_in_degree']:.3f}",
            f"Zero out-degree nodes: {m['zero_out_degree_nodes']}",
            f"Zero in-degree nodes: {m['zero_in_degree_nodes']}",
            f"Weakly connected components: {m['weak_components']} (sizes head: {m['weak_component_sizes'][:10]})",
            f"Strongly connected components: {m['strong_components']} (sizes head: {m['strong_component_sizes'][:10]})",
            f"Average diameter: {m['average_diameter']:.3f}",
            f"Median diameter: {m['median_diameter']:.3f}",
            f"Max diameter: {m['max_diameter']}",
        ]
        return "\n".join(report), m

    def _extract_all_chains(self) -> List[List[Tuple[int, int]]]:
        visited = set()
        chains: List[List[Tuple[int, int]]] = []
        chain_starts = [n for n in self.G.nodes if not list(self.G.predecessors(n))]
        for start_node in chain_starts:
            if start_node in visited:
                continue
            chain = self._greedy_chain(start_node)
            chains.append(chain)
            visited.update(chain)
        remaining = set(self.G.nodes) - visited
        for node in remaining:
            chains.append([node])
        return chains

    def _greedy_chain(self, start_node: Tuple[int, int]) -> List[Tuple[int, int]]:
        chain = [start_node]
        current = start_node
        while True:
            successors = list(self.G.successors(current))
            if not successors:
                break
            if len(successors) > 1:
                best_successor = max(successors, key=lambda n: self.G[current][n].get('similarity', 0.0))
                chain.append(best_successor)
                current = best_successor
            else:
                chain.append(successors[0])
                current = successors[0]
        return chain

    def get_matching_chain(self, start_om_id: int, crown_id: int) -> List[Tuple[int, int]]:
        node = (start_om_id, crown_id)
        if node not in self.G:
            raise ValueError(f"Node {(start_om_id, crown_id)} not in graph. Build the graph first.")
        return self._greedy_chain(node)

    def ensure_output_dir(self) -> None:
        os.makedirs(self.output_dir, exist_ok=True)

    def save_text(self, text: str, filename: str) -> str:
        self.ensure_output_dir()
        path = os.path.join(self.output_dir, filename)
        with open(path, 'w') as f:
            f.write(text)
        return path

    def save_json(self, data: Dict[str, Any], filename: str) -> str:
        self.ensure_output_dir()
        path = os.path.join(self.output_dir, filename)
        with open(path, 'w') as f:
            json.dump(data, f, indent=2)
        return path

    def run_strict_preset(self, *, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, save_prefix: Optional[str] = None) -> Dict[str, Any]:
        if not self.crowns_gdfs:
            self.load_data(load_images=False)
        self.build_graph_conditional(case_configs=self.case_configs, case_order=self.case_order, base_max_dist=base_max_dist, overlap_gate=overlap_gate, min_base_similarity=min_base_similarity)
        q_report, q_metrics = self.quality_report()
        c_report, c_metrics = self.complexity_report()
        artifacts = {
            'quality_report_path': None,
            'quality_metrics_path': None,
            'complexity_report_path': None,
            'complexity_metrics_path': None,
        }
        if save_prefix:
            artifacts['quality_report_path'] = self.save_text(q_report, f'{save_prefix}_quality_report.txt')
            artifacts['quality_metrics_path'] = self.save_json(q_metrics, f'{save_prefix}_quality_metrics.json')
            artifacts['complexity_report_path'] = self.save_text(c_report, f'{save_prefix}_complexity_report.txt')
            artifacts['complexity_metrics_path'] = self.save_json(c_metrics, f'{save_prefix}_complexity_metrics.json')
        return {
            'nodes': self.G.number_of_nodes(),
            'edges': self.G.number_of_edges(),
            'quality_report': q_report,
            'quality_metrics': q_metrics,
            'complexity_report': c_report,
            'complexity_metrics': c_metrics,
            **artifacts,
        }

In [5]:
# Configure (optional): override default input directories if needed
# tracker = TreeTrackingGraph(crown_dir='../../input/input_crowns', ortho_dir='../../input/input_om')
tracker = TreeTrackingGraph()
summary = tracker.run_strict_preset(base_max_dist=75.0, overlap_gate=0.48, min_base_similarity=0.35, save_prefix='strict_15oct')
print(f"Nodes: {summary['nodes']}, Edges: {summary['edges']}")
print(summary['quality_report'])
print('---')
print(summary['complexity_report'])

Nodes: 626, Edges: 2
# Tree Tracking Quality Assessment Report
Total Trees Detected: 626
Total Tracking Edges: 2
Overall Match Rate: 0.004
Average Chain Length: 1.00
Maximum Chain Length: 2
Match Rates by Orthomosaic Pair:
- 1->2: 1/80 (0.013)
- 2->3: 0/116 (0.000)
- 3->4: 0/130 (0.000)
- 4->5: 1/150 (0.007)

Chain Length Distribution:
- Length 1: 622 trees
- Length 2: 2 trees

Edge selection by case:
- one_to_one: 2 / 2 (1.00)
- containment: 0 / 3 (0.00)
---
# Graph Complexity Report
Nodes: 626
Edges: 2
Avg out-degree: 0.003
Avg in-degree: 0.003
Zero out-degree nodes: 624
Zero in-degree nodes: 624
Weakly connected components: 624 (sizes head: [2, 2, 1, 1, 1, 1, 1, 1, 1, 1])
Strongly connected components: 626 (sizes head: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Average diameter: 0.003
Median diameter: 0.000
Max diameter: 1


In [6]:
# Patch quality_report to avoid f-string quote conflicts
from typing import Iterable, Tuple as _Tuple
def _quality_report_patch(self) -> _Tuple[str, Dict[str, Any]]:
    G = self.G
    om_ids = self.om_ids
    metrics: Dict[str, Any] = {
        'total_trees_detected': G.number_of_nodes(),
        'total_edges': G.number_of_edges(),
        'total_possible_matches': 0,
        'successful_matches': 0,
        'match_rate_by_om_pair': {},
        'chain_length_distribution': {},
        'average_chain_length': 0,
        'median_chain_length': 0,
        'max_chain_length': 0,
    }
    chains = self._extract_all_chains()
    chain_lengths = [len(chain) for chain in chains]
    if chain_lengths:
        metrics['average_chain_length'] = float(np.mean(chain_lengths))
        metrics['median_chain_length'] = float(np.median(chain_lengths))
        metrics['max_chain_length'] = int(max(chain_lengths))
        for length in chain_lengths:
            key = int(length)
            metrics['chain_length_distribution'][key] = metrics['chain_length_distribution'].get(key, 0) + 1
    for i in range(len(om_ids) - 1):
        om1, om2 = om_ids[i], om_ids[i + 1]
        om1_nodes = [n for n in G.nodes if n[0] == om1]
        om2_nodes = [n for n in G.nodes if n[0] == om2]
        matches = sum(1 for u, v in G.edges() if u[0] == om1 and v[0] == om2)
        possible_matches = min(len(om1_nodes), len(om2_nodes))
        rate = matches / possible_matches if possible_matches > 0 else 0.0
        metrics['match_rate_by_om_pair'][f"{om1}->{om2}"] = {
            'matches': matches,
            'possible': possible_matches,
            'rate': float(rate),
        }
        metrics['total_possible_matches'] += possible_matches
        metrics['successful_matches'] += matches
    metrics['overall_match_rate'] = (metrics['successful_matches'] / metrics['total_possible_matches'] if metrics['total_possible_matches'] > 0 else 0.0)
    report = []
    report.append('# Tree Tracking Quality Assessment Report')
    report.append(f"Total Trees Detected: {metrics['total_trees_detected']}")
    report.append(f"Total Tracking Edges: {metrics['total_edges']}")
    report.append(f"Overall Match Rate: {metrics['overall_match_rate']:.3f}")
    report.append(f"Average Chain Length: {metrics.get('average_chain_length', 0):.2f}")
    report.append(f"Maximum Chain Length: {metrics.get('max_chain_length', 0)}")
    report.append('Match Rates by Orthomosaic Pair:')
    for pair, data in metrics['match_rate_by_om_pair'].items():
        report.append(f"- {pair}: {data['matches']}/{data['possible']} ({data['rate']:.3f})")
    report.append('\nChain Length Distribution:')
    for length, count in sorted(metrics['chain_length_distribution'].items()):
        report.append(f"- Length {length}: {count} trees")
    if self.last_selected_counts:
        report.append('\nEdge selection by case:')
        for case_name, count in sorted(self.last_selected_counts.items(), key=lambda kv: (-kv[1], kv[0])):
            total_candidates = self.last_case_counts.get(case_name, 0)
            if total_candidates:
                ratio = count / total_candidates
                report.append(f"- {case_name}: {count} / {total_candidates} ({ratio:.2f})")
            else:
                report.append(f"- {case_name}: {count}")
    return '\n'.join(report), metrics
TreeTrackingGraph.quality_report = _quality_report_patch

In [None]:
# Optional: list a few high-confidence chains (strict edges tend to be high-confidence)
def _edge_info(tracker: TreeTrackingGraph, chain: List[Tuple[int, int]]):
    for u, v in zip(chain, chain[1:]):
        data = tracker.G.get_edge_data(u, v) or {}
        yield u, v, data

def filter_high_confidence_chains(tracker: TreeTrackingGraph, min_similarity: float = 0.8, allowed_cases: Optional[set] = None):
    if allowed_cases is None:
        allowed_cases = {'one_to_one', 'containment'}
    chains = tracker._extract_all_chains()
    high_conf = []
    for chain in chains:
        if len(chain) < 2:
            continue
        edges = list(_edge_info(tracker, chain))
        if not edges:
            continue
        if all((e.get('case') in allowed_cases) and (e.get('similarity', 0.0) >= min_similarity) for _, _, e in edges):
            high_conf.append(chain)
    return high_conf

hc = filter_high_confidence_chains(tracker, min_similarity=0.8)
print(f"High-confidence chains found: {len(hc)}")
for idx, ch in enumerate(hc[:10], start=1):
    sims = [tracker.G.get_edge_data(u, v).get('similarity', 0.0) for u, v in zip(ch, ch[1:])]
    print(f"{idx:02d}. length={len(ch)} avg_sim={np.mean(sims) if sims else 0:.3f} nodes={ch}")

In [7]:
# Extended matching strategies: shuffled temporal orders and virtual all-pairs edges
import random
from typing import Set

def build_graph_conditional_for_order(self, om_sequence, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, max_candidates_per_prev: Optional[int] = None, max_candidates_per_curr: Optional[int] = None, case_configs: Optional[Dict[str, MatchCaseConfig]] = None, case_order: Optional[List[str]] = None) -> None:
    if not self.crowns_gdfs:
        self.load_data(load_images=False)
    self.reset_graph()
    configs = {name: replace(cfg) for name, cfg in (case_configs or self.case_configs).items()}
    order = case_order or self.case_order
    self.last_case_counts = {}
    self.last_selected_counts = {name: 0 for name in configs.keys()}
    # Add nodes once
    for om_id in om_sequence:
        gdf = self.crowns_gdfs[om_id]
        for crown_id, row in gdf.iterrows():
            attrs = self.crown_attrs[om_id][crown_id]
            self.G.add_node((om_id, crown_id), **attrs)
    # Edges for consecutive pairs in the provided order
    for idx in range(1, len(om_sequence)):
        prev_om, om_id = om_sequence[idx - 1], om_sequence[idx]
        gdf = self.crowns_gdfs[om_id]
        prev_nodes = [(prev_om, i) for i in range(len(self.crowns_gdfs[prev_om]))]
        curr_nodes = [(om_id, j) for j in range(len(gdf))]
        candidates: List[Dict[str, Any]] = []
        overlap_counts_prev: Dict[Tuple[int, int], int] = defaultdict(int)
        overlap_counts_curr: Dict[Tuple[int, int], int] = defaultdict(int)
        for prev_node in prev_nodes:
            prev_attrs = self.G.nodes[prev_node]
            for curr_node in curr_nodes:
                curr_attrs = self.crown_attrs[om_id][curr_node[1]]
                features = self._compute_pair_metrics(prev_attrs, curr_attrs, max_dist=base_max_dist)
                if features['centroid_dist'] > base_max_dist:
                    continue
                if features['base_similarity'] < min_base_similarity and features['iou'] < overlap_gate:
                    continue
                cand = {
                    'prev_node': prev_node,
                    'curr_node': curr_node,
                    'prev_attrs': prev_attrs,
                    'curr_attrs': curr_attrs,
                    'features': features,
                }
                candidates.append(cand)
                if features['overlap_prev'] >= overlap_gate:
                    overlap_counts_prev[prev_node] += 1
                if features['overlap_curr'] >= overlap_gate:
                    overlap_counts_curr[curr_node] += 1
        if not candidates:
            continue
        for cand in candidates:
            cand['case'] = self._classify_match_case(cand['prev_node'], cand['curr_node'], cand['features'], overlap_counts_prev, overlap_counts_curr, overlap_gate)
        candidates = [cand for cand in candidates if cand['case'] != 'none']
        if not candidates:
            continue
        if max_candidates_per_prev is not None:
            grouped_prev: Dict[Tuple[int, int], List[Dict[str, Any]]] = defaultdict(list)
            for cand in candidates:
                grouped_prev[cand['prev_node']].append(cand)
            trimmed: List[Dict[str, Any]] = []
            for group in grouped_prev.values():
                group.sort(key=lambda c: (c['features']['base_similarity'], c['features']['iou']), reverse=True)
                trimmed.extend(group[:max_candidates_per_prev])
            candidates = trimmed
        if max_candidates_per_curr is not None:
            grouped_curr: Dict[Tuple[int, int], List[Dict[str, Any]]] = defaultdict(list)
            for cand in candidates:
                grouped_curr[cand['curr_node']].append(cand)
            trimmed_curr: List[Dict[str, Any]] = []
            for group in grouped_curr.values():
                group.sort(key=lambda c: (c['features']['base_similarity'], c['features']['iou']), reverse=True)
                trimmed_curr.extend(group[:max_candidates_per_curr])
            candidates = trimmed_curr
        case_counts = defaultdict(int)
        for cand in candidates:
            case_counts[cand['case']] += 1
        for case_name, count in case_counts.items():
            self.last_case_counts[case_name] = self.last_case_counts.get(case_name, 0) + count
        selected = self._select_candidates_by_case(candidates, configs, order, base_max_dist)
        for cand in selected:
            case_name = cand['case']
            features = cand['features']
            similarity_parts = cand.get('similarity_parts', {})
            self.G.add_edge(cand['prev_node'], cand['curr_node'], similarity=float(cand.get('score', features['base_similarity'])), method='conditional_order', case=case_name, overlap_prev=float(features['overlap_prev']), overlap_curr=float(features['overlap_curr']), iou=float(features['iou']), centroid_distance=float(features['centroid_dist']), base_similarity=float(cand.get('base_similarity', features['base_similarity'])), spatial_similarity=float(similarity_parts.get('spatial', features['spatial_similarity'])), area_similarity=float(similarity_parts.get('area', features['area_similarity'])), shape_similarity=float(similarity_parts.get('shape', features['shape_similarity'])))
            self.last_selected_counts[case_name] = self.last_selected_counts.get(case_name, 0) + 1

def build_graph_from_shuffles(self, num_shuffles: int = 8, seed: Optional[int] = 42, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, min_frequency: float = 0.5) -> None:
    """Aggregate edges over random orderings; keep edges appearing in >= min_frequency * num_shuffles."""
    if not self.crowns_gdfs:
        self.load_data(load_images=False)
    rng = random.Random(seed)
    edge_aggr: Dict[Tuple[Tuple[int,int], Tuple[int,int]], Dict[str, Any]] = {}
    # Run shuffled orders and aggregate
    for s in range(num_shuffles):
        order = list(self.om_ids)
        rng.shuffle(order)
        temp = TreeTrackingGraph(auto_discover=False)
        # shallow copy data
        temp.crown_dir = self.crown_dir
        temp.ortho_dir = self.ortho_dir
        temp.file_pairs = list(self.file_pairs)
        temp.om_ids = list(self.om_ids)
        temp.crowns_gdfs = self.crowns_gdfs
        temp.crown_attrs = self.crown_attrs
        temp.crown_images = self.crown_images
        temp.case_configs = {name: replace(cfg) for name, cfg in self.case_configs.items()}
        temp.case_order = list(self.case_order)
        temp.build_graph_conditional_for_order(order, base_max_dist=base_max_dist, overlap_gate=overlap_gate, min_base_similarity=min_base_similarity)
        for u, v, data in temp.G.edges(data=True):
            key = (u, v)
            rec = edge_aggr.get(key)
            if rec is None:
                rec = {'count': 0, 'sim_sum': 0.0, 'last': data}
                edge_aggr[key] = rec
            rec['count'] += 1
            rec['sim_sum'] += float(data.get('similarity', 0.0))
    # Build final aggregated graph
    self.reset_graph()
    # add nodes
    for om_id in self.om_ids:
        gdf = self.crowns_gdfs[om_id]
        for crown_id, row in gdf.iterrows():
            self.G.add_node((om_id, crown_id), **self.crown_attrs[om_id][crown_id])
    keep_min_count = max(1, int(np.ceil(min_frequency * num_shuffles)))
    for (u, v), rec in edge_aggr.items():
        if rec['count'] >= keep_min_count:
            avg_sim = rec['sim_sum'] / rec['count'] if rec['count'] else 0.0
            self.G.add_edge(u, v, similarity=float(avg_sim), agg_count=int(rec['count']), agg_sim_sum=float(rec['sim_sum']), method='shuffle_agg', case=rec['last'].get('case', 'one_to_one'))

def build_graph_virtual_allpairs(self, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35) -> None:
    """Allow edges between any pair of OMs (i < j). Enforce strict cases and global 1-1 via used_prev/used_curr."""
    if not self.crowns_gdfs:
        self.load_data(load_images=False)
    self.reset_graph()
    configs = {name: replace(cfg) for name, cfg in self.case_configs.items()}
    order = self.case_order
    # Add nodes
    for om_id in self.om_ids:
        gdf = self.crowns_gdfs[om_id]
        for crown_id, row in gdf.iterrows():
            attrs = self.crown_attrs[om_id][crown_id]
            self.G.add_node((om_id, crown_id), **attrs)
    self.last_case_counts = {}
    self.last_selected_counts = {name: 0 for name in configs.keys()}
    used_prev: Dict[Tuple[int,int], int] = defaultdict(int)
    used_curr: Dict[Tuple[int,int], int] = defaultdict(int)
    for i_idx in range(len(self.om_ids)):
        for j_idx in range(i_idx + 1, len(self.om_ids)):
            prev_om = self.om_ids[i_idx]
            om_id = self.om_ids[j_idx]
            gdf = self.crowns_gdfs[om_id]
            prev_nodes = [(prev_om, i) for i in range(len(self.crowns_gdfs[prev_om]))]
            curr_nodes = [(om_id, j) for j in range(len(gdf))]
            candidates: List[Dict[str, Any]] = []
            overlap_counts_prev: Dict[Tuple[int, int], int] = defaultdict(int)
            overlap_counts_curr: Dict[Tuple[int, int], int] = defaultdict(int)
            for prev_node in prev_nodes:
                prev_attrs = self.G.nodes[prev_node]
                for curr_node in curr_nodes:
                    curr_attrs = self.crown_attrs[om_id][curr_node[1]]
                    features = self._compute_pair_metrics(prev_attrs, curr_attrs, max_dist=base_max_dist)
                    if features['centroid_dist'] > base_max_dist:
                        continue
                    if features['base_similarity'] < min_base_similarity and features['iou'] < overlap_gate:
                        continue
                    cand = {'prev_node': prev_node, 'curr_node': curr_node, 'prev_attrs': prev_attrs, 'curr_attrs': curr_attrs, 'features': features}
                    candidates.append(cand)
                    if features['overlap_prev'] >= overlap_gate:
                        overlap_counts_prev[prev_node] += 1
                    if features['overlap_curr'] >= overlap_gate:
                        overlap_counts_curr[curr_node] += 1
            if not candidates:
                continue
            for cand in candidates:
                cand['case'] = self._classify_match_case(cand['prev_node'], cand['curr_node'], cand['features'], overlap_counts_prev, overlap_counts_curr, overlap_gate)
            candidates = [cand for cand in candidates if cand['case'] != 'none']
            if not candidates:
                continue
            # Select per pair
            selected = self._select_candidates_by_case(candidates, configs, order, base_max_dist)
            # Apply global 1:1 caps
            for cand in selected:
                u = cand['prev_node']; v = cand['curr_node']
                if used_prev.get(u, 0) and not configs[cand['case']].allow_multiple:
                    continue
                if used_curr.get(v, 0) and not configs[cand['case']].allow_multiple:
                    continue
                features = cand['features']
                similarity_parts = cand.get('similarity_parts', {})
                self.G.add_edge(u, v, similarity=float(cand.get('score', features['base_similarity'])), method='virtual_allpairs', case=cand['case'], overlap_prev=float(features['overlap_prev']), overlap_curr=float(features['overlap_curr']), iou=float(features['iou']), centroid_distance=float(features['centroid_dist']), base_similarity=float(cand.get('base_similarity', features['base_similarity'])), spatial_similarity=float(similarity_parts.get('spatial', features['spatial_similarity'])), area_similarity=float(similarity_parts.get('area', features['area_similarity'])), shape_similarity=float(similarity_parts.get('shape', features['shape_similarity'])))
                used_prev[u] += 1
                used_curr[v] += 1
                self.last_selected_counts[cand['case']] = self.last_selected_counts.get(cand['case'], 0) + 1

# Bind the methods to the class
TreeTrackingGraph.build_graph_conditional_for_order = build_graph_conditional_for_order
TreeTrackingGraph.build_graph_from_shuffles = build_graph_from_shuffles
TreeTrackingGraph.build_graph_virtual_allpairs = build_graph_virtual_allpairs

## Virtual edges (enhanced): algorithm overview

This variant considers edges between any two orthomosaics i < j, not just consecutive pairs.
We score each candidate crown pair with the strict case-based similarity and multiply by a gap penalty so that long jumps are discouraged.

Key ideas:
- Candidate generation: for every pair of OMs (i, j), compute features (IoU, overlaps, centroid distance) and a base similarity using geometry-derived parts.
- Case classification: only two strict cases are allowed — one_to_one and containment — using strong overlap/IoU gates.
- Scoring: case-specific weighted score, then apply a gap weight w_gap = exp(-alpha * (gap - 1)).
- Global selection: sort all candidates across all (i, j) by score * w_gap and greedily select edges with at most one outgoing per node and at most one incoming per node (maintains chains).
- Output: edges annotated with score, base similarity, IoU/overlaps/centroid distance, case, gap, and gap_weight.

Benefits:
- Resilient to missing detections in intermediate OMs (can bridge gaps), while still preferring short, consistent links.
- Preserves strictness (1 in, 1 out per node) to avoid merges/splits and keep interpretable tracks.

Parameters to tune:
- alpha (gap decay): higher alpha penalizes long jumps more strongly.
- base_max_dist, overlap_gate, min_base_similarity: spatial/similarity gates to form plausible candidates.


In [9]:
from math import exp

def build_graph_virtual_allpairs_enhanced(self, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, alpha: float = 0.9) -> None:
    """
    Enhanced virtual-all-pairs with gap penalty and global greedy 1:1 selection.
    - Generate candidates for all i<j OM pairs that pass strict gates.
    - Score with case-specific weights and multiply by exp(-alpha * (gap-1)).
    - Sort candidates by final score and greedily select, enforcing <=1 out, <=1 in per node.
    """
    if not self.crowns_gdfs:
        self.load_data(load_images=False)
    self.reset_graph()
    configs = {name: replace(cfg) for name, cfg in self.case_configs.items()}
    order = self.case_order
    # Add nodes
    for om_id in self.om_ids:
        gdf = self.crowns_gdfs[om_id]
        for crown_id, row in gdf.iterrows():
            attrs = self.crown_attrs[om_id][crown_id]
            self.G.add_node((om_id, crown_id), **attrs)
    self.last_case_counts = {}
    self.last_selected_counts = {name: 0 for name in configs.keys()}

    all_candidates: List[Dict[str, Any]] = []
    # Collect candidates across all i<j
    for i_idx in range(len(self.om_ids)):
        for j_idx in range(i_idx + 1, len(self.om_ids)):
            prev_om = self.om_ids[i_idx]
            om_id = self.om_ids[j_idx]
            gap = j_idx - i_idx
            prev_nodes = [(prev_om, i) for i in range(len(self.crowns_gdfs[prev_om]))]
            curr_nodes = [(om_id, j) for j in range(len(self.crowns_gdfs[om_id]))]
            overlap_counts_prev: Dict[Tuple[int, int], int] = defaultdict(int)
            overlap_counts_curr: Dict[Tuple[int, int], int] = defaultdict(int)
            prelim: List[Dict[str, Any]] = []
            for prev_node in prev_nodes:
                prev_attrs = self.G.nodes[prev_node]
                for curr_node in curr_nodes:
                    curr_attrs = self.crown_attrs[om_id][curr_node[1]]
                    features = self._compute_pair_metrics(prev_attrs, curr_attrs, max_dist=base_max_dist)
                    if features['centroid_dist'] > base_max_dist:
                        continue
                    if features['base_similarity'] < min_base_similarity and features['iou'] < overlap_gate:
                        continue
                    rec = {'prev_node': prev_node, 'curr_node': curr_node, 'prev_attrs': prev_attrs, 'curr_attrs': curr_attrs, 'features': features, 'gap': gap}
                    prelim.append(rec)
                    if features['overlap_prev'] >= overlap_gate:
                        overlap_counts_prev[prev_node] += 1
                    if features['overlap_curr'] >= overlap_gate:
                        overlap_counts_curr[curr_node] += 1
            if not prelim:
                continue
            # classify and keep
            for cand in prelim:
                cand['case'] = self._classify_match_case(cand['prev_node'], cand['curr_node'], cand['features'], overlap_counts_prev, overlap_counts_curr, overlap_gate)
            prelim = [c for c in prelim if c['case'] != 'none']
            if not prelim:
                continue
            # score per case and apply gap weight
            for cand in prelim:
                cfg = configs[cand['case']]
                base_sim, parts = self._weighted_similarity(cand['prev_attrs'], cand['curr_attrs'], weights=cfg.base_similarity_weights, max_dist=base_max_dist)
                score = self._score_candidate(base_sim, parts, cand['features'], cfg)
                gap_w = exp(-alpha * (cand['gap'] - 1)) if cand['gap'] > 1 else 1.0
                final_score = score * gap_w
                if final_score < cfg.similarity_threshold:
                    continue
                cand['base_similarity'] = float(base_sim)
                cand['similarity_parts'] = {k: float(v) for k, v in parts.items()}
                cand['score'] = float(score)
                cand['gap_weight'] = float(gap_w)
                cand['final_score'] = float(final_score)
                all_candidates.append(cand)
            # count per-case candidates (for diagnostics)
            case_counts = defaultdict(int)
            for cand in prelim:
                case_counts[cand['case']] += 1
            for case_name, count in case_counts.items():
                self.last_case_counts[case_name] = self.last_case_counts.get(case_name, 0) + count

    if not all_candidates:
        return
    # Global greedy selection by final_score
    all_candidates.sort(key=lambda c: c['final_score'], reverse=True)
    used_prev: Dict[Tuple[int,int], int] = defaultdict(int)
    used_curr: Dict[Tuple[int,int], int] = defaultdict(int)
    for cand in all_candidates:
        u = cand['prev_node']; v = cand['curr_node']
        cfg = configs[cand['case']]
        if used_prev.get(u, 0) and not cfg.allow_multiple:
            continue
        if used_curr.get(v, 0) and not cfg.allow_multiple:
            continue
        self.G.add_edge(u, v,
                        similarity=float(cand['final_score']),
                        base_similarity=float(cand['base_similarity']),
                        method='virtual_allpairs_enhanced',
                        case=cand['case'],
                        gap=int(cand['gap']), gap_weight=float(cand['gap_weight']),
                        overlap_prev=float(cand['features']['overlap_prev']),
                        overlap_curr=float(cand['features']['overlap_curr']),
                        iou=float(cand['features']['iou']),
                        centroid_distance=float(cand['features']['centroid_dist']),
                        spatial_similarity=float(cand['similarity_parts']['spatial']),
                        area_similarity=float(cand['similarity_parts']['area']),
                        shape_similarity=float(cand['similarity_parts']['shape']))
        used_prev[u] += 1
        used_curr[v] += 1
        self.last_selected_counts[cand['case']] = self.last_selected_counts.get(cand['case'], 0) + 1

TreeTrackingGraph.build_graph_virtual_allpairs_enhanced = build_graph_virtual_allpairs_enhanced

In [12]:
def build_graph_virtual_allpairs_enhanced_cfg(self, *, base_max_dist: float = 75.0, overlap_gate: float = 0.48, min_base_similarity: float = 0.35, alpha: float = 0.9, max_gap: Optional[int] = None, case_overrides: Optional[Dict[str, Dict[str, float]]] = None) -> None:
    """Enhanced all-pairs with optional case threshold overrides and max_gap constraint."""
    if not self.crowns_gdfs:
        self.load_data(load_images=False)
    self.reset_graph()
    configs = {name: replace(cfg) for name, cfg in self.case_configs.items()}
    if case_overrides:
        for cname, ov in case_overrides.items():
            if cname in configs:
                for k, v in ov.items():
                    # only allow editing known numeric fields
                    if hasattr(configs[cname], k):
                        setattr(configs[cname], k, v)
    # Add nodes
    for om_id in self.om_ids:
        for crown_id, row in self.crowns_gdfs[om_id].iterrows():
            self.G.add_node((om_id, crown_id), **self.crown_attrs[om_id][crown_id])
    self.last_case_counts = {}
    self.last_selected_counts = {name: 0 for name in configs.keys()}

    all_candidates: List[Dict[str, Any]] = []
    for i_idx in range(len(self.om_ids)):
        for j_idx in range(i_idx + 1, len(self.om_ids)):
            gap = j_idx - i_idx
            if max_gap is not None and gap > max_gap:
                continue
            prev_om = self.om_ids[i_idx]
            om_id = self.om_ids[j_idx]
            prev_nodes = [(prev_om, i) for i in range(len(self.crowns_gdfs[prev_om]))]
            curr_nodes = [(om_id, j) for j in range(len(self.crowns_gdfs[om_id]))]
            overlap_counts_prev: Dict[Tuple[int, int], int] = defaultdict(int)
            overlap_counts_curr: Dict[Tuple[int, int], int] = defaultdict(int)
            prelim: List[Dict[str, Any]] = []
            for prev_node in prev_nodes:
                prev_attrs = self.G.nodes[prev_node]
                for curr_node in curr_nodes:
                    curr_attrs = self.crown_attrs[om_id][curr_node[1]]
                    features = self._compute_pair_metrics(prev_attrs, curr_attrs, max_dist=base_max_dist)
                    if features['centroid_dist'] > base_max_dist:
                        continue
                    if features['base_similarity'] < min_base_similarity and features['iou'] < overlap_gate:
                        continue
                    rec = {'prev_node': prev_node, 'curr_node': curr_node, 'prev_attrs': prev_attrs, 'curr_attrs': curr_attrs, 'features': features, 'gap': gap}
                    prelim.append(rec)
                    if features['overlap_prev'] >= overlap_gate:
                        overlap_counts_prev[prev_node] += 1
                    if features['overlap_curr'] >= overlap_gate:
                        overlap_counts_curr[curr_node] += 1
            if not prelim:
                continue
            for cand in prelim:
                cand['case'] = self._classify_match_case(cand['prev_node'], cand['curr_node'], cand['features'], overlap_counts_prev, overlap_counts_curr, overlap_gate)
            prelim = [c for c in prelim if c['case'] != 'none']
            if not prelim:
                continue
            for cand in prelim:
                cfg = configs[cand['case']]
                base_sim, parts = self._weighted_similarity(cand['prev_attrs'], cand['curr_attrs'], weights=cfg.base_similarity_weights, max_dist=base_max_dist)
                score = self._score_candidate(base_sim, parts, cand['features'], cfg)
                gap_w = exp(-alpha * (cand['gap'] - 1)) if cand['gap'] > 1 else 1.0
                final_score = score * gap_w
                if final_score < cfg.similarity_threshold:
                    continue
                cand['base_similarity'] = float(base_sim)
                cand['similarity_parts'] = {k: float(v) for k, v in parts.items()}
                cand['score'] = float(score)
                cand['gap_weight'] = float(gap_w)
                cand['final_score'] = float(final_score)
                all_candidates.append(cand)
            case_counts = defaultdict(int)
            for cand in prelim:
                case_counts[cand['case']] += 1
            for case_name, count in case_counts.items():
                self.last_case_counts[case_name] = self.last_case_counts.get(case_name, 0) + count
    if not all_candidates:
        return
    all_candidates.sort(key=lambda c: c['final_score'], reverse=True)
    used_prev: Dict[Tuple[int,int], int] = defaultdict(int)
    used_curr: Dict[Tuple[int,int], int] = defaultdict(int)
    for cand in all_candidates:
        u = cand['prev_node']; v = cand['curr_node']
        cfg = configs[cand['case']]
        if used_prev.get(u, 0) and not cfg.allow_multiple:
            continue
        if used_curr.get(v, 0) and not cfg.allow_multiple:
            continue
        self.G.add_edge(u, v, similarity=float(cand['final_score']), base_similarity=float(cand['base_similarity']), method='virtual_allpairs_enhanced_cfg', case=cand['case'], gap=int(cand['gap']), gap_weight=float(cand['gap_weight']), overlap_prev=float(cand['features']['overlap_prev']), overlap_curr=float(cand['features']['overlap_curr']), iou=float(cand['features']['iou']), centroid_distance=float(cand['features']['centroid_dist']), spatial_similarity=float(cand['similarity_parts']['spatial']), area_similarity=float(cand['similarity_parts']['area']), shape_similarity=float(cand['similarity_parts']['shape']))
        used_prev[u] += 1
        used_curr[v] += 1
        self.last_selected_counts[cand['case']] = self.last_selected_counts.get(cand['case'], 0) + 1

TreeTrackingGraph.build_graph_virtual_allpairs_enhanced_cfg = build_graph_virtual_allpairs_enhanced_cfg

In [10]:
def compute_detailed_matching_metrics(tracker: TreeTrackingGraph) -> Dict[str, Any]:
    G = tracker.G
    om_ids = tracker.om_ids
    out: Dict[str, Any] = {}

    # Basic counts
    out['num_nodes'] = G.number_of_nodes()
    out['num_edges'] = G.number_of_edges()

    # Edges by case and by gap
    by_case: Dict[str, int] = defaultdict(int)
    by_gap: Dict[int, int] = defaultdict(int)
    by_case_gap: Dict[str, Dict[int, int]] = defaultdict(lambda: defaultdict(int))
    sim_values: List[float] = []
    iou_values: List[float] = []
    overlap_prev_values: List[float] = []
    overlap_curr_values: List[float] = []
    centroid_values: List[float] = []
    for u, v, d in G.edges(data=True):
        case = d.get('case', 'unknown')
        by_case[case] += 1
        gap = int(d.get('gap', (v[0] - u[0] if isinstance(v[0], int) and isinstance(u[0], int) else 1)))
        by_gap[gap] += 1
        by_case_gap[case][gap] += 1
        sim_values.append(float(d.get('similarity', 0.0)))
        iou_values.append(float(d.get('iou', 0.0)))
        overlap_prev_values.append(float(d.get('overlap_prev', 0.0)))
        overlap_curr_values.append(float(d.get('overlap_curr', 0.0)))
        centroid_values.append(float(d.get('centroid_distance', 0.0)))
    out['edges_by_case'] = dict(sorted(by_case.items(), key=lambda kv: (-kv[1], kv[0])))
    out['edges_by_gap'] = dict(sorted(by_gap.items()))
    out['edges_by_case_and_gap'] = {k: dict(sorted(v.items())) for k, v in by_case_gap.items()}

    # Distributions
    def _stats(vals: List[float]) -> Dict[str, float]:
        if not vals:
            return {'count': 0, 'mean': 0.0, 'median': 0.0, 'min': 0.0, 'max': 0.0}
        return {
            'count': len(vals),
            'mean': float(np.mean(vals)),
            'median': float(np.median(vals)),
            'min': float(np.min(vals)),
            'max': float(np.max(vals)),
        }
    out['similarity_stats'] = _stats(sim_values)
    out['iou_stats'] = _stats(iou_values)
    out['overlap_prev_stats'] = _stats(overlap_prev_values)
    out['overlap_curr_stats'] = _stats(overlap_curr_values)
    out['centroid_distance_stats'] = _stats(centroid_values)

    # Per consecutive OM pair
    pair_stats: Dict[str, Dict[str, float]] = {}
    for i in range(len(om_ids) - 1):
        om1, om2 = om_ids[i], om_ids[i + 1]
        om1_nodes = [n for n in G.nodes if n[0] == om1]
        om2_nodes = [n for n in G.nodes if n[0] == om2]
        matches = sum(1 for u, v in G.edges() if u[0] == om1 and v[0] == om2)
        possible = min(len(om1_nodes), len(om2_nodes))
        rate = matches / possible if possible > 0 else 0.0
        pair_stats[f"{om1}->{om2}"] = {
            'matches': matches,
            'possible': possible,
            'rate': float(rate),
        }
    out['pair_stats'] = pair_stats

    # Chains
    chains = tracker._extract_all_chains()
    lengths = [len(c) for c in chains]
    out['chain_count'] = len(chains)
    out['chain_length_stats'] = _stats([float(l) for l in lengths])
    out['long_chains'] = [c for c in chains if len(c) >= max(2, int(np.percentile(lengths, 90)))]

    # Degrees
    out['out_degree_distribution'] = tracker.graph_complexity_metrics()['out_degree_distribution']
    out['in_degree_distribution'] = tracker.graph_complexity_metrics()['in_degree_distribution']

    # High-confidence chains
    def _edge_info(tr, ch):
        for u, v in zip(ch, ch[1:]):
            yield tr.G.get_edge_data(u, v) or {}
    hc = []
    for ch in chains:
        if len(ch) < 2:
            continue
        edges = list(_edge_info(tracker, ch))
        if edges and all((e.get('case') in {'one_to_one','containment'}) and (e.get('similarity',0.0) >= 0.8) for e in edges):
            hc.append(ch)
    out['high_conf_chain_count'] = len(hc)
    out['high_conf_avg_length'] = float(np.mean([len(c) for c in hc])) if hc else 0.0

    return out

In [11]:
# Run enhanced virtual-all-pairs and print/save metrics

enhanced = TreeTrackingGraph()
enhanced.load_data(load_images=False)
enhanced.build_graph_virtual_allpairs_enhanced(base_max_dist=75.0, overlap_gate=0.48, min_base_similarity=0.35, alpha=0.9)

# Core reports
q_report, q = enhanced.quality_report()
c_report, c = enhanced.complexity_report()

# Detailed metrics
metrics = compute_detailed_matching_metrics(enhanced)

print("=== Enhanced Virtual All-Pairs ===")
print(q_report)
print("---")
print(c_report)
print("---")
print("Edges by case:", metrics['edges_by_case'])
print("Edges by gap:", metrics['edges_by_gap'])
print("Similarity stats:", metrics['similarity_stats'])
print("IoU stats:", metrics['iou_stats'])
print("Overlap(prev) stats:", metrics['overlap_prev_stats'])
print("Overlap(curr) stats:", metrics['overlap_curr_stats'])
print("Centroid distance stats:", metrics['centroid_distance_stats'])
print("High-confidence chains:", metrics['high_conf_chain_count'], "avg len:", f"{metrics['high_conf_avg_length']:.2f}")

# Save
prefix = 'virtual_allpairs_enhanced_15oct'
enhanced.save_text(q_report, f'{prefix}_quality_report.txt')
enhanced.save_json(q, f'{prefix}_quality_metrics.json')
enhanced.save_text(c_report, f'{prefix}_complexity_report.txt')
enhanced.save_json(c, f'{prefix}_complexity_metrics.json')
enhanced.save_json(metrics, f'{prefix}_detailed_matching_metrics.json')

=== Enhanced Virtual All-Pairs ===
# Tree Tracking Quality Assessment Report
Total Trees Detected: 626
Total Tracking Edges: 2
Overall Match Rate: 0.004
Average Chain Length: 1.00
Maximum Chain Length: 2
Match Rates by Orthomosaic Pair:
- 1->2: 1/80 (0.013)
- 2->3: 0/116 (0.000)
- 3->4: 0/130 (0.000)
- 4->5: 1/150 (0.007)

Chain Length Distribution:
- Length 1: 622 trees
- Length 2: 2 trees

Edge selection by case:
- one_to_one: 2 / 3 (0.67)
- containment: 0 / 20 (0.00)
---
# Graph Complexity Report
Nodes: 626
Edges: 2
Avg out-degree: 0.003
Avg in-degree: 0.003
Zero out-degree nodes: 624
Zero in-degree nodes: 624
Weakly connected components: 624 (sizes head: [2, 2, 1, 1, 1, 1, 1, 1, 1, 1])
Strongly connected components: 626 (sizes head: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Average diameter: 0.003
Median diameter: 0.000
Max diameter: 1
---
Edges by case: {'one_to_one': 2}
Edges by gap: {1: 2}
Similarity stats: {'count': 2, 'mean': 0.8355917390249512, 'median': 0.8355917390249512, 'min': 0.83

'../../output/virtual_allpairs_enhanced_15oct_detailed_matching_metrics.json'

In [13]:
# Parameter sweep over enhanced virtual edges to improve recall while preserving precision
from itertools import product

def run_variant(alpha, base_max_dist, overlap_gate, thr_delta, max_gap=2):
    cfg = {'one_to_one': {'similarity_threshold': TreeTrackingGraph().case_configs['one_to_one'].similarity_threshold + thr_delta}, 'containment': {'similarity_threshold': TreeTrackingGraph().case_configs['containment'].similarity_threshold + thr_delta}}
    t = TreeTrackingGraph()
    t.load_data(load_images=False)
    t.build_graph_virtual_allpairs_enhanced_cfg(base_max_dist=base_max_dist, overlap_gate=overlap_gate, min_base_similarity=0.35, alpha=alpha, max_gap=max_gap, case_overrides=cfg)
    q_report, q = t.quality_report()
    c_report, c = t.complexity_report()
    m = compute_detailed_matching_metrics(t)
    # composite score: prioritize high-conf chains, then match rate, penalize zero-degree nodes
    composite = (m['high_conf_chain_count'] * 2.0) + q['overall_match_rate'] - 0.0005 * (c.get('zero_in_degree_nodes',0) + c.get('zero_out_degree_nodes',0))
    return {'t': t, 'q': q, 'c': c, 'm': m, 'q_report': q_report, 'c_report': c_report, 'params': {'alpha': alpha, 'base_max_dist': base_max_dist, 'overlap_gate': overlap_gate, 'thr_delta': thr_delta, 'max_gap': max_gap}, 'score': composite}

alphas = [0.5, 0.8, 1.0]
base_dists = [75.0, 85.0, 95.0]
overgates = [0.45, 0.48, 0.52]
thr_deltas = [-0.05, -0.02, 0.0]

results = []
for a, d, og, td in product(alphas, base_dists, overgates, thr_deltas):
    res = run_variant(a, d, og, td, max_gap=2)
    results.append(res)
    print(f"Tried alpha={a}, dist={d}, og={og}, thrΔ={td} -> edges={res['c']['num_edges']}, hc={res['m']['high_conf_chain_count']}, rate={res['q']['overall_match_rate']:.3f}")

# pick best
best = max(results, key=lambda r: r['score']) if results else None
if best:
    print("\nBest params:", best['params'])
    print("Score:", best['score'])
    print(best['q_report'])
    print('---')
    print(best['c_report'])
    # save artifacts
    prefix = 'virtual_allpairs_enhanced_sweep_best'
    best['t'].save_text(best['q_report'], f'{prefix}_quality_report.txt')
    best['t'].save_json(best['q'], f'{prefix}_quality_metrics.json')
    best['t'].save_text(best['c_report'], f'{prefix}_complexity_report.txt')
    best['t'].save_json(best['c'], f'{prefix}_complexity_metrics.json')
    best['t'].save_json(best['m'], f'{prefix}_detailed_matching_metrics.json')
else:
    print('No results found in sweep.')

Tried alpha=0.5, dist=75.0, og=0.45, thrΔ=-0.05 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.45, thrΔ=-0.02 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.45, thrΔ=0.0 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.48, thrΔ=-0.05 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.48, thrΔ=-0.02 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.48, thrΔ=0.0 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.52, thrΔ=-0.05 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.52, thrΔ=-0.02 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=75.0, og=0.52, thrΔ=0.0 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=85.0, og=0.45, thrΔ=-0.05 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=85.0, og=0.45, thrΔ=-0.02 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=85.0, og=0.45, thrΔ=0.0 -> edges=2, hc=2, rate=0.004
Tried alpha=0.5, dist=85.0, og=0.48, thrΔ=-0.05 -> edges=2, hc=2, rate=0.004
Tried a

In [14]:
# Visualization helpers: overlay matched pairs across OMs for top chains
import matplotlib.pyplot as plt

def visualize_chains(tracker: TreeTrackingGraph, chains: List[List[Tuple[int,int]]], k: int = 5, prefix: str = 'virtual_vis'):
    os.makedirs(tracker.output_dir, exist_ok=True)
    shown = 0
    for idx, ch in enumerate(sorted(chains, key=lambda c: len(c), reverse=True)):
        if shown >= k:
            break
        # build a simple plot: per edge, plot prev and curr polygons in different colors
        fig, axes = plt.subplots(1, len(ch)-1 if len(ch)>1 else 1, figsize=(4*(len(ch)-1 if len(ch)>1 else 1), 4))
        if not isinstance(axes, np.ndarray):
            axes = np.array([axes])
        for ax, (u, v) in zip(axes, zip(ch, ch[1:])):
            prev_geom = tracker.G.nodes[u]['geometry']
            curr_geom = tracker.G.nodes[v]['geometry']
            x,y = prev_geom.exterior.xy if hasattr(prev_geom, 'exterior') else ([],[])
            ax.plot(x, y, color='tab:blue', label=f'{u[0]}:{u[1]}')
            x2,y2 = curr_geom.exterior.xy if hasattr(curr_geom, 'exterior') else ([],[])
            ax.plot(x2, y2, color='tab:orange', label=f'{v[0]}:{v[1]}')
            e = tracker.G.get_edge_data(u,v) or {}
            ax.set_title(f"{u[0]}→{v[0]} | sim={e.get('similarity',0):.2f} iou={e.get('iou',0):.2f}")
            ax.legend(loc='best', fontsize=8)
            ax.set_aspect('equal', 'box')
        plt.tight_layout()
        out_path = os.path.join(tracker.output_dir, f"{prefix}_chain_{idx+1}.png")
        plt.savefig(out_path, dpi=150)
        plt.close(fig)
        shown += 1
    return shown

In [15]:
# Run sweep and visualize best
best_result = None
try:
    # If sweep already ran in this session, reuse variable 'best'; else run now
    best_result = best  # noqa: F821
except NameError:
    pass

if best_result is None:
    # Execute the sweep cell above programmatically by reusing definitions
    from itertools import product
    alphas = [0.5, 0.8, 1.0]
    base_dists = [75.0, 85.0, 95.0]
    overgates = [0.45, 0.48, 0.52]
    thr_deltas = [-0.05, -0.02, 0.0]
    results = []
    for a, d, og, td in product(alphas, base_dists, overgates, thr_deltas):
        res = run_variant(a, d, og, td, max_gap=2)
        results.append(res)
        print(f"Tried alpha={a}, dist={d}, og={og}, thrΔ={td} -> edges={res['c']['num_edges']}, hc={res['m']['high_conf_chain_count']}, rate={res['q']['overall_match_rate']:.3f}")
    best_result = max(results, key=lambda r: r['score']) if results else None

if best_result:
    print("\n[Best] params:", best_result['params'])
    print("[Best] edges:", best_result['c']['num_edges'], "high-conf chains:", best_result['m']['high_conf_chain_count'])
    # visualize top chains
    chains = best_result['t']._extract_all_chains()
    # pick high-confidence chains first if available
    hc = []
    def _edge_info(tr, ch):
        for u, v in zip(ch, ch[1:]):
            yield tr.G.get_edge_data(u, v) or {}
    for ch in chains:
        if len(ch) < 2: continue
        edges = list(_edge_info(best_result['t'], ch))
        if edges and all((e.get('case') in {'one_to_one','containment'}) and (e.get('similarity',0.0) >= 0.8) for e in edges):
            hc.append(ch)
    selected = hc if hc else [c for c in chains if len(c) >= 2]
    n_plotted = visualize_chains(best_result['t'], selected, k=6, prefix='virtual_allpairs_enhanced_best')
    print(f"Saved {n_plotted} chain visualizations to {best_result['t'].output_dir}")
else:
    print('No best result available to visualize.')


[Best] params: {'alpha': 0.5, 'base_max_dist': 75.0, 'overlap_gate': 0.45, 'thr_delta': -0.05, 'max_gap': 2}
[Best] edges: 2 high-conf chains: 2
Saved 2 chain visualizations to ../../output


In [8]:
# Run: shuffled-order aggregation vs virtual all-pairs; compare metrics
from copy import deepcopy
import math

def run_and_metrics(build_fn):
    tracker = TreeTrackingGraph()
    # Reuse loaded data for speed
    tracker.load_data(load_images=False)
    build_fn(tracker)
    q_report, q_metrics = tracker.quality_report()
    c_report, c_metrics = tracker.complexity_report()
    # High-confidence chains (strict cases, high similarity)
    def _edge_info(tr, ch):
        for u, v in zip(ch, ch[1:]):
            yield tr.G.get_edge_data(u, v) or {}
    chains = tracker._extract_all_chains()
    hc = []
    for ch in chains:
        if len(ch) < 2:
            continue
        edges = list(_edge_info(tracker, ch))
        if edges and all((e.get('case') in {'one_to_one','containment'}) and (e.get('similarity',0.0) >= 0.8) for e in edges):
            hc.append(ch)
    return {
        'tracker': tracker,
        'q_report': q_report,
        'q': q_metrics,
        'c_report': c_report,
        'c': c_metrics,
        'hc_count': len(hc),
        'hc_avg_len': float(np.mean([len(ch) for ch in hc])) if hc else 0.0,
    }

# Strategy A: shuffle aggregation
def build_shuffle(tracker: TreeTrackingGraph):
    tracker.build_graph_from_shuffles(num_shuffles=8, seed=42, base_max_dist=75.0, overlap_gate=0.48, min_base_similarity=0.35, min_frequency=0.5)

# Strategy B: virtual all-pairs
def build_virtual(tracker: TreeTrackingGraph):
    tracker.build_graph_virtual_allpairs(base_max_dist=75.0, overlap_gate=0.48, min_base_similarity=0.35)

shuffle_out = run_and_metrics(build_shuffle)
virtual_out = run_and_metrics(build_virtual)

def summarize(name, out):
    q, c = out['q'], out['c']
    print(f"\n=== {name} ===")
    print(f"Edges: {c['num_edges']} | Match rate: {q['overall_match_rate']:.3f}")
    print(f"Avg chain: {q.get('average_chain_length',0):.2f} | Max chain: {q.get('max_chain_length',0)}")
    print(f"High-conf chains: {out['hc_count']} (avg len {out['hc_avg_len']:.2f})")
    print(f"Zero in/out: {c.get('zero_in_degree_nodes',0)}/{c.get('zero_out_degree_nodes',0)} | Max diameter: {c.get('max_diameter',0)}")

summarize('Shuffle aggregation', shuffle_out)
summarize('Virtual all-pairs', virtual_out)

# Simple decision heuristic: prefer higher high-conf chains, then higher match rate, then fewer zero-degree nodes
def decide(a, b):
    ak = (a['hc_count'], a['q']['overall_match_rate'], -a['c'].get('zero_in_degree_nodes',0)-a['c'].get('zero_out_degree_nodes',0))
    bk = (b['hc_count'], b['q']['overall_match_rate'], -b['c'].get('zero_in_degree_nodes',0)-b['c'].get('zero_out_degree_nodes',0))
    return 'A' if ak > bk else ('B' if bk > ak else 'tie')

choice = decide(shuffle_out, virtual_out)
if choice == 'A':
    print("\n>>> Preferred: Shuffle aggregation — more consistent high-confidence chains / match quality.")
elif choice == 'B':
    print("\n>>> Preferred: Virtual all-pairs — better connectivity or confidence under current data.")
else:
    print("\n>>> Strategies are comparable on current metrics.")


=== Shuffle aggregation ===
Edges: 0 | Match rate: 0.000
Avg chain: 1.00 | Max chain: 1
High-conf chains: 0 (avg len 0.00)
Zero in/out: 626/626 | Max diameter: 0

=== Virtual all-pairs ===
Edges: 3 | Match rate: 0.004
Avg chain: 1.00 | Max chain: 2
High-conf chains: 3 (avg len 2.00)
Zero in/out: 623/623 | Max diameter: 1

>>> Preferred: Virtual all-pairs — better connectivity or confidence under current data.
