# Validation 4: Drone GSD Validation

**Purpose**: Scientific validation proving SPIDS drone GSD calculations are correct  
**Status**: Phase 4 - Drone Learning Track  
**Duration**: 25-30 minutes  

---

## Hypothesis

**Measured GSD matches theoretical formula GSD = H x p / f within 10% error.**

## Method

1. Load drone presets (drone_50m_survey, drone_100m_mapping)
2. Create checkerboard targets with known square sizes
3. Simulate drone camera imaging with SPIDS forward model
4. Measure effective GSD from checkerboard pattern
5. Compare measured vs theoretical GSD

## Success Criteria

- GSD validation: <10% error vs theory
- All tested presets pass validation
- Swath width calculations accurate

---

## Setup

In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch

# SPIDS imports
from prism.core.instruments import create_instrument
from prism.core.targets import CheckerboardTarget
from prism.scenarios import get_scenario_preset, list_scenario_presets
from prism.utils.metrics import compute_ssim, psnr
from prism.validation.baselines import GSDBaseline, compare_to_theoretical


# Plotting style
plt.rcParams["figure.figsize"] = (14, 10)
plt.rcParams["font.size"] = 11
plt.rcParams["axes.titlesize"] = 12

# Device configuration - use CPU for validation to ensure consistency
device = torch.device("cpu")
print(f"Using device: {device}")
print(f"Validation started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("\nSetup complete!")

---

# Section 1: Theoretical Background

## The GSD Formula

Ground Sampling Distance (GSD) is the real-world distance represented by one pixel:

$$\text{GSD} = \frac{H \times p}{f}$$

Where:
- $H$ = Flight altitude (meters)
- $p$ = Pixel pitch (sensor pixel size, in meters)
- $f$ = Focal length (meters)

## Swath Width Formula

$$\text{Swath Width} = \frac{H \times W_{sensor}}{f}$$

## Test Presets

| Preset | Altitude | Lens | Expected GSD |
|--------|----------|------|-------------|
| drone_50m_survey | 50m | 50mm f/4.0 | ~6.5 cm |
| drone_100m_mapping | 100m | 50mm f/4.0 | ~13 cm |

In [None]:
# Define presets to validate
presets_to_validate = [
    "drone_50m_survey",
    "drone_100m_mapping",
]

print("Theoretical GSD Calculations")
print("=" * 80)
print(f"{'Preset':<25} {'Altitude':>10} {'Pixel':>12} {'Focal':>10} {'GSD (cm)':>12}")
print("-" * 80)

theoretical_values = {}
for preset_name in presets_to_validate:
    scenario = get_scenario_preset(preset_name)
    info = scenario.get_info()

    # Extract parameters
    altitude_m = info["altitude_m"]
    pixel_pitch_um = info["pixel_pitch_um"]
    focal_length_mm = info["focal_length_mm"]

    # Calculate theoretical GSD using GSDBaseline
    pixel_pitch_m = pixel_pitch_um * 1e-6
    focal_length_m = focal_length_mm * 1e-3

    theoretical_gsd_m = GSDBaseline.gsd(altitude_m, pixel_pitch_m, focal_length_m)
    theoretical_gsd_cm = theoretical_gsd_m * 100

    # Also get swath width
    sensor_width_mm = info.get("sensor_width_mm", 36.0)  # Full-frame default
    swath_width_m = GSDBaseline.swath_width(altitude_m, sensor_width_mm * 1e-3, focal_length_m)

    theoretical_values[preset_name] = {
        "altitude_m": altitude_m,
        "pixel_pitch_um": pixel_pitch_um,
        "focal_length_mm": focal_length_mm,
        "theoretical_gsd_cm": theoretical_gsd_cm,
        "theoretical_gsd_m": theoretical_gsd_m,
        "swath_width_m": swath_width_m,
        "scenario": scenario,
        "info": info,
    }

    print(
        f"  {preset_name:<23} {altitude_m:>8.0f} m "
        f"{pixel_pitch_um:>8.1f} um {focal_length_mm:>8.0f} mm "
        f"{theoretical_gsd_cm:>10.2f}"
    )

print("=" * 80)
print("\nFormula: GSD = H x p / f")

---

# Section 2: Checkerboard Target Design

For each preset, we create a checkerboard target with:
- Square size based on expected GSD (should span multiple pixels)
- Field of view matching the swath width
- Sufficient squares to measure periodicity

In [None]:
def design_checkerboard_for_preset(preset_info: dict, n_pixels: int = 512) -> dict:
    """
    Design a checkerboard target for GSD validation.

    The square size is chosen to be clearly resolvable (10x GSD) while
    still fitting enough squares to measure the pattern accurately.
    """
    gsd_cm = preset_info["theoretical_gsd_cm"]
    swath_m = preset_info["swath_width_m"]

    # Square size: 10x GSD ensures clear resolution
    # Minimum 10cm squares for practical visibility
    square_size_cm = max(10.0, gsd_cm * 10)
    square_size_m = square_size_cm / 100

    # Field of view: use swath width
    field_size_m = swath_m

    # Number of squares that fit
    n_squares = int(field_size_m / square_size_m)

    # Pixels per square
    gsd_m = preset_info["theoretical_gsd_m"]
    pixels_per_square = square_size_m / gsd_m

    return {
        "square_size_m": square_size_m,
        "square_size_cm": square_size_cm,
        "field_size_m": field_size_m,
        "n_squares": n_squares,
        "pixels_per_square": pixels_per_square,
        "n_pixels": n_pixels,
    }


# Design targets for each preset
print("Checkerboard Target Design")
print("=" * 85)
print(f"{'Preset':<25} {'Square Size':>12} {'Field Size':>12} {'# Squares':>10} {'Pix/Square':>12}")
print("-" * 85)

for preset_name, info in theoretical_values.items():
    target_design = design_checkerboard_for_preset(info)
    info["target_design"] = target_design

    print(
        f"  {preset_name:<23} {target_design['square_size_cm']:>8.0f} cm "
        f"{target_design['field_size_m']:>10.1f} m "
        f"{target_design['n_squares']:>10} "
        f"{target_design['pixels_per_square']:>12.1f}"
    )

print("=" * 85)
print("\nNote: Pixels per square > 2 required for resolution (Nyquist)")

---

# Section 3: GSD Measurement Pipeline

For each preset:
1. Create drone camera instrument from scenario
2. Create checkerboard target with designed parameters
3. Simulate imaging through SPIDS forward model
4. Measure effective GSD from checkerboard periodicity
5. Compare to theoretical GSD

In [None]:
def measure_gsd_from_checkerboard(
    measurement: np.ndarray,
    square_size_m: float,
    field_size_m: float,
) -> dict:
    """
    Measure effective GSD from checkerboard pattern.

    Uses FFT to find the dominant spatial frequency, which corresponds
    to the checkerboard period. From this, we can derive the effective GSD.
    """
    h, w = measurement.shape

    # Compute 2D FFT
    fft_result = np.fft.fft2(measurement)
    fft_shifted = np.fft.fftshift(fft_result)
    power_spectrum = np.abs(fft_shifted) ** 2

    # Find peaks in power spectrum (excluding DC component)
    center_y, center_x = h // 2, w // 2

    # Mask out DC component (central region)
    dc_mask_size = 5
    power_spectrum_masked = power_spectrum.copy()
    power_spectrum_masked[
        center_y - dc_mask_size : center_y + dc_mask_size,
        center_x - dc_mask_size : center_x + dc_mask_size,
    ] = 0

    # Find the strongest frequency component
    peak_idx = np.unravel_index(np.argmax(power_spectrum_masked), power_spectrum_masked.shape)

    # Calculate frequency in cycles per image
    freq_y = abs(peak_idx[0] - center_y)
    freq_x = abs(peak_idx[1] - center_x)
    dominant_freq_pixels = max(freq_y, freq_x)

    if dominant_freq_pixels == 0:
        # Fallback: use image gradient analysis
        grad_y = np.abs(np.gradient(measurement, axis=0))

        # Find average transition spacing
        # This gives us a rough estimate of square size in pixels
        pixels_per_square = h / max(1, np.sum(grad_y.mean(axis=1) > grad_y.mean() * 0.5))
    else:
        # Period in pixels = image_size / frequency
        # Checkerboard has 2 squares per period (black + white)
        period_pixels = h / dominant_freq_pixels
        pixels_per_square = period_pixels / 2

    # Calculate measured GSD
    # GSD = physical_square_size / pixels_per_square
    measured_gsd_m = square_size_m / pixels_per_square
    measured_gsd_cm = measured_gsd_m * 100

    # Also calculate from field of view (as sanity check)
    gsd_from_fov_m = field_size_m / h

    return {
        "measured_gsd_m": measured_gsd_m,
        "measured_gsd_cm": measured_gsd_cm,
        "pixels_per_square": pixels_per_square,
        "dominant_freq_pixels": dominant_freq_pixels,
        "gsd_from_fov_m": gsd_from_fov_m,
        "power_spectrum": power_spectrum,
    }


def validate_preset_gsd(
    preset_name: str,
    preset_info: dict,
    n_pixels: int = 512,
    tolerance: float = 0.10,
) -> dict:
    """
    Run complete GSD validation for a drone preset.

    Parameters
    ----------
    preset_name : str
        Name of the drone preset
    preset_info : dict
        Preset information from theoretical_values
    n_pixels : int
        Computational grid resolution
    tolerance : float
        Acceptable relative error (default 10%)

    Returns
    -------
    dict
        Validation results
    """
    scenario = preset_info["scenario"]
    target_design = preset_info["target_design"]
    theoretical_gsd_cm = preset_info["theoretical_gsd_cm"]

    print(f"\n{'=' * 60}")
    print(f"Validating: {preset_name}")
    print(f"{'=' * 60}")
    print(f"  Altitude: {preset_info['altitude_m']:.0f} m")
    print(f"  Theoretical GSD: {theoretical_gsd_cm:.2f} cm")
    print(f"  Field of view: {target_design['field_size_m']:.1f} m")
    print(f"  Square size: {target_design['square_size_cm']:.0f} cm")

    # Create checkerboard target
    print("  Creating checkerboard target...")
    target = CheckerboardTarget(
        size=n_pixels,
        n_squares=target_design["n_squares"],
    )
    ground_truth = target.generate().to(device)

    # Create camera instrument from scenario
    print("  Creating drone camera instrument...")
    instrument_config = scenario.to_instrument_config()
    camera = create_instrument(instrument_config)

    # Simulate imaging
    print("  Simulating drone camera measurement...")
    with torch.no_grad():
        input_field = ground_truth.unsqueeze(0).unsqueeze(0).float()
        measurement = camera.forward(input_field)

    measurement_2d = measurement.squeeze().cpu().numpy()
    gt_2d = ground_truth.cpu().numpy()

    # Measure GSD from checkerboard
    print("  Measuring GSD from checkerboard pattern...")
    gsd_measurement = measure_gsd_from_checkerboard(
        measurement_2d,
        target_design["square_size_m"],
        target_design["field_size_m"],
    )

    measured_gsd_cm = gsd_measurement["measured_gsd_cm"]

    # Compute quality metrics
    print("  Computing quality metrics...")
    gt_norm = (gt_2d - gt_2d.min()) / (gt_2d.max() - gt_2d.min() + 1e-8)
    meas_norm = (measurement_2d - measurement_2d.min()) / (
        measurement_2d.max() - measurement_2d.min() + 1e-8
    )

    ssim_value = compute_ssim(torch.from_numpy(meas_norm), torch.from_numpy(gt_norm))
    psnr_value = psnr(torch.from_numpy(meas_norm), torch.from_numpy(gt_norm))

    # Compare to theoretical
    validation_result = compare_to_theoretical(
        measured=measured_gsd_cm,
        theoretical=theoretical_gsd_cm,
        tolerance=tolerance,
    )

    # Print results
    print("\n  Results:")
    print(f"    SSIM: {ssim_value:.4f}")
    print(f"    PSNR: {psnr_value:.2f} dB")
    print(f"    Measured GSD: {measured_gsd_cm:.2f} cm")
    print(f"    Theoretical GSD: {theoretical_gsd_cm:.2f} cm")
    print(f"    Error: {validation_result.error_percent:.1f}%")
    print(f"    Status: {validation_result.status}")

    return {
        "preset_name": preset_name,
        "altitude_m": preset_info["altitude_m"],
        "theoretical_gsd_cm": theoretical_gsd_cm,
        "measured_gsd_cm": measured_gsd_cm,
        "error_percent": validation_result.error_percent,
        "passed": validation_result.passed,
        "status": validation_result.status,
        "ssim": ssim_value,
        "psnr": psnr_value,
        "pixels_per_square": gsd_measurement["pixels_per_square"],
        "swath_width_m": preset_info["swath_width_m"],
        "ground_truth": gt_2d,
        "measurement": measurement_2d,
        "target_design": target_design,
    }

In [None]:
# Run validation for all presets
validation_results = []

for preset_name, preset_info in theoretical_values.items():
    result = validate_preset_gsd(preset_name, preset_info)
    validation_results.append(result)

print("\n" + "=" * 60)
print("All validations complete!")
print("=" * 60)

---

# Section 4: Results Visualization

In [None]:
# Create comparison figure for all presets
n_presets = len(validation_results)
fig, axes = plt.subplots(n_presets, 3, figsize=(16, 5 * n_presets))

if n_presets == 1:
    axes = axes.reshape(1, -1)

for idx, result in enumerate(validation_results):
    row = idx

    gt = result["ground_truth"]
    meas = result["measurement"]
    diff = gt - meas

    swath_m = result["swath_width_m"]
    extent = [0, swath_m, 0, swath_m]

    # Ground truth
    axes[row, 0].imshow(gt, cmap="gray", extent=extent, origin="lower")
    axes[row, 0].set_title(f"Ground Truth\n{result['preset_name']}")
    axes[row, 0].set_xlabel("Position (m)")
    axes[row, 0].set_ylabel("Position (m)")

    # Measurement
    status_color = "green" if result["passed"] else "red"
    im1 = axes[row, 1].imshow(meas, cmap="gray", extent=extent, origin="lower")
    axes[row, 1].set_title(
        f"Drone Camera Image\n"
        f"Alt={result['altitude_m']:.0f}m, GSD={result['theoretical_gsd_cm']:.1f}cm\n"
        f"SSIM={result['ssim']:.3f}",
        color=status_color,
    )
    axes[row, 1].set_xlabel("Position (m)")

    # Difference
    im2 = axes[row, 2].imshow(diff, cmap="RdBu", extent=extent, origin="lower")
    axes[row, 2].set_title(
        f"Difference\nError: {result['error_percent']:.1f}% - {result['status']}",
        color=status_color,
    )
    axes[row, 2].set_xlabel("Position (m)")
    plt.colorbar(im2, ax=axes[row, 2], fraction=0.046)

plt.tight_layout()
plt.suptitle("SPIDS Drone GSD Validation", fontsize=14, y=1.01)
plt.show()

---

# Section 5: Validation Summary Table

In [None]:
# Create validation summary table
print("\n" + "=" * 100)
print("GSD VALIDATION SUMMARY")
print("=" * 100)
print(
    f"{'Preset':<25} {'Altitude':>10} {'Theoretical':>14} "
    f"{'Measured':>12} {'Error':>8} {'SSIM':>8} {'Status':>10}"
)
print("-" * 100)

all_passed = True
for result in validation_results:
    status_icon = "PASS" if result["passed"] else "FAIL"
    if not result["passed"]:
        all_passed = False

    print(
        f"  {result['preset_name']:<23} "
        f"{result['altitude_m']:>8.0f} m "
        f"{result['theoretical_gsd_cm']:>10.2f} cm "
        f"{result['measured_gsd_cm']:>10.2f} cm "
        f"{result['error_percent']:>7.1f}% "
        f"{result['ssim']:>8.4f} "
        f"{status_icon:>10}"
    )

print("=" * 100)

# Overall status
print(f"\nOverall Validation: {'PASS' if all_passed else 'FAIL'}")
print("Tolerance: 10% relative error")
print("Formula: GSD = H x p / f")

In [None]:
# Create GSD comparison plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: GSD comparison
presets = [r["preset_name"].replace("drone_", "") for r in validation_results]
theoretical = [r["theoretical_gsd_cm"] for r in validation_results]
measured = [r["measured_gsd_cm"] for r in validation_results]

x = np.arange(len(presets))
width = 0.35

bars1 = axes[0].bar(
    x - width / 2, theoretical, width, label="Theoretical", color="steelblue", alpha=0.8
)
bars2 = axes[0].bar(x + width / 2, measured, width, label="Measured", color="coral", alpha=0.8)

axes[0].set_ylabel("GSD (cm)")
axes[0].set_xlabel("Drone Preset")
axes[0].set_title("Theoretical vs Measured GSD")
axes[0].set_xticks(x)
axes[0].set_xticklabels(presets, rotation=15, ha="right")
axes[0].legend()

# Add error annotations
for i, result in enumerate(validation_results):
    error = result["error_percent"]
    color = "green" if result["passed"] else "red"
    axes[0].annotate(
        f"{error:.1f}%",
        xy=(i + width / 2, measured[i]),
        xytext=(0, 5),
        textcoords="offset points",
        ha="center",
        fontsize=10,
        fontweight="bold",
        color=color,
    )

# Right: Error percentage
errors = [r["error_percent"] for r in validation_results]
colors = ["green" if r["passed"] else "red" for r in validation_results]

bars = axes[1].bar(x, errors, color=colors, alpha=0.8)
axes[1].axhline(y=10, color="red", linestyle="--", linewidth=2, label="10% tolerance")
axes[1].set_ylabel("Error (%)")
axes[1].set_xlabel("Drone Preset")
axes[1].set_title("GSD Error vs Tolerance")
axes[1].set_xticks(x)
axes[1].set_xticklabels(presets, rotation=15, ha="right")
axes[1].legend()
axes[1].set_ylim(0, 15)

plt.tight_layout()
plt.show()

---

# Section 6: Detailed Analysis

In [None]:
# Detailed analysis of each preset
print("\nDETAILED VALIDATION ANALYSIS")
print("=" * 80)

for result in validation_results:
    preset_info = theoretical_values[result["preset_name"]]

    print(f"\n{result['preset_name'].upper()}")
    print("-" * 40)

    print("  Camera Configuration:")
    print(f"    Altitude: {result['altitude_m']:.0f} m")
    print(f"    Focal Length: {preset_info['focal_length_mm']:.0f} mm")
    print(f"    Pixel Pitch: {preset_info['pixel_pitch_um']:.1f} um")
    print(f"    Swath Width: {result['swath_width_m']:.1f} m")

    print("\n  GSD Analysis:")
    print(f"    Theoretical GSD: {result['theoretical_gsd_cm']:.2f} cm")
    print(f"    Measured GSD: {result['measured_gsd_cm']:.2f} cm")
    print(f"    Error: {result['error_percent']:.2f}%")

    print("\n  Checkerboard Analysis:")
    print(f"    Square Size: {result['target_design']['square_size_cm']:.0f} cm")
    print(f"    Pixels per Square: {result['pixels_per_square']:.1f}")

    print("\n  Quality Metrics:")
    print(f"    SSIM: {result['ssim']:.4f}")
    print(f"    PSNR: {result['psnr']:.2f} dB")

    # Validation status
    status_str = "PASS" if result["passed"] else "FAIL"
    print(f"\n  Validation: {status_str} (tolerance: 10%)")

print("\n" + "=" * 80)

---

# Section 7: Additional Drone Presets Overview

In [None]:
# Show all available drone presets and their theoretical GSD
all_drone_presets = list_scenario_presets("drone")

print("ALL SPIDS DRONE PRESETS - Theoretical GSD Values")
print("=" * 95)
print(f"{'Preset':<28} {'Altitude':>10} {'Lens':>12} {'Pixel':>10} {'GSD':>10} {'Swath':>10}")
print("-" * 95)

for preset_name in all_drone_presets:
    scenario = get_scenario_preset(preset_name)
    info = scenario.get_info()

    altitude_m = info["altitude_m"]
    pixel_pitch_um = info["pixel_pitch_um"]
    focal_length_mm = info["focal_length_mm"]

    # Calculate theoretical GSD
    gsd_cm = GSDBaseline.gsd(altitude_m, pixel_pitch_um * 1e-6, focal_length_mm * 1e-3) * 100

    swath_m = info.get("swath_width_m", 0)
    lens = info.get("lens", "N/A")

    # Mark validated presets
    validated = " *" if preset_name in presets_to_validate else ""

    print(
        f"{preset_name + validated:<28} {altitude_m:>8.0f} m {lens:>12} "
        f"{pixel_pitch_um:>8.1f} um {gsd_cm:>8.1f} cm {swath_m:>8.0f} m"
    )

print("=" * 95)
print("\n* = Validated in this notebook")
print(f"\nTotal: {len(all_drone_presets)} drone presets available")

---

# Section 8: Conclusions

In [None]:
# Final summary and conclusions
passed_count = sum(1 for r in validation_results if r["passed"])
total_count = len(validation_results)
avg_error = np.mean([r["error_percent"] for r in validation_results])
avg_ssim = np.mean([r["ssim"] for r in validation_results])

print("\n" + "=" * 70)
print("VALIDATION CONCLUSIONS")
print("=" * 70)

print("\nHypothesis: Measured GSD matches GSD = H x p / f within 10% error")
print("\nResults:")
print(f"  - Presets validated: {passed_count}/{total_count}")
print(f"  - Average error: {avg_error:.1f}%")
print(f"  - Average SSIM: {avg_ssim:.4f}")

if all_passed:
    print("\nCONCLUSION: HYPOTHESIS CONFIRMED")
    print("  SPIDS GSD calculations match theoretical predictions within 10% tolerance.")
    print("  The GSD formula GSD = H x p / f is validated for drone imaging.")
else:
    failed_presets = [r["preset_name"] for r in validation_results if not r["passed"]]
    print("\nCONCLUSION: HYPOTHESIS PARTIALLY CONFIRMED")
    print("  Some presets exceeded 10% error tolerance:")
    for preset in failed_presets:
        print(f"    - {preset}")

print(f"\nValidation completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("=" * 70)

---

## Expected Results (Reference)

Based on the GSD formula, the expected results are:

| Preset | Altitude | Theoretical GSD | Expected Measured | Expected Error | Status |
|--------|----------|-----------------|-------------------|----------------|--------|
| drone_50m_survey | 50m | 6.5 cm | 6.2-7.0 cm | <10% | PASS |
| drone_100m_mapping | 100m | 13.0 cm | 12.0-14.0 cm | <10% | PASS |

**Note**: Actual results may vary based on:
- Wave propagation effects (coherent vs incoherent)
- Discretization effects (pixel grid alignment)
- Target design (checkerboard square size vs GSD ratio)

---

## References

- GSD Formula: GSD = H x p / f (altitude x pixel_pitch / focal_length)
- Swath Width: SW = H x W_sensor / f
- SPIDS Drone Presets: `spids.scenarios.get_scenario_preset()`

---

## Related Materials

- **[Learning 4: GSD Basics](../../notebooks/learning_04_gsd_basics.ipynb)**: GSD fundamentals
- **[Learning 5: Drone Altitudes](../../notebooks/learning_05_drone_altitudes.ipynb)**: Altitude comparison
- **[Python API: Drone Mapping](../../python_api/07_drone_mapping.py)**: Production workflow

In [None]:
print("\n" + "=" * 70)
print("Validation 4: Drone GSD Validation - COMPLETE")
print("=" * 70)
print("\nResults Summary:")
for r in validation_results:
    status = "PASS" if r["passed"] else "FAIL"
    print(f"  [{status}] {r['preset_name']}: {r['error_percent']:.1f}% error")
print("\nPhase 4 Drone Learning Track - Validation Complete!")