"""
Script to compute zonal statistics for GEDI L4A points against a multiband HLS raster.
Steps:
1. Load GEDI L4A geospatial points (optionally buffer it into polygons - ## check errors on rastererization).
2. Load single band image - e.g. HLS composite raster for one band.
3. Reproject points to raster CRS and rasterize zones.
4. Compute per-zone statistics (e.g., mean) using xrspatial.zonal_stats.
5. Join results back to the input GeoDataFrame and export to GeoPackage.
"""

In [35]:
import os
import numpy as np
import geopandas as gpd
import rioxarray
from rasterio import features
from shapely.geometry import Point
from xrspatial import zonal_stats
import pandas as pd
from pyproj import CRS

In [36]:
l4a_path = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t89_veg2020_outrm.gpkg'

In [37]:
# Multiband image path
img_path = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/yearly/br_af_grid90km_evi2_p95/tile_089/bands_vrt/HLS_89_01-01_12-30_2022_2022_percentile95.0evi2_Blue.vrt'

In [38]:
# Get band name
fname = os.path.basename(img_path)
# -> "HLS_89_01-01_12-30_2022_2022_percentile95.0evi2_Blue.vrt"

fn_parts = fname.replace(".vrt", "").split("_")
# ['HLS', '89', '01-01', '12-30', '2022', '2022', 'percentile95.0evi2', 'Blue']

# Rebuild the string to variable name
band_name = f"{fn_parts[0]}.{fn_parts[4]}-{fn_parts[2]}.{fn_parts[5]}-{fn_parts[3]}.{fn_parts[-1]}"
band_name

'HLS.2022-01-01.2022-12-30.Blue'

In [39]:
# BUFFER_METERS = 12.5         # set >0 to make polygons around each point; set 0 to treat each point as a single-pixel sample
BUFFER_METERS = 0         
ZONE_ID_COL   = "zone"   # name for per-feature unique IDs

In [40]:
# 1) Open raster as xarray (single band)
da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
if 'band' in da.dims:
    da = da.squeeze('band', drop=True)
    
raster_crs = CRS.from_user_input(da.rio.crs)
transform  = da.rio.transform()
out_shape  = da.shape  # (rows, cols)
nodata     = da.rio.nodata

da = da.where(da != da.rio.nodata)

In [41]:
# temp_tif_fn = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/temp_blue_img.tif'
# da.rio.to_raster(temp_tif_fn)

In [42]:
# 2) Read points
gdf = gpd.read_file(l4a_path)

# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

In [43]:
# 3) Reproject to raster CRS and buffer to polygons 
gdf_proj = gdf.to_crs(raster_crs)


# BUffer
if BUFFER_METERS > 0:
    gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)


In [44]:
# temp_pts_prj = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/temp_refpts_prj.gpkg'
# gdf_proj.to_file(temp_pts_prj)

In [45]:
# 4) Build (geometry, id) tuples for rasterization
geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))


In [46]:
# geom
# ?features.rasterize

In [47]:
# 5) Rasterize zones (same shape/transform as raster)
zones_arr = features.rasterize(
    geom,
    out_shape=out_shape,
    transform=transform,
    fill= nodata,
    nodata = nodata,
    masked = True,
    # fill=0,# background (no zone)
    dtype="int32",
    # all_touched=True
)

In [48]:
# # 6) Wrap zones into an xarray aligned with the raster
zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
zones_xarr.data = zones_arr
# zones_xarr.attrs["nodata"] = nodata
# zones_xarr

In [49]:
# temp_tif_fn = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/temp_zones_img.tif'
# zones_xarr.rio.to_raster(temp_tif_fn)

In [50]:
# ?zonal_stats

In [51]:
%%time
# 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
#    If your raster has nodata, pass it so stats ignore it
zs_df = zonal_stats(
    zones=zones_xarr,      # your integer-labeled zone raster
    values=da,             # the raster with values to summarize
    # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
    stats_funcs=['mean'],
    nodata_values=nodata,   # very important → ensures nodata is ignored
    return_type='pandas.DataFrame'
)


CPU times: user 565 ms, sys: 47 ms, total: 612 ms
Wall time: 611 ms


In [25]:
# zs_df

In [52]:
%%time
from rasterstats import point_query

values = point_query(gdf, img_path)

CPU times: user 24.9 s, sys: 11.2 s, total: 36.1 s
Wall time: 3min


In [54]:
# values

In [26]:
# # Keep only "mean"
# zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()

# Rename column -> e.g. "evi2_mean"
zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})

print(zs_df.head())

   zone  HLS.2022-01-01.2022-12-30.Blue_mean
0  44.0                               0.0243
1  59.0                               0.0116
2  60.0                               0.0221
3  61.0                               0.0137
4  62.0                               0.0149


In [27]:
# 9) Join back to attributes
out = (gdf[[ZONE_ID_COL]]
       .drop_duplicates()
       .merge(zs_df, on=ZONE_ID_COL, how="left"))

# 10) (Optional) Merge to original gdf (one row per feature) and save
gdf_stats = gdf.merge(out, on=ZONE_ID_COL, how="left")

In [28]:
# # 9) Join back to attributes test with gdf_proj
# out = (gdf_proj[[ZONE_ID_COL]]
#        .drop_duplicates()
#        .merge(zs_df, on=ZONE_ID_COL, how="left"))

# # 10) (Optional) Merge to original gdf (one row per feature) and save
# gdf_stats = gdf_proj.merge(out, on=ZONE_ID_COL, how="left")

In [29]:
# gdf_stats["HLS.2022-01-01.2022-12-30.Blue_mean"].isna().sum()

In [30]:
col = f"{band_name}_mean"
# --- Filter to NaN points in the target column ---
nan_pts = gdf_stats[gdf_stats[col].isna()].copy()
print(f"NaN points in '{col}': {len(nan_pts)}")

if nan_pts.empty:
    raise ValueError(f"No NaN points found in column '{col}'.")


NaN points in 'HLS.2022-01-01.2022-12-30.Blue_mean': 598


In [31]:
gdf_stats_nona = gdf_stats.dropna(subset=f"{band_name}_mean")

In [32]:
len(gdf_stats_nona)

10639

In [33]:
l4a_zonalstats_fn = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t89_veg2020_outrm_zonal.gpkg'

In [34]:
# Export gdf_stats
gdf_stats_nona.to_file(l4a_zonalstats_fn)