"""
Script to compute zonal statistics for GEDI L4A points against a multiband HLS raster.
Steps:
1. Load GEDI L4A geospatial points (optionally buffer it into polygons - ## check errors on rastererization).
2. Load single band image - e.g. HLS composite raster for one band.
3. Reproject points to raster CRS and rasterize zones.
4. Compute per-zone statistics (e.g., mean) using xrspatial.zonal_stats.
5. Join results back to the input GeoDataFrame and export to GeoPackage.
"""

In [26]:
import os
import glob
import numpy as np
import geopandas as gpd
import rioxarray
from rasterio import features
from shapely.geometry import Point
from xrspatial import zonal_stats
import pandas as pd
from pyproj import CRS

In [27]:

def zonal_stats_raster(gpkg_points_path,img_path, ZONE_ID_COL = 'zone', BUFFER_METERS =0, band_name = 'b'):
        
    ####################################################
    # 1) Open raster as xarray (single band)
    ####################################################
    
    da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
    if 'band' in da.dims:
        da = da.squeeze('band', drop=True)
        
    raster_crs = CRS.from_user_input(da.rio.crs)
    transform  = da.rio.transform()
    out_shape  = da.shape  # (rows, cols)
    nodata     = da.rio.nodata
    
    da = da.where(da != da.rio.nodata)
    
    
    
    ####################################################
    # 2) Read points
    ####################################################
    gdf = gpd.read_file(gpkg_points_path)
    
    # Create ID for the zonal stats
    if ZONE_ID_COL not in gdf.columns:
        gdf = gdf.reset_index(drop=True)
        gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)
    
    
    
    ####################################################
    # 3) Reproject to raster CRS and buffer to polygons 
    ####################################################
    
    gdf_proj = gdf.to_crs(raster_crs)
    
    # BUffer
    if BUFFER_METERS > 0:
        gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)
    
    
    ####################
    # 4) Build (geometry, id) tuples for rasterization
    ####################
    geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))
    
    
    
    ####################################################
    # 5) Rasterize zones (same shape/transform as raster)
    ####################################################
    zones_arr = features.rasterize(
        geom,
        out_shape=out_shape,
        transform=transform,
        fill= nodata,
        nodata = nodata,
        masked = True,
        # fill=0,# background (no zone)
        dtype="int32",
        # all_touched=True
    )
    
    
    
    ####################################################
    # # 6) Wrap zones into an xarray aligned with the raster
    ####################################################
    
    zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
    zones_xarr.data = zones_arr
    # zones_xarr.attrs["nodata"] = nodata
    # zones_xarr
    
    
    ####################################################
    # 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
    ####################################################
    
    # If your raster has nodata, pass it so stats ignore it
    zs_df = zonal_stats(
        zones=zones_xarr,      # your integer-labeled zone raster
        values=da,             # the raster with values to summarize
        # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
        stats_funcs=['mean'],
        nodata_values=nodata,   # very important → ensures nodata is ignored
        return_type='pandas.DataFrame'
    )
    
    
    ####################################################
    # # Keep only "mean" and rename the band
    ####################################################
    # zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()
    
    # Rename column -> e.g. "evi2_mean"
    zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})
    
    
    ####################################################
    # 9) create output with zone id and values
    ####################################################
    # out = (gdf[[ZONE_ID_COL]]
    #        .drop_duplicates()
    #        .merge(zs_df, on=ZONE_ID_COL, how="left"))
    
   
    
    return zs_df
    
    

In [28]:
ref_year = 2022
ref_tile = 89

In [29]:
gpkg_points_path = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm.gpkg'

In [30]:
# Multiband image path
# img_path = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/yearly/br_af_grid90km_evi2_p95/tile_089/bands_vrt/HLS_89_01-01_12-30_2022_2022_percentile95.0evi2_Blue.vrt'

# Directory with your raster .vrt files
# dir_img = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/yearly/br_af_grid90km_evi2_p95/tile_{ref_tile:03}/bands_vrt'
dir_img = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/GoogleEmbeddings/2022/bands_vrt_2022/'
# dir_img

In [31]:
output_gpkg_zonalstats_fn = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm_zonal_GEV.gpkg'

In [32]:
# Define excluded band suffixes
# exclude_bands = ['ValidMask.vrt', 'count.vrt', 'yearDate.vrt', 'JulianDate.vrt']
exclude_bands = []


In [33]:
ZONE_ID_COL = 'zone'
BUFFER_METERS = 0

In [34]:
img_paths = glob.glob(os.path.join(dir_img, "*.vrt"))


img_paths = [
    f for f in img_paths
    if not any(f.endswith(ex) for ex in exclude_bands)
]

In [35]:
# img_paths

In [36]:
# Load the base gdf
gdf = gpd.read_file(gpkg_points_path)



In [37]:
# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

In [38]:
# fname = os.path.basename(img_paths[0])
# # fname
# # # -> 'GSE.V1.ANNUAL.2024.A00.vrt'   
# band_name = fname.replace(".vrt", "")
# band_name

In [39]:
%%time
# Loop through rasters
for img_path in img_paths:
   
    fname = os.path.basename(img_path)
    # fname
    # # -> 'GSE.V1.ANNUAL.2024.A00.vrt'   
    band_name = fname.replace(".vrt", "")
    
    print(f'Zonal stats: {band_name}')

    
    # Run your function (it returns a gdf with ZONE_ID + stats)
    out = zonal_stats_raster(gpkg_points_path, img_path, ZONE_ID_COL, BUFFER_METERS, band_name)
    
    # print(f'Merge: {band_name}')
    # # Merge results back into main gdf (on zone_id)
    gdf = gdf.merge(out, on=ZONE_ID_COL, how="left")


# ~10k points 64 bands 90 km tile
# Wall time: 17min 48s

Zonal stats: GSE.V1.ANNUAL.2022.A00
Zonal stats: GSE.V1.ANNUAL.2022.A00
Zonal stats: GSE.V1.ANNUAL.2022.A01
Zonal stats: GSE.V1.ANNUAL.2022.A01
Zonal stats: GSE.V1.ANNUAL.2022.A02
Zonal stats: GSE.V1.ANNUAL.2022.A02
Zonal stats: GSE.V1.ANNUAL.2022.A03
Zonal stats: GSE.V1.ANNUAL.2022.A03
Zonal stats: GSE.V1.ANNUAL.2022.A04
Zonal stats: GSE.V1.ANNUAL.2022.A04
Zonal stats: GSE.V1.ANNUAL.2022.A05
Zonal stats: GSE.V1.ANNUAL.2022.A05
Zonal stats: GSE.V1.ANNUAL.2022.A06
Zonal stats: GSE.V1.ANNUAL.2022.A06
Zonal stats: GSE.V1.ANNUAL.2022.A07
Zonal stats: GSE.V1.ANNUAL.2022.A07
Zonal stats: GSE.V1.ANNUAL.2022.A08
Zonal stats: GSE.V1.ANNUAL.2022.A08
Zonal stats: GSE.V1.ANNUAL.2022.A09
Zonal stats: GSE.V1.ANNUAL.2022.A09
Zonal stats: GSE.V1.ANNUAL.2022.A10
Zonal stats: GSE.V1.ANNUAL.2022.A10
Zonal stats: GSE.V1.ANNUAL.2022.A11
Zonal stats: GSE.V1.ANNUAL.2022.A11
Zonal stats: GSE.V1.ANNUAL.2022.A12
Zonal stats: GSE.V1.ANNUAL.2022.A12
Zonal stats: GSE.V1.ANNUAL.2022.A13
Zonal stats: GSE.V1.ANNUAL.2

In [40]:
# gdf.columns

In [41]:
gdf_stats_nona = gdf.dropna()

In [42]:
len(gdf_stats_nona)

20225

In [None]:
# Export gdf_stats
gdf_stats_nona.to_file(output_gpkg_zonalstats_fn)


In [43]:
output_gpkg_zonalstats_fn

'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t89_veg2022_outrm_zonal_GEV.gpkg'

# Test

In [23]:
####################################################
# 1) Open raster as xarray (single band)
####################################################

da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
if 'band' in da.dims:
    da = da.squeeze('band', drop=True)
    
raster_crs = CRS.from_user_input(da.rio.crs)
transform  = da.rio.transform()
out_shape  = da.shape  # (rows, cols)
nodata     = da.rio.nodata

da = da.where(da != da.rio.nodata)



####################################################
# 2) Read points
####################################################

# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)



####################################################
# 3) Reproject to raster CRS and buffer to polygons 
####################################################

gdf_proj = gdf.to_crs(raster_crs)

# BUffer
if BUFFER_METERS > 0:
    gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)


####################
# 4) Build (geometry, id) tuples for rasterization
####################
geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))








In [25]:
print(out_shape, transform, nodata)

(14379, 14156) | 10.00, 0.00, 5466330.00|
| 0.00,-10.00, 7362430.00|
| 0.00, 0.00, 1.00| -inf


In [24]:

####################################################
# 5) Rasterize zones (same shape/transform as raster)
####################################################
zones_arr = features.rasterize(
    geom,
    out_shape=out_shape,
    transform=transform,
    fill= nodata,
    nodata = nodata,
    masked = True,
    # fill=0,# background (no zone)
    dtype="int32",
    # all_touched=True
)


OverflowError: cannot convert float infinity to integer

In [None]:
####################################################
# # 6) Wrap zones into an xarray aligned with the raster
####################################################

zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
zones_xarr.data = zones_arr
# zones_xarr.attrs["nodata"] = nodata
# zones_xarr


####################################################
# 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
####################################################

# If your raster has nodata, pass it so stats ignore it
zs_df = zonal_stats(
    zones=zones_xarr,      # your integer-labeled zone raster
    values=da,             # the raster with values to summarize
    # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
    stats_funcs=['mean'],
    nodata_values=nodata,   # very important → ensures nodata is ignored
    return_type='pandas.DataFrame'
)


In [None]:
zs_df.head()

In [None]:
gdf = gdf.merge(zs_df, on=ZONE_ID_COL, how="left")

In [None]:
gdf.columns

In [None]:




####################################################
# # Keep only "mean" and rename the band
####################################################
# zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()

# Rename column -> e.g. "evi2_mean"
zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})


####################################################
# 9) create output with zone id and values
####################################################
out = (gdf[[ZONE_ID_COL]]
       .drop_duplicates()
       .merge(zs_df, on=ZONE_ID_COL, how="left"))



In [None]:
gdf.columns

In [None]:
gdf.columns