# Apple Detection and Yield Estimation from 3D Point Clouds

## ACM SAC 2026 - Modular UAV-Based Framework

[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

This notebook implements a modular framework for automated apple detection, counting, and yield estimation from UAV-derived 3D point clouds.

**Key Results:**
- Regression achieves R² ≈ 0.99 for apple count estimation
- Clustering with MBB validation achieves R² ≈ 0.75
- Processes 260M+ points across 50 orchard sections

### Framework Configurations

| Method | Color Filter | Clustering | Object Approximation | Description |
|--------|-------------|------------|---------------------|-------------|
| **M1** | HSV | DBSCAN | Minimum Bounding Box (MBB) | Uses cube-likeness validation |
| **M2** | HSV | DBSCAN | Sphere | Validates clusters using sphere radius |
| **M3** | ExR-LAB | DBSCAN | Sphere | Uses Excess Red + LAB color space filtering |

### Paper Reference

This code accompanies the paper: *"A Modular UAV-Based Framework for Apple Detection and Yield Extrapolation from 3D Point Clouds"* submitted to ACM SAC 2026.

---

## 1. Configuration and Setup

Run this cell to select which method to execute. You will be prompted to choose M1, M2, M3, or all.

In [None]:
# ============================================================================
#                          USER CONFIGURATION
# ============================================================================

# ---------------------------------------------------------------------------
# INTERACTIVE METHOD SELECTION
# Set INTERACTIVE_MODE = True to select method via input prompt
# Set INTERACTIVE_MODE = False to use the DEFAULT_METHOD value
# ---------------------------------------------------------------------------

INTERACTIVE_MODE = True   # Set to False to skip the prompt and use DEFAULT_METHOD
DEFAULT_METHOD = "M1"     # Used when INTERACTIVE_MODE = False

def get_method_selection():
    """Prompt user to select a method interactively."""
    print("")
    print("=" * 62)
    print("           APPLE DETECTION METHOD SELECTION")
    print("=" * 62)
    print("  M1  : HSV + DBSCAN + MBB (Minimum Bounding Box)")
    print("  M2  : HSV + DBSCAN + Sphere")
    print("  M3  : ExR-LAB + DBSCAN + Sphere")
    print("  all : Run all methods and generate comparison")
    print("=" * 62)
    print("")
    
    while True:
        user_input = input("Enter method (M1/M2/M3/all): ").strip().upper()
        if user_input in ["M1", "M2", "M3", "ALL"]:
            selected = user_input if user_input != "ALL" else "all"
            print(f"\n>>> Selected: {selected}")
            return selected
        print("Invalid selection. Please enter M1, M2, M3, or all.")

# Get method selection
if INTERACTIVE_MODE:
    METHOD = get_method_selection()
else:
    METHOD = DEFAULT_METHOD
    print(f"Using default method: {METHOD}")

# Input/Output Paths
INPUT_DIR = "./data/las_files"      # Directory containing .las point cloud files
GT_CSV = "./data/ground_truth.csv"  # Ground truth CSV (columns: filename/sample, true_count/count)
OUTPUT_BASE_DIR = "./Results"        # Base output directory

# Training Configuration
TRAIN_FROM_SCRATCH = False           # Set True to retrain regression models
PRETRAINED_MODEL_DIR = "./models"    # Directory for pretrained models

# Output Options
SAVE_VISUALIZATIONS = True           # Generate and save storyboard visualizations
SAVE_INTERMEDIATE_FILES = True       # Save intermediate processing files (CSV, NPY)

print(f"\nConfiguration:")
print(f"  Method: {METHOD}")
print(f"  Input directory: {INPUT_DIR}")
print(f"  Ground truth CSV: {GT_CSV}")
print(f"  Output directory: {OUTPUT_BASE_DIR}")
print(f"  Save visualizations: {SAVE_VISUALIZATIONS}")
print(f"  Save intermediate files: {SAVE_INTERMEDIATE_FILES}")

## 2. Dependencies and Imports

In [None]:
# Standard library imports
import os
import sys
import glob
import json
import pickle
import re
import warnings
import math
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, asdict

# Scientific computing
import numpy as np
import pandas as pd

# Point cloud processing
import laspy

# Machine learning and clustering
from sklearn.cluster import DBSCAN
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.neighbors import KNeighborsRegressor

# Color processing
import matplotlib.colors as mcolors
from skimage.color import rgb2lab

# Coordinate transformation
from pyproj import Transformer

# Visualization
import matplotlib
matplotlib.use("Agg")  # Use non-interactive backend for file saving
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

print("All dependencies imported successfully.")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 3. Parameter Definitions

All parameters from the paper (Table 1) are defined here.

In [None]:
@dataclass
class HSVParams:
    """HSV color filtering parameters (Table 1 in paper)."""
    H_LOW_WRAP: float = 0.96    # Hue threshold (wrap-around for red)
    H_HIGH_WRAP: float = 0.10   # Hue upper wrap threshold
    S_MIN: float = 0.20         # Minimum saturation
    S_MAX: float = 0.80         # Maximum saturation
    V_MIN: float = 0.50         # Minimum value (brightness)
    V_MAX: float = 1.00         # Maximum value

@dataclass
class ExRLABParams:
    """Excess Red + LAB color filtering parameters."""
    RED_THRESHOLD: float = 0.20   # ExR threshold (2R - G - B)
    A_MIN: float = 10.0           # LAB a* channel minimum
    A_MAX: float = 60.0           # LAB a* channel maximum

@dataclass
class DBSCANParams:
    """DBSCAN clustering parameters."""
    EPS: float = 0.028           # Neighbor radius (meters)
    MIN_SAMPLES: int = 136       # Minimum samples per cluster

@dataclass
class MBBParams:
    """Minimum Bounding Box validation parameters."""
    DIAM_MIN: float = 0.04       # Minimum space diagonal (m)
    DIAM_MAX: float = 0.20       # Maximum space diagonal (m)
    CUBE_RATIO_MIN: float = 0.50 # Minimum side ratio for cube-likeness

@dataclass
class SphereParams:
    """Sphere approximation parameters."""
    RADIUS_MIN: float = 0.03     # Minimum sphere radius (m)
    RADIUS_MAX: float = 0.10     # Maximum sphere radius (m)

# Instantiate default parameters
hsv_params = HSVParams()
exr_lab_params = ExRLABParams()
dbscan_params = DBSCANParams()
mbb_params = MBBParams()
sphere_params = SphereParams()

# Coordinate system handling
FORCE_LONLAT = None  # None=auto-detect, True=force, False=skip

print("Parameters initialized (from Table 1 in paper):")
print(f"  HSV: H in [0, {hsv_params.H_HIGH_WRAP}] or [{hsv_params.H_LOW_WRAP}, 1]")
print(f"       S in [{hsv_params.S_MIN}, {hsv_params.S_MAX}], V in [{hsv_params.V_MIN}, {hsv_params.V_MAX}]")
print(f"  DBSCAN: eps={dbscan_params.EPS}m, min_samples={dbscan_params.MIN_SAMPLES}")
print(f"  MBB: D in [{mbb_params.DIAM_MIN}, {mbb_params.DIAM_MAX}]m, cube_ratio >= {mbb_params.CUBE_RATIO_MIN}")
print(f"  Sphere: R in [{sphere_params.RADIUS_MIN}, {sphere_params.RADIUS_MAX}]m")

## 4. Utility Functions

Core utility functions for file I/O, coordinate transformation, and data preprocessing.

In [None]:
# ============================================================================
#                          FILE I/O FUNCTIONS
# ============================================================================

def read_las_xyzrgb(las_path: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    Read LAS file and extract XYZ coordinates with normalized RGB values.
    
    Args:
        las_path: Path to .las file
        
    Returns:
        xyz: (N, 3) array of coordinates
        rgb: (N, 3) array of RGB values in [0, 1]
    """
    las = laspy.read(las_path)
    xyz = np.vstack((las.x, las.y, las.z)).T
    
    if hasattr(las, "red") and hasattr(las, "green") and hasattr(las, "blue"):
        r = np.asarray(las.red, dtype=np.float32)
        g = np.asarray(las.green, dtype=np.float32)
        b = np.asarray(las.blue, dtype=np.float32)
        
        # Auto-detect 8-bit vs 16-bit color encoding
        max_val = 65535.0 if max(r.max(initial=0), g.max(initial=0), b.max(initial=0)) > 255 else 255.0
        rgb = np.vstack((r, g, b)).T / max_val
        rgb = np.clip(rgb, 0.0, 1.0)
    else:
        rgb = np.zeros((xyz.shape[0], 3), dtype=np.float32)
    
    return xyz, rgb


def load_ground_truth(gt_csv_path: str) -> Dict[str, float]:
    """
    Load ground truth apple counts from CSV.
    
    Accepts column formats:
      - (filename, true_count)
      - (sample, count)
    
    Args:
        gt_csv_path: Path to ground truth CSV
        
    Returns:
        Dictionary mapping sample keys (lowercase, no extension) to counts
    """
    if not os.path.exists(gt_csv_path):
        print(f"Warning: Ground truth CSV not found: {gt_csv_path}")
        return {}
    
    df = pd.read_csv(gt_csv_path)
    cols = {c.lower(): c for c in df.columns}
    
    if {"filename", "true_count"}.issubset(cols.keys()):
        df = df.rename(columns={cols["filename"]: "filename", cols["true_count"]: "true_count"})
    elif {"sample", "count"}.issubset(cols.keys()):
        df = df.rename(columns={cols["sample"]: "filename", cols["count"]: "true_count"})
    else:
        available_cols = list(df.columns)
        raise ValueError(f"GT CSV must have (filename, true_count) or (sample, count). Found: {available_cols}")
    
    df["key"] = df["filename"].astype(str).str.lower().str.replace(r"\.[^.]+$", "", regex=True)
    df["true_count"] = pd.to_numeric(df["true_count"], errors="coerce")
    df = df.dropna(subset=["key", "true_count"])
    
    gt_map = dict(zip(df["key"], df["true_count"]))
    print(f"Loaded ground truth for {len(gt_map)} samples")
    return gt_map


def setup_output_directories(base_dir: str, method: str) -> Dict[str, str]:
    """
    Create standardized output directory structure.
    
    Args:
        base_dir: Base output directory
        method: Method name (M1, M2, M3, or Combined)
        
    Returns:
        Dictionary of output paths
    """
    paths = {
        "root": os.path.join(base_dir, method),
        "intermediate": os.path.join(base_dir, method, "intermediate"),
        "visualizations": os.path.join(base_dir, method, "visualizations"),
        "metrics": os.path.join(base_dir, method, "metrics"),
        "models": os.path.join(base_dir, method, "models"),
    }
    
    for path in paths.values():
        os.makedirs(path, exist_ok=True)
    
    return paths

print("File I/O functions defined.")

In [None]:
# ============================================================================
#                    COORDINATE TRANSFORMATION FUNCTIONS
# ============================================================================

def looks_like_lonlat(xy: np.ndarray) -> bool:
    """
    Detect if coordinates are in longitude/latitude degrees.
    """
    x, y = xy[:, 0], xy[:, 1]
    if not (np.all(np.isfinite(x)) and np.all(np.isfinite(y))):
        return False
    
    in_range = (x.min() >= -180) and (x.max() <= 180) and (y.min() >= -90) and (y.max() <= 90)
    span_ok = (x.max() - x.min()) < 2.0 and (y.max() - y.min()) < 2.0
    return in_range and span_ok


def utm_epsg_from_lonlat(lon: float, lat: float) -> str:
    """
    Determine UTM EPSG code from geographic coordinates.
    """
    zone = int(np.floor((lon + 180) / 6) + 1)
    return f"EPSG:326{zone:02d}" if lat >= 0 else f"EPSG:327{zone:02d}"


def reproject_xy_to_meters(xy: np.ndarray, force_lonlat: Optional[bool] = None) -> Tuple[np.ndarray, Optional[str]]:
    """
    Reproject XY coordinates to meters if needed.
    """
    if force_lonlat is False:
        return xy.copy(), None
    
    if force_lonlat is True or looks_like_lonlat(xy):
        lon, lat = xy[:, 0], xy[:, 1]
        epsg = utm_epsg_from_lonlat(float(np.mean(lon)), float(np.mean(lat)))
        
        transformer = Transformer.from_crs("EPSG:4326", epsg, always_xy=True)
        x_m, y_m = transformer.transform(lon, lat)
        xy_m = np.column_stack([x_m, y_m])
        
        return xy_m, epsg
    else:
        return xy.copy(), None

print("Coordinate transformation functions defined.")

## 5. Color Filtering Functions

Implementation of HSV and ExR-LAB color filtering methods.

In [None]:
# ============================================================================
#                         COLOR FILTERING FUNCTIONS
# ============================================================================

def hsv_red_mask(rgb: np.ndarray, params: HSVParams = None) -> np.ndarray:
    """
    HSV-based red color filtering for apple detection.
    The hue component wraps around 0/1 for red colors.
    """
    if params is None:
        params = hsv_params
    
    hsv = mcolors.rgb_to_hsv(rgb)
    h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2]
    
    hue_mask = (h >= params.H_LOW_WRAP) | (h <= params.H_HIGH_WRAP)
    sat_mask = (s >= params.S_MIN) & (s <= params.S_MAX)
    val_mask = (v >= params.V_MIN) & (v <= params.V_MAX)
    
    return hue_mask & sat_mask & val_mask


def excess_red_mask(rgb: np.ndarray, params: ExRLABParams = None) -> np.ndarray:
    """
    Excess Red + LAB color filtering for apple detection.
    Two-stage: ExR = 2*R - G - B, then LAB a* channel constraint.
    """
    if params is None:
        params = exr_lab_params
    
    excess_red = 2.0 * rgb[:, 0] - rgb[:, 1] - rgb[:, 2]
    
    lab = rgb2lab(rgb.reshape(-1, 1, 3)).reshape(-1, 3)
    a_channel = lab[:, 1]
    
    return (excess_red > params.RED_THRESHOLD) & \
           (a_channel >= params.A_MIN) & \
           (a_channel <= params.A_MAX)

print("Color filtering functions defined.")

## 6. Object Approximation Functions

Implementation of MBB and Sphere geometric approximations.

In [None]:
# ============================================================================
#                      OBJECT APPROXIMATION FUNCTIONS
# ============================================================================

def compute_aabb(points: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute Axis-Aligned Bounding Box for a point cluster."""
    min_pt = points.min(axis=0)
    max_pt = points.max(axis=0)
    sides = max_pt - min_pt
    return min_pt, max_pt, sides


def validate_mbb(sides: np.ndarray, params: MBBParams = None) -> Tuple[bool, Dict[str, float]]:
    """Validate cluster using Minimum Bounding Box criteria (cube-likeness + size)."""
    if params is None:
        params = mbb_params
    
    dx, dy, dz = float(sides[0]), float(sides[1]), float(sides[2])
    space_diagonal = np.sqrt(dx**2 + dy**2 + dz**2)
    
    s_min = min(dx, dy, dz)
    s_max = max(dx, dy, dz)
    cube_ratio = s_min / s_max if s_max > 0 else 0
    
    is_cubelike = cube_ratio >= params.CUBE_RATIO_MIN
    in_size_window = params.DIAM_MIN <= space_diagonal <= params.DIAM_MAX
    is_valid = is_cubelike and in_size_window
    
    metrics = {
        "dx": dx, "dy": dy, "dz": dz,
        "space_diagonal": space_diagonal,
        "cube_ratio": cube_ratio,
        "is_cubelike": is_cubelike,
        "in_size_window": in_size_window,
    }
    
    return is_valid, metrics


def validate_sphere(sides: np.ndarray, params: SphereParams = None) -> Tuple[bool, Dict[str, float]]:
    """Validate cluster using inscribed sphere criteria."""
    if params is None:
        params = sphere_params
    
    min_side = float(np.min(sides))
    radius = min_side / 2.0
    volume = (4.0 / 3.0) * np.pi * (radius ** 3)
    
    in_size_window = params.RADIUS_MIN <= radius <= params.RADIUS_MAX
    is_valid = in_size_window
    
    metrics = {
        "radius": radius,
        "diameter": min_side,
        "volume": volume,
        "in_size_window": in_size_window,
    }
    
    return is_valid, metrics


def estimate_ellipsoid_volume(sides: np.ndarray, alpha: float = 1.15) -> float:
    """Estimate apple volume using prolate ellipsoid model."""
    dx, dy, dz = float(sides[0]), float(sides[1]), float(sides[2])
    a = min(dx, dy) / 2.0
    c = min(alpha * a, dz / 2.0)
    return (4.0 / 3.0) * np.pi * (a ** 2) * c

print("Object approximation functions defined.")

## 7. Visualization Functions

Functions to generate storyboard visualizations of detection results.

In [None]:
# ============================================================================
#                         VISUALIZATION FUNCTIONS
# ============================================================================

def generate_storyboard(sample_id: str, xyz: np.ndarray, rgb: np.ndarray,
                        red_xyz: np.ndarray, labels: np.ndarray,
                        accepted_mask: np.ndarray, output_path: str,
                        method_name: str, ground_truth: Optional[float] = None):
    """
    Generate a 4-panel storyboard visualization.
    
    Panels:
    1. Original point cloud (RGB)
    2. Red-filtered points
    3. DBSCAN clusters (colored by label)
    4. Accepted clusters only (valid apples)
    """
    fig = plt.figure(figsize=(16, 12))
    
    # Subsample for faster plotting
    max_points = 50000
    if xyz.shape[0] > max_points:
        idx = np.random.choice(xyz.shape[0], max_points, replace=False)
        xyz_sub, rgb_sub = xyz[idx], rgb[idx]
    else:
        xyz_sub, rgb_sub = xyz, rgb
    
    # Panel 1: Original point cloud
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    ax1.scatter(xyz_sub[:, 0], xyz_sub[:, 1], xyz_sub[:, 2], 
                c=rgb_sub, s=0.1, alpha=0.5)
    ax1.set_title(f"1. Original Point Cloud (N={xyz.shape[0]:,})")
    ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_zlabel('Z')
    
    # Panel 2: Red-filtered points
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    if red_xyz.shape[0] > 0:
        red_sub = red_xyz if red_xyz.shape[0] <= max_points else red_xyz[np.random.choice(red_xyz.shape[0], max_points, replace=False)]
        ax2.scatter(red_sub[:, 0], red_sub[:, 1], red_sub[:, 2], 
                    c='red', s=0.5, alpha=0.6)
    ax2.set_title(f"2. Red-Filtered Points (N={red_xyz.shape[0]:,})")
    ax2.set_xlabel('X'); ax2.set_ylabel('Y'); ax2.set_zlabel('Z')
    
    # Panel 3: DBSCAN clusters
    ax3 = fig.add_subplot(2, 2, 3, projection='3d')
    unique_labels = np.unique(labels[labels >= 0])
    n_clusters = len(unique_labels)
    if n_clusters > 0:
        colors = plt.cm.tab20(np.linspace(0, 1, max(n_clusters, 1)))
        for i, lab in enumerate(unique_labels):
            pts = red_xyz[labels == lab]
            if pts.shape[0] > 0:
                ax3.scatter(pts[:, 0], pts[:, 1], pts[:, 2], 
                           c=[colors[i % len(colors)]], s=1, alpha=0.7)
    ax3.set_title(f"3. DBSCAN Clusters (N={n_clusters})")
    ax3.set_xlabel('X'); ax3.set_ylabel('Y'); ax3.set_zlabel('Z')
    
    # Panel 4: Accepted clusters
    ax4 = fig.add_subplot(2, 2, 4, projection='3d')
    n_accepted = accepted_mask.sum()
    accepted_labels = unique_labels[accepted_mask[:len(unique_labels)]] if len(accepted_mask) >= len(unique_labels) else []
    if len(accepted_labels) > 0:
        for i, lab in enumerate(accepted_labels):
            pts = red_xyz[labels == lab]
            if pts.shape[0] > 0:
                ax4.scatter(pts[:, 0], pts[:, 1], pts[:, 2], 
                           c='green', s=2, alpha=0.8)
    gt_str = f", GT={int(ground_truth)}" if ground_truth else ""
    ax4.set_title(f"4. Accepted Apples (N={n_accepted}{gt_str})")
    ax4.set_xlabel('X'); ax4.set_ylabel('Y'); ax4.set_zlabel('Z')
    
    plt.suptitle(f"{sample_id} - {method_name} Detection Pipeline", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close(fig)


def generate_comparison_plot(all_results: Dict, output_dir: str):
    """Generate comparison bar charts and scatter plots for all methods."""
    methods = list(all_results.keys())
    
    # Metrics comparison bar chart
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    
    # RMSE
    rmse_vals = [all_results[m]["metrics"].get("rmse", 0) for m in methods]
    axes[0].bar(methods, rmse_vals, color=['#2ecc71', '#3498db', '#e74c3c'])
    axes[0].set_title("RMSE (lower is better)")
    axes[0].set_ylabel("RMSE")
    
    # R2
    r2_vals = [all_results[m]["metrics"].get("r2", 0) for m in methods]
    axes[1].bar(methods, r2_vals, color=['#2ecc71', '#3498db', '#e74c3c'])
    axes[1].set_title("R² Score (higher is better)")
    axes[1].set_ylabel("R²")
    axes[1].set_ylim(0, 1)
    
    # Mean Relative Error
    mre_vals = [all_results[m]["metrics"].get("mean_rel_error", 0) for m in methods]
    axes[2].bar(methods, mre_vals, color=['#2ecc71', '#3498db', '#e74c3c'])
    axes[2].set_title("Mean Relative Error (lower is better)")
    axes[2].set_ylabel("MRE")
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "method_comparison.png"), dpi=150, bbox_inches='tight')
    plt.close(fig)
    
    # Predicted vs Actual scatter plots
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    colors = ['#2ecc71', '#3498db', '#e74c3c']
    
    for i, method in enumerate(methods):
        results = all_results[method]["results"]
        valid = [(r.detected_count, r.ground_truth) for r in results 
                 if r.ground_truth is not None and np.isfinite(r.ground_truth)]
        if valid:
            pred, actual = zip(*valid)
            axes[i].scatter(actual, pred, c=colors[i], alpha=0.6, edgecolors='black', linewidth=0.5)
            max_val = max(max(pred), max(actual)) * 1.1
            axes[i].plot([0, max_val], [0, max_val], 'k--', alpha=0.5, label='Perfect')
            axes[i].set_xlabel('Ground Truth')
            axes[i].set_ylabel('Predicted')
            axes[i].set_title(f"{method} (R²={all_results[method]['metrics'].get('r2', 0):.3f})")
            axes[i].legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "predicted_vs_actual.png"), dpi=150, bbox_inches='tight')
    plt.close(fig)

print("Visualization functions defined.")

## 8. Core Detection Pipeline

Main detection pipeline with intermediate file saving.

In [None]:
# ============================================================================
#                         CORE DETECTION PIPELINE
# ============================================================================

@dataclass
class DetectionResult:
    """Container for detection results from a single file."""
    sample_id: str
    n_total_points: int
    n_red_points: int
    n_clusters: int
    n_accepted_clusters: int
    detected_count: int
    ground_truth: Optional[float]
    cluster_metrics: List[Dict]
    total_volume: float
    # For visualization
    xyz: Optional[np.ndarray] = None
    rgb: Optional[np.ndarray] = None
    red_xyz: Optional[np.ndarray] = None
    labels: Optional[np.ndarray] = None
    accepted_mask: Optional[np.ndarray] = None


def detect_apples_single_file(
    las_path: str,
    color_method: str = "HSV",
    validation_method: str = "MBB",
    gt_map: Optional[Dict[str, float]] = None,
    keep_arrays: bool = False
) -> DetectionResult:
    """
    Run apple detection pipeline on a single LAS file.
    
    Pipeline:
    1. Load LAS file (XYZ + RGB)
    2. Reproject to meters if needed
    3. Apply color filtering (HSV or ExR-LAB)
    4. Run DBSCAN clustering
    5. Validate clusters using geometric approximation (MBB or Sphere)
    6. Count accepted clusters as detected apples
    """
    sample_id = os.path.splitext(os.path.basename(las_path))[0].lower()
    
    # Load point cloud
    xyz, rgb = read_las_xyzrgb(las_path)
    n_total_points = xyz.shape[0]
    
    # Reproject to meters
    xy_m, _ = reproject_xy_to_meters(xyz[:, :2], FORCE_LONLAT)
    xyz_m = np.column_stack([xy_m, xyz[:, 2]])
    
    # Apply color filtering
    if color_method.upper() == "HSV":
        red_mask = hsv_red_mask(rgb)
    else:
        red_mask = excess_red_mask(rgb)
    
    red_xyz = xyz_m[red_mask]
    n_red_points = red_xyz.shape[0]
    
    # Handle empty case
    if n_red_points == 0:
        return DetectionResult(
            sample_id=sample_id, n_total_points=n_total_points, n_red_points=0,
            n_clusters=0, n_accepted_clusters=0, detected_count=0,
            ground_truth=gt_map.get(sample_id) if gt_map else None,
            cluster_metrics=[], total_volume=0.0,
            xyz=xyz_m if keep_arrays else None, rgb=rgb if keep_arrays else None,
            red_xyz=red_xyz if keep_arrays else None, labels=np.array([]) if keep_arrays else None,
            accepted_mask=np.array([]) if keep_arrays else None
        )
    
    # DBSCAN clustering
    labels = DBSCAN(eps=dbscan_params.EPS, min_samples=dbscan_params.MIN_SAMPLES).fit_predict(red_xyz)
    unique_labels = [l for l in np.unique(labels) if l != -1]
    n_clusters = len(unique_labels)
    
    # Validate clusters
    cluster_metrics = []
    accepted_list = []
    detected_count = 0
    total_volume = 0.0
    
    for lab in unique_labels:
        pts = red_xyz[labels == lab]
        if pts.shape[0] < 2:
            accepted_list.append(False)
            continue
        
        min_pt, max_pt, sides = compute_aabb(pts)
        centroid = pts.mean(axis=0)
        
        if validation_method.upper() == "MBB":
            is_valid, metrics = validate_mbb(sides)
        else:
            is_valid, metrics = validate_sphere(sides)
        
        metrics["cluster_id"] = int(lab)
        metrics["n_points"] = pts.shape[0]
        metrics["centroid_x"] = float(centroid[0])
        metrics["centroid_y"] = float(centroid[1])
        metrics["centroid_z"] = float(centroid[2])
        metrics["accepted"] = is_valid
        cluster_metrics.append(metrics)
        accepted_list.append(is_valid)
        
        if is_valid:
            detected_count += 1
            total_volume += estimate_ellipsoid_volume(sides)
    
    return DetectionResult(
        sample_id=sample_id, n_total_points=n_total_points, n_red_points=n_red_points,
        n_clusters=n_clusters, n_accepted_clusters=detected_count, detected_count=detected_count,
        ground_truth=gt_map.get(sample_id) if gt_map else None,
        cluster_metrics=cluster_metrics, total_volume=total_volume,
        xyz=xyz_m if keep_arrays else None, rgb=rgb if keep_arrays else None,
        red_xyz=red_xyz if keep_arrays else None, labels=labels if keep_arrays else None,
        accepted_mask=np.array(accepted_list) if keep_arrays else None
    )

print("Core detection pipeline defined.")

## 9. Method Runner Functions

Functions to run M1, M2, M3, and combined analysis with full output saving.

In [None]:
# ============================================================================
#                          METHOD RUNNER FUNCTIONS
# ============================================================================

def compute_metrics(results: List[DetectionResult]) -> Dict[str, float]:
    """Compute evaluation metrics from detection results."""
    valid_results = [r for r in results if r.ground_truth is not None and np.isfinite(r.ground_truth)]
    
    if not valid_results:
        return {"rmse": np.nan, "mae": np.nan, "r2": np.nan, "mean_rel_error": np.nan, "n_samples": 0}
    
    y_true = np.array([r.ground_truth for r in valid_results])
    y_pred = np.array([r.detected_count for r in valid_results])
    
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    mae = float(np.mean(np.abs(y_pred - y_true)))
    r2 = float(r2_score(y_true, y_pred)) if np.var(y_true) > 0 else np.nan
    
    rel_errors = [abs(r.detected_count - r.ground_truth) / r.ground_truth 
                  for r in valid_results if r.ground_truth > 0]
    mean_rel_error = float(np.mean(rel_errors)) if rel_errors else np.nan
    
    return {"rmse": rmse, "mae": mae, "r2": r2, "mean_rel_error": mean_rel_error, "n_samples": len(valid_results)}


def save_intermediate_files(result: DetectionResult, paths: Dict[str, str], method_name: str):
    """Save intermediate processing files for a single sample."""
    sample_id = result.sample_id
    intermediate_dir = paths["intermediate"]
    
    # Save cluster metrics as CSV
    if result.cluster_metrics:
        cluster_df = pd.DataFrame(result.cluster_metrics)
        cluster_df.to_csv(os.path.join(intermediate_dir, f"{sample_id}_clusters.csv"), index=False)
    
    # Save red points as NPY (compressed)
    if result.red_xyz is not None and result.red_xyz.shape[0] > 0:
        np.savez_compressed(
            os.path.join(intermediate_dir, f"{sample_id}_filtered.npz"),
            red_xyz=result.red_xyz,
            labels=result.labels if result.labels is not None else np.array([])
        )


def run_method(method_name: str, input_dir: str, gt_csv: str, output_base: str,
               save_viz: bool = True, save_intermediate: bool = True) -> Tuple[List[DetectionResult], Dict]:
    """
    Run a specific detection method on all LAS files.
    Saves results to metrics/, intermediate/, and visualizations/ folders.
    """
    method_config = {
        "M1": {"color": "HSV", "validation": "MBB"},
        "M2": {"color": "HSV", "validation": "Sphere"},
        "M3": {"color": "ExR-LAB", "validation": "Sphere"},
    }
    
    if method_name not in method_config:
        raise ValueError(f"Unknown method: {method_name}. Choose from M1, M2, M3.")
    
    config = method_config[method_name]
    print(f"\n{'='*60}")
    print(f"Running {method_name}: {config['color']} + DBSCAN + {config['validation']}")
    print(f"{'='*60}")
    
    paths = setup_output_directories(output_base, method_name)
    gt_map = load_ground_truth(gt_csv) if os.path.exists(gt_csv) else {}
    
    las_files = sorted(glob.glob(os.path.join(input_dir, "*.las")))
    if not las_files:
        print(f"No LAS files found in {input_dir}")
        return [], {}
    
    print(f"Found {len(las_files)} LAS files")
    keep_arrays = save_viz or save_intermediate
    
    results = []
    for i, las_path in enumerate(las_files):
        result = detect_apples_single_file(
            las_path, color_method=config["color"], validation_method=config["validation"],
            gt_map=gt_map, keep_arrays=keep_arrays
        )
        results.append(result)
        
        gt_str = f"GT={int(result.ground_truth)}" if result.ground_truth else "GT=N/A"
        print(f"  [{i+1:3d}/{len(las_files)}] {result.sample_id}: detected={result.detected_count}, {gt_str}")
        
        # Save intermediate files
        if save_intermediate and result.red_xyz is not None:
            save_intermediate_files(result, paths, method_name)
        
        # Generate visualization
        if save_viz and result.xyz is not None and result.red_xyz is not None:
            viz_path = os.path.join(paths["visualizations"], f"{result.sample_id}_storyboard.png")
            try:
                generate_storyboard(
                    result.sample_id, result.xyz, result.rgb, result.red_xyz,
                    result.labels if result.labels is not None else np.array([]),
                    result.accepted_mask if result.accepted_mask is not None else np.array([]),
                    viz_path, method_name, result.ground_truth
                )
            except Exception as e:
                print(f"    Warning: Could not generate visualization: {e}")
        
        # Clear arrays to save memory
        result.xyz = None
        result.rgb = None
        result.red_xyz = None
        result.labels = None
    
    # Compute metrics
    metrics = compute_metrics(results)
    
    print(f"\n{method_name} Results:")
    print(f"  RMSE: {metrics['rmse']:.3f}")
    print(f"  MAE: {metrics['mae']:.3f}")
    print(f"  R2: {metrics['r2']:.4f}")
    print(f"  Mean Relative Error: {metrics['mean_rel_error']:.3f}")
    
    # Save results CSV
    results_df = pd.DataFrame([{
        "sample": r.sample_id, "predicted": r.detected_count, "ground_truth": r.ground_truth,
        "n_red_points": r.n_red_points, "n_clusters": r.n_clusters,
        "n_accepted": r.n_accepted_clusters, "total_volume_m3": r.total_volume,
    } for r in results])
    results_csv = os.path.join(paths["metrics"], f"{method_name}_results.csv")
    results_df.to_csv(results_csv, index=False)
    
    # Save summary CSV
    summary_df = pd.DataFrame([metrics])
    summary_csv = os.path.join(paths["metrics"], f"{method_name}_summary.csv")
    summary_df.to_csv(summary_csv, index=False)
    
    # Save metadata JSON
    metadata = {
        "method": method_name, "color_filter": config["color"], "validation": config["validation"],
        "timestamp": datetime.now().isoformat(), "n_files": len(las_files),
        "parameters": {
            "dbscan_eps": dbscan_params.EPS, "dbscan_min_samples": dbscan_params.MIN_SAMPLES,
            "hsv_params": asdict(hsv_params), "mbb_params": asdict(mbb_params), "sphere_params": asdict(sphere_params)
        }
    }
    with open(os.path.join(paths["metrics"], f"{method_name}_metadata.json"), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"\nSaved outputs to: {paths['root']}")
    print(f"  - metrics/{method_name}_results.csv")
    print(f"  - metrics/{method_name}_summary.csv")
    if save_intermediate:
        print(f"  - intermediate/ ({len(las_files)} files)")
    if save_viz:
        print(f"  - visualizations/ ({len(las_files)} storyboards)")
    
    return results, metrics


def run_M1(input_dir=INPUT_DIR, gt_csv=GT_CSV, output_base=OUTPUT_BASE_DIR):
    """Run M1: HSV + DBSCAN + MBB"""
    return run_method("M1", input_dir, gt_csv, output_base, SAVE_VISUALIZATIONS, SAVE_INTERMEDIATE_FILES)

def run_M2(input_dir=INPUT_DIR, gt_csv=GT_CSV, output_base=OUTPUT_BASE_DIR):
    """Run M2: HSV + DBSCAN + Sphere"""
    return run_method("M2", input_dir, gt_csv, output_base, SAVE_VISUALIZATIONS, SAVE_INTERMEDIATE_FILES)

def run_M3(input_dir=INPUT_DIR, gt_csv=GT_CSV, output_base=OUTPUT_BASE_DIR):
    """Run M3: ExR-LAB + DBSCAN + Sphere"""
    return run_method("M3", input_dir, gt_csv, output_base, SAVE_VISUALIZATIONS, SAVE_INTERMEDIATE_FILES)

def run_all(input_dir=INPUT_DIR, gt_csv=GT_CSV, output_base=OUTPUT_BASE_DIR):
    """Run all methods and generate combined comparison."""
    all_results = {}
    
    for method in ["M1", "M2", "M3"]:
        results, metrics = run_method(method, input_dir, gt_csv, output_base, 
                                       SAVE_VISUALIZATIONS, SAVE_INTERMEDIATE_FILES)
        all_results[method] = {"results": results, "metrics": metrics}
    
    # Generate combined outputs
    combined_dir = os.path.join(output_base, "Combined")
    os.makedirs(combined_dir, exist_ok=True)
    
    # Comparison CSV
    comparison_df = pd.DataFrame([
        {"Method": method, **data["metrics"]}
        for method, data in all_results.items()
    ])
    comparison_df.to_csv(os.path.join(combined_dir, "method_comparison.csv"), index=False)
    
    # Comparison plots
    generate_comparison_plot(all_results, combined_dir)
    
    print(f"\n{'='*60}")
    print("Method Comparison Summary")
    print(f"{'='*60}")
    print(comparison_df.to_string(index=False))
    print(f"\nSaved to: {combined_dir}")
    
    return all_results

print("Method runner functions defined.")

## 10. Regression Training (Optional)

Train regression models to predict apple count from red-point features.

In [None]:
# ============================================================================
#                       REGRESSION MODEL TRAINING
# ============================================================================

def train_regression_models(results: List[DetectionResult], output_dir: str) -> Dict:
    """
    Train regression models to predict apple count from clustering features.
    Uses features: n_red_points, n_clusters, n_accepted_clusters, total_volume
    """
    valid = [r for r in results if r.ground_truth is not None and np.isfinite(r.ground_truth)]
    
    if len(valid) < 5:
        print("Not enough samples for training (need at least 5)")
        return {}
    
    # Prepare features
    X = np.array([[r.n_red_points, r.n_clusters, r.n_accepted_clusters, r.total_volume] for r in valid])
    y = np.array([r.ground_truth for r in valid])
    
    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    # Models to train
    models = {
        "DecisionTree": DecisionTreeRegressor(random_state=42),
        "RandomForest": RandomForestRegressor(n_estimators=100, random_state=42),
        "GradientBoosting": GradientBoostingRegressor(n_estimators=100, random_state=42),
        "KNN": KNeighborsRegressor(n_neighbors=min(5, len(X_train)))
    }
    
    trained = {}
    print("\nTraining regression models:")
    
    for name, model in models.items():
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        
        print(f"  {name}: R²={r2:.4f}, RMSE={rmse:.3f}")
        
        # Save model
        model_path = os.path.join(output_dir, f"regression_{name.lower()}.pkl")
        with open(model_path, 'wb') as f:
            pickle.dump(model, f)
        
        trained[name] = {"model": model, "r2": r2, "rmse": rmse, "path": model_path}
    
    print(f"\nModels saved to: {output_dir}")
    return trained

print("Regression training functions defined.")

## 11. Main Execution

Run the selected method based on the configuration in Cell 1.

In [None]:
# ============================================================================
#                            MAIN EXECUTION
# ============================================================================

def main():
    """Main execution based on METHOD selection."""
    print(f"\n{'#'*60}")
    print(f"# Apple Detection Framework - ACM SAC 2026")
    print(f"# Selected Method: {METHOD}")
    print(f"{'#'*60}\n")
    
    if not os.path.exists(INPUT_DIR):
        print(f"ERROR: Input directory does not exist: {INPUT_DIR}")
        print("Please update INPUT_DIR in Cell 1 or create the directory.")
        return
    
    results = None
    
    if METHOD.upper() == "M1":
        results, _ = run_M1()
    elif METHOD.upper() == "M2":
        results, _ = run_M2()
    elif METHOD.upper() == "M3":
        results, _ = run_M3()
    elif METHOD.upper() == "ALL":
        all_results = run_all()
        results = all_results.get("M1", {}).get("results", [])
    else:
        print(f"Unknown method: {METHOD}")
        print("Valid options: M1, M2, M3, all")
        return
    
    # Optional: Train regression models
    if TRAIN_FROM_SCRATCH and results:
        models_dir = os.path.join(OUTPUT_BASE_DIR, METHOD.upper() if METHOD.upper() != "ALL" else "M1", "models")
        os.makedirs(models_dir, exist_ok=True)
        train_regression_models(results, models_dir)
    
    print(f"\n{'='*60}")
    print("Processing complete!")
    print(f"Results saved to: {OUTPUT_BASE_DIR}")
    print(f"{'='*60}")

# Run
main()

---

## Quick Verification

Verify that all components are properly defined.

In [None]:
print("Framework Verification:")
print("-" * 50)
print("Color filters:     hsv_red_mask, excess_red_mask")
print("Validation:        validate_mbb, validate_sphere")
print("Pipeline:          detect_apples_single_file")
print("Methods:           run_M1, run_M2, run_M3, run_all")
print("Visualization:     generate_storyboard, generate_comparison_plot")
print("Training:          train_regression_models")
print("-" * 50)
print("All components ready.")
print("\nOutput folders will contain:")
print("  - metrics/       : CSV results and summary")
print("  - intermediate/  : Cluster CSVs and filtered points (.npz)")
print("  - visualizations/: 4-panel storyboard PNGs")
print("  - models/        : Trained regression models (.pkl)")