# Landsat EVI/NDVI Change Detection with H3 Aggregation

Streaming approach for memory-efficient processing of large raster datasets.

**Key design:**
- Process one time period at a time → aggregate to H3 → discard raster
- Controlled concurrency (default 2) - never holds many rasters in memory
- DuckDB H3 extension for fast spatial aggregation
- Arrow for zero-copy data handoff
- Only final pandas DataFrame for small result

In [3]:
# Core imports
import os
import concurrent.futures
from typing import List, Optional

import numpy as np
import pyarrow as pa
import duckdb
import odc.stac
import boto3
from pystac_client import Client
from pyproj import Transformer
from dotenv import load_dotenv
import h3
import dask
from dask.diagnostics import ProgressBar

# Configure dask for optimal S3 fetching
dask.config.set({
    'array.slicing.split_large_chunks': False,  # Avoid warnings
    'num_workers': 8,  # Parallel chunk fetches
})

<dask.config.set at 0x10c9679e0>

In [15]:
# Install H3 extension once (persists to disk)
duckdb.sql("INSTALL h3 FROM community")

def get_con():
    """Create a new DuckDB connection with H3 loaded."""
    con = duckdb.connect()
    con.sql("LOAD h3")
    return con

In [16]:
def calculate_resolution_for_h3(
    h3_res: int, 
    native_resolution: int = 30,
    pixels_per_hex_edge: int = 6
) -> int:
    """
    Calculate appropriate Landsat resolution based on H3 resolution.
    
    Uses ~4-10 Landsat pixels per H3 hex edge for good sampling.
    
    Parameters
    ----------
    h3_res : int
        H3 resolution (0-15)
    native_resolution : int
        Native sensor resolution in meters (default 30 for Landsat)
    pixels_per_hex_edge : int
        Target pixels per hex edge (default 6)
    
    Returns
    -------
    int
        Resolution in meters, rounded to native resolution multiple
    """
    hex_edge_m = h3.average_hexagon_edge_length(h3_res, unit='m')
    target_resolution = hex_edge_m / pixels_per_hex_edge
    
    resolution_rounded = max(
        round(target_resolution / native_resolution) * native_resolution,
        native_resolution
    )
    
    return resolution_rounded

In [17]:
def configure_aws_access():
    """Configure AWS credentials for odc-stac access."""
    load_dotenv()
    
    session = boto3.Session(
        aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
        aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
        region_name='us-west-2'
    )
    
    odc.stac.configure_s3_access(
        aws_session=session,
        requester_pays=True
    )
    
    return session

In [18]:
def stratified_sample_items(items, items_per_group: int = 3, max_items: int = None):
    """Select clearest images stratified by month AND path/row."""
    from collections import defaultdict
    import math
    
    groups = defaultdict(list)
    for item in items:
        dt = item.datetime
        path = item.properties.get('landsat:wrs_path', 0)
        row = item.properties.get('landsat:wrs_row', 0)
        key = (dt.year, dt.month, path, row)
        groups[key].append(item)
    
    for key in groups:
        groups[key] = sorted(
            groups[key], 
            key=lambda x: x.properties.get('eo:cloud_cover', 100)
        )[:items_per_group]
    
    total_selected = sum(len(g) for g in groups.values())
    
    if max_items is None or total_selected <= max_items:
        selected = []
        for key in sorted(groups.keys()):
            selected.extend(groups[key])
        return selected
    
    keep_ratio = max_items / total_selected
    selected = []
    for key in sorted(groups.keys()):
        group_items = groups[key]
        keep_count = max(1, math.floor(len(group_items) * keep_ratio))
        selected.extend(group_items[:keep_count])
    
    if len(selected) > max_items:
        selected = sorted(selected, key=lambda x: x.properties.get('eo:cloud_cover', 100))[:max_items]
    
    return selected


def process_single_period_to_h3(
    time_of_interest: str,
    bounds: List[float],
    h3_res: int = 7,
    red_band: str = "red",
    blue_band: str = "blue",
    nir_band: str = "nir08",
    collection: str = "landsat-c2-l2",
    cloud_threshold: int = 20,
    evi: bool = True,
    native_resolution: int = 30,
    max_items: int = 50,
    items_per_group: int = 2,
    agg_func: str = "mean",
    verbose: bool = False,
) -> Optional[pa.Table]:
    """
    Load ONE time period, aggregate to H3 immediately, return small result.
    Optimized for parallel I/O and minimal memory.
    """
    import time
    
    pixel_spacing = calculate_resolution_for_h3(h3_res, native_resolution)
    year_tag = time_of_interest[:4]
    
    if verbose:
        print(f"  [{year_tag}] Resolution: {pixel_spacing}m")
    
    configure_aws_access()
    
    catalog = Client.open("https://earth-search.aws.element84.com/v1")
    
    items = catalog.search(
        collections=[collection],
        bbox=bounds,
        datetime=time_of_interest,
        query={"eo:cloud_cover": {"lt": cloud_threshold}},
    ).item_collection()
    
    if len(items) == 0:
        if verbose:
            print(f"  [{year_tag}] No items found")
        return None
    
    if items_per_group > 0:
        sampled_items = stratified_sample_items(items, items_per_group, max_items)
        if verbose:
            print(f"  [{year_tag}] {len(sampled_items)} items from {len(items)} total")
    else:
        items_sorted = sorted(items, key=lambda x: x.properties.get('eo:cloud_cover', 100))
        sampled_items = items_sorted[:max_items]
    
    try:
        bands = [red_band, nir_band, blue_band] if evi else [red_band, nir_band]
        t0 = time.time()
        
        # Use chunks for parallel S3 fetching via dask
        ds = odc.stac.load(
            sampled_items,
            crs="EPSG:3857",
            bands=bands,
            resolution=pixel_spacing,
            bbox=bounds,
            chunks={'x': 2048, 'y': 2048, 'time': 1},  # Parallel chunk loading
            fail_on_error=False,
        )
        
        if verbose:
            print(f"  [{year_tag}] Load graph built: {time.time() - t0:.1f}s")
            
    except Exception as e:
        if verbose:
            print(f"  [{year_tag}] Error: {e}")
        return None
    
    if not ds.data_vars:
        return None
    
    t0 = time.time()
    
    # Build lazy computation graph - float32 for memory efficiency
    ds = ds.astype('float32') * 0.0000275 - 0.2
    
    # Calculate vegetation index (still lazy)
    if evi:
        import xarray as xr
        nir = ds[nir_band]
        red = ds[red_band]
        blue = ds[blue_band]
        denominator = nir + 6 * red - 7.5 * blue + 1
        vi_arr = xr.where(
            denominator > 1.0,
            2.5 * ((nir - red) / denominator),
            np.float32(np.nan)
        )
    else:
        vi_arr = (ds[nir_band] - ds[red_band]) / (ds[nir_band] + ds[red_band])
        vi_arr = vi_arr.clip(min=-1.0, max=1.0)
    
    # Max across time - SINGLE .compute() triggers all I/O + compute in parallel
    arr = vi_arr.max(dim="time", skipna=True).compute()
    
    if verbose:
        print(f"  [{year_tag}] Fetch+VI+max: {time.time() - t0:.1f}s")
    
    t0 = time.time()
    
    # Extract coordinates once
    x_coords = arr.coords['x'].values
    y_coords = arr.coords['y'].values
    ny, nx = len(y_coords), len(x_coords)
    
    # Build coordinate arrays - use broadcasting to avoid meshgrid memory copy
    xx = np.tile(x_coords, ny)
    yy = np.repeat(y_coords, nx)
    values = arr.values.ravel()
    
    # Filter invalid values BEFORE reprojection (fewer coords to transform)
    mask = np.isfinite(values)
    xx_valid = xx[mask]
    yy_valid = yy[mask]
    values_valid = values[mask].astype('float64')
    
    del arr, xx, yy, values  # Free memory
    
    # Reproject only valid points
    transformer = Transformer.from_crs("EPSG:3857", "EPSG:4326", always_xy=True)
    lons, lats = transformer.transform(xx_valid, yy_valid)
    
    del xx_valid, yy_valid
    
    if verbose:
        print(f"  [{year_tag}] Reproject {len(lats):,} pts: {time.time() - t0:.1f}s")
    
    t0 = time.time()
    
    # Create Arrow table
    py_table = pa.table({
        'lat': pa.array(lats, type=pa.float64()),
        'lon': pa.array(lons, type=pa.float64()),
        'data': pa.array(values_valid, type=pa.float64())
    })
    
    del lats, lons, values_valid
    
    # H3 aggregation
    con = get_con()
    agg_map = {'mean': 'AVG', 'sum': 'SUM', 'max': 'MAX', 'min': 'MIN', 'median': 'MEDIAN'}
    sql_agg = agg_map.get(agg_func, 'AVG')
    
    h3_result = con.execute(f"""
        SELECT 
            h3_latlng_to_cell_string(lat, lon, {h3_res}) AS hex,
            {sql_agg}(data)::DOUBLE AS metric
        FROM py_table
        WHERE data IS NOT NULL AND isfinite(data)
        GROUP BY 1
    """).fetch_arrow_table()
    
    if verbose:
        print(f"  [{year_tag}] H3 agg: {time.time() - t0:.1f}s → {h3_result.num_rows:,} cells")
    
    return h3_result

In [19]:
def combine_h3_results(h3_tables: List[pa.Table], agg_func: str = "mean") -> pa.Table:
    """
    Combine multiple H3 aggregation results into one.
    
    Re-aggregates overlapping hexagons across time periods.
    """
    if not h3_tables:
        return None
    
    combined = pa.concat_tables(h3_tables)
    
    con = get_con()
    agg_map = {'mean': 'AVG', 'sum': 'SUM', 'max': 'MAX', 'min': 'MIN', 'median': 'MEDIAN'}
    sql_agg = agg_map.get(agg_func, 'AVG')
    
    return con.execute(f"""
        SELECT hex, {sql_agg}(metric)::DOUBLE AS metric
        FROM combined
        GROUP BY hex
    """).fetch_arrow_table()

In [20]:
def compute_period_diff(h3_first: pa.Table, h3_second: pa.Table):
    """
    Compute difference between two H3 aggregated periods.
    
    Arrow in, pandas out (final result is small).
    """
    con = get_con()
    
    return con.sql("""
        SELECT 
            h3_first.hex,
            ROUND((h3_second.metric - h3_first.metric) * 100, 3) AS pct_change_evi,
            ROUND(h3_first.metric::DOUBLE, 3) AS early_avg,
            ROUND(h3_second.metric::DOUBLE, 3) AS current_avg
        FROM h3_first 
        INNER JOIN h3_second ON h3_first.hex = h3_second.hex
    """).df()

In [21]:
def run_parallel(fn, arg_list, max_workers: int = 24):
    """Execute function in parallel with ThreadPoolExecutor."""
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
        return list(pool.map(fn, arg_list))

In [22]:
def process_years_streaming(
    bounds: List[float],
    years: List[int],
    h3_res: int,
    evi: bool = True,
    cloud_threshold: int = 30,
    max_items: int = 50,
    items_per_group: int = 2,
    agg_func: str = "mean",
    verbose: bool = False
) -> Optional[pa.Table]:
    """
    Process multiple years with concurrency = len(years).
    
    All years in a period run in parallel, but periods are sequential.
    """
    time_periods = [f"{year}-04-20/{year}-11-10" for year in years]
    concurrency = len(time_periods)
    
    h3_results = []
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as pool:
        futures = {}
        for period in time_periods:
            future = pool.submit(
                process_single_period_to_h3,
                time_of_interest=period,
                bounds=bounds,
                h3_res=h3_res,
                evi=evi,
                cloud_threshold=cloud_threshold,
                max_items=max_items,
                items_per_group=items_per_group,
                agg_func=agg_func,
                verbose=verbose
            )
            futures[future] = period
        
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            if result is not None:
                h3_results.append(result)
    
    if not h3_results:
        return None
    
    return combine_h3_results(h3_results, agg_func)

In [23]:
def get_h3_diff(
    bounds: List[float],
    h3_res: int = 7,
    first_years: List[int] = [1992, 1993, 1994],
    second_years: List[int] = [2021, 2022, 2023],
    cloud_threshold: int = 30,
    evi: bool = True,
    agg_func: str = "mean",
    max_items: int = 50,
    items_per_group: int = 2,
    max_concurrent_years: int = 3,  # Limit concurrency since dask handles parallelism within each year
    verbose: bool = False
):
    """
    Compute vegetation index change between two time periods.
    
    Uses dask for parallel I/O within each year, with controlled year-level concurrency.
    """
    all_years = first_years + second_years
    time_periods = [f"{year}-04-20/{year}-11-10" for year in all_years]
    
    # Limit concurrent years to avoid oversubscription (dask handles parallelism within each)
    num_workers = min(max_concurrent_years, len(time_periods))
    
    if verbose:
        print(f"Processing {len(all_years)} years ({num_workers} concurrent)\n")
    
    h3_results = {}
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as pool:
        futures = {}
        for year, period in zip(all_years, time_periods):
            future = pool.submit(
                process_single_period_to_h3,
                time_of_interest=period,
                bounds=bounds,
                h3_res=h3_res,
                evi=evi,
                cloud_threshold=cloud_threshold,
                max_items=max_items,
                items_per_group=items_per_group,
                agg_func=agg_func,
                verbose=verbose
            )
            futures[future] = year
        
        for future in concurrent.futures.as_completed(futures):
            year = futures[future]
            result = future.result()
            if result is not None:
                h3_results[year] = result
    
    # Split results back into first/second periods
    first_tables = [h3_results[y] for y in first_years if y in h3_results]
    second_tables = [h3_results[y] for y in second_years if y in h3_results]
    
    if not first_tables or not second_tables:
        if verbose:
            print("Missing data for one or both periods")
        return None
    
    h3_first = combine_h3_results(first_tables, agg_func)
    h3_second = combine_h3_results(second_tables, agg_func)
    
    if verbose:
        print(f"\nEarly period: {h3_first.num_rows:,} H3 cells")
        print(f"Recent period: {h3_second.num_rows:,} H3 cells")
        print("Computing difference...")
    
    result = compute_period_diff(h3_first, h3_second)
    
    if verbose:
        print(f"Final: {len(result):,} matched hexagons")
    
    return result

In [4]:
def save_to_parquet(df, filename: str):
    """Save DataFrame to parquet."""
    qr=f"copy (from df) to '{filename}' (FORMAT parquet, PARQUET_VERSION v2)"
    duckdb.sql(qr)
    print(f"Saved to {filename}")


## Example Usage

Get your bounding box from [boundingbox.klokantech.com](https://boundingbox.klokantech.com/) — select **CSV** format and paste directly as `bounds = [...]`

In [None]:
# Pripyat/Chernobyl exclusion zone
BOUNDS = [29.1026, 51.045, 30.6678, 51.8343]

In [None]:
# Run analysis
df = get_h3_diff(
    bounds=BOUNDS,
    h3_res=8,
    first_years=[1992, 1993, 1994],
    second_years=[2022, 2023, 2024],
    evi=True,
    max_items=50,
    items_per_group=2,
    verbose=True
)

if df is not None:
    print(df.describe())

In [27]:
# Optionally save results
save_to_parquet(df, 'pripyat_EVI_change_v2_res_8.parquet')

Saved to pripyat_EVI_change_v2_res_8.parquet


## Visualization

In [5]:
df =duckdb.sql("from read_parquet('pripyat_EVI_change_v2_res_8.parquet')").df()

In [None]:
def visualize_h3_diff(
    df,
    mapbox_token: str,
    column: str = 'pct_change_evi',
    opacity: float = 0.3592,
    coverage: float = 1
):
    """
    Visualize H3 difference data with lonboard.
    
    Returns (map, h3_layer) so you can mutate colors without reloading.
    """
    from lonboard import Map, H3HexagonLayer, BitmapTileLayer
    from lonboard.colormap import apply_continuous_cmap
    from matplotlib.colors import TwoSlopeNorm, Normalize
    from palettable.scientific.diverging import Roma_20
    
    # Use TwoSlopeNorm for change columns (center at 0), linear for averages
    if 'change' in column:
        norm = TwoSlopeNorm(
            vmin=df[column].quantile(0.05),
            vcenter=0,
            vmax=df[column].quantile(0.95)
        )
    else:
        norm = Normalize(
            vmin=df[column].quantile(0.05),
            vmax=df[column].quantile(0.95)
        )
    
    normalized = norm(df[column])
    colors = apply_continuous_cmap(normalized, Roma_20, alpha=1)
    
    h3_layer = H3HexagonLayer.from_pandas(
        df,
        get_hexagon=df["hex"],
        get_fill_color=colors,
        high_precision=True,
        auto_highlight=True,
        extruded=False,
        stroked=False,
        coverage=coverage,
        opacity=opacity,
    )
    
    basemap = BitmapTileLayer(
        data=f"https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v12/tiles/512/{{z}}/{{x}}/{{y}}@2x?access_token={mapbox_token}",
        tile_size=512,
        max_requests=-1,
    )
    
    return Map(layers=[basemap, h3_layer]), h3_layer


def update_colors(h3_layer, df, column: str = 'pct_change_evi'):
    """Update layer colors without reloading data."""
    from lonboard.colormap import apply_continuous_cmap
    from matplotlib.colors import TwoSlopeNorm, Normalize
    from palettable.scientific.diverging import Roma_20
    
    if 'change' in column:
        norm = TwoSlopeNorm(
            vmin=df[column].quantile(0.05),
            vcenter=0,
            vmax=df[column].quantile(0.95)
        )
    else:
        norm = Normalize(
            vmin=df[column].quantile(0.05),
            vmax=df[column].quantile(0.95)
        )
    
    normalized = norm(df[column])
    h3_layer.get_fill_color = apply_continuous_cmap(normalized, Roma_20, alpha=1)

In [None]:
# Create map (loads data once)
load_dotenv()
MAPBOX_TOKEN = os.environ.get('MAPBOX_TOKEN')
m, h3_layer = visualize_h3_diff(df, MAPBOX_TOKEN)
m

In [None]:
from palettable.scientific.diverging import Roma_15
Roma_15.mpl_colormap

In [None]:
# Switch to 90s average (no reload, just updates colors)
update_colors(h3_layer, df, 'early_avg')

In [None]:
# Switch to 2020s average
update_colors(h3_layer, df, 'current_avg')

In [None]:
# Switch back to change
update_colors(h3_layer, df, 'pct_change_evi')