"""
Script to compute zonal statistics for GEDI L4A points against a multiband HLS raster.
Steps:
1. Load GEDI L4A geospatial points (optionally buffer it into polygons - ## check errors on rastererization).
2. Load single band image - e.g. HLS composite raster for one band.
3. Reproject points to raster CRS and rasterize zones.
4. Compute per-zone statistics (e.g., mean) using xrspatial.zonal_stats.
5. Join results back to the input GeoDataFrame and export to GeoPackage.
"""

In [1]:
import os
import glob
import numpy as np
import geopandas as gpd
import rioxarray
from rasterio import features
from shapely.geometry import Point
from xrspatial import zonal_stats
import pandas as pd
from pyproj import CRS

In [2]:
import sys
sys.path.append('/projects/my-private-bucket/HLS-1DCNN-AGB/code/timeseries_modeling/zonal_stats_vrt_list')
from zonal_stats_vrt import run_zonal_stats 

In [3]:
os.chdir("/tmp")

In [4]:
ref_year = 2022
ref_tile = 89

In [5]:
# s3_private_dir_path = 's3://maap-ops-workspace/private/rodrigo.leite'

In [6]:
s3_public_dir_path = 's3:/maap-ops-workspace/public/rodrigo.leite'
glob.glob(f'/{s3_public_dir_path}/*')

[]

In [7]:
gpkg_points_path = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm.gpkg'
# gpkg_points_path = f'{s3_private_dir_path}/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm.gpkg'
gpkg_points_path

'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t89_veg2022_outrm.gpkg'

In [None]:
# Multiband image path
# img_path = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/yearly/br_af_grid90km_evi2_p95/tile_089/bands_vrt/HLS_89_01-01_12-30_2022_2022_percentile95.0evi2_Blue.vrt'

# Directory with your raster .vrt files
# dir_img = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/yearly/br_af_grid90km_evi2_p95/tile_{ref_tile:03}/bands_vrt'
# dir_img = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/monthly/br_af_grid60km_prj_evi2_max/vrt/'
dir_img = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/monthly/br_af_grid60km_prj_evi2_max/vrt_test/'

# dir_img

In [None]:
output_gpkg_zonalstats_fn = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm_zonal_HLS.gpkg'
output_gpkg_zonalstats_fn

In [None]:
# Define excluded band suffixes
exclude_bands = ['ValidMask.vrt', 'count.vrt', 'yearDate.vrt', 'JulianDate.vrt']


In [None]:
ZONE_ID_COL = 'zone'
BUFFER_METERS = 0

## Testing

In [None]:
img_paths = glob.glob(f'{dir_img}/*.vrt')


img_paths = [
    f for f in img_paths
    if not any(f.endswith(ex) for ex in exclude_bands)
]

In [None]:
# img_paths[-1]

In [None]:

def build_zones_xarr(gpkg_points_path, img_path, ZONE_ID_COL='zone', BUFFER_METERS=0, all_touched=False):
    """Create a zone-labeled xarray aligned to the raster."""
    da = rioxarray.open_rasterio(img_path).squeeze()
    if 'band' in da.dims:
        da = da.squeeze('band', drop=True)

    raster_crs = CRS.from_user_input(da.rio.crs)
    transform = da.rio.transform()
    out_shape = da.shape
    nodata = da.rio.nodata

    gdf = gpd.read_file(gpkg_points_path)
    if ZONE_ID_COL not in gdf.columns:
        gdf = gdf.reset_index(drop=True)
        gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

    gdf = gdf.to_crs(raster_crs)

    if BUFFER_METERS > 0:
        gdf["geometry"] = gdf.geometry.buffer(BUFFER_METERS)

        
     # (geom, value) pairs; ensure values are int32
    shapes = list(zip(gdf.geometry, gdf[ZONE_ID_COL].astype(np.int32)))

    zones_arr = features.rasterize(
        shapes,
        out_shape=out_shape,
        transform=transform,
        fill=nodata,
        dtype="int32",
        all_touched=all_touched
    )

    zones_xarr = da.copy(deep=False)
    zones_xarr.data = zones_arr
    return zones_xarr


def zonal_stats_raster(zones_xarr, img_path, band_name='b', ZONE_ID_COL='zone'):
    """Compute mean zonal stats for a single raster."""
    da = rioxarray.open_rasterio(img_path).squeeze()
    if 'band' in da.dims:
        da = da.squeeze('band', drop=True)
    nodata = da.rio.nodata
    da = da.where(da != nodata)

    zs_df = zonal_stats(
        zones=zones_xarr,
        values=da,
        stats_funcs=['mean'],
        nodata_values=nodata#,
        # return_type='pandas.DataFrame'
    )

    zs_df = zs_df.rename(columns={'mean': f'{band_name}_mean'})
    return zs_df


def run_zonal_stats(gpkg_points_path, dir_img, output_path, ZONE_ID_COL='zone', BUFFER_METERS=0, exclude_bands=None):
    """Run zonal stats for all rasters in directory and save merged output."""
    if exclude_bands is None:
        exclude_bands = []

    img_paths = [
        f for f in glob.glob(os.path.join(dir_img, "*.vrt"))
        if not any(f.endswith(ex) for ex in exclude_bands)
    ]

    gdf = gpd.read_file(gpkg_points_path)
    if ZONE_ID_COL not in gdf.columns:
        gdf = gdf.reset_index(drop=True)
        gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

    zones_xarr = build_zones_xarr(gpkg_points_path, img_paths[0], ZONE_ID_COL, BUFFER_METERS)
    
    for i, img_path in enumerate(img_paths, 1):
        band_name = os.path.basename(img_path).replace('.vrt', '')
        print(f"[{i}/{len(img_paths)}] Processing {band_name}...")
        try:
            out = zonal_stats_raster(zones_xarr, img_path, band_name, ZONE_ID_COL)
            # import pdb; pdb.set_trace();
            gdf = gdf.merge(out, on=ZONE_ID_COL, how='left')
        except Exception as e:
            print(f"❌ Error on {band_name}: {e}")

    gdf.to_file(output_path)
    print(f"✅ Saved: {output_path}")
    return gdf


In [None]:
# zonal_stats?

In [None]:
xarr_test = build_zones_xarr(gpkg_points_path=gpkg_points_path, 
                             img_path = img_paths[0], 
                             ZONE_ID_COL='zone', 
                             BUFFER_METERS=0, 
                             all_touched=True)

In [None]:
np.unique(xarr_test.data)

In [None]:
band_name = os.path.basename(img_paths[0]).replace('.vrt', '')
band_name

In [None]:
myzonal_test = zonal_stats_raster(zones_xarr = xarr_test, 
                                  img_path = img_paths[0], 
                                  band_name=band_name, 
                                  ZONE_ID_COL='zone')

In [None]:
myzonal_test

In [None]:
len(myzonal_test)

In [None]:
gdf_test = gpd.read_file(gpkg_points_path)
# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)
len(gdf_test)

In [None]:
run_zonal_stats(
    gpkg_points_path=gpkg_points_path,
    dir_img=dir_img,
    output_path=output_gpkg_zonalstats_fn,
    ZONE_ID_COL=ZONE_ID_COL, 
    BUFFER_METERS=BUFFER_METERS, 
    exclude_bands=exclude_bands
)

In [None]:
# Test functions

In [None]:
%%time

# Create gpkg array to do zonal_stats
zones_xarr = build_zones_xarr(gpkg_points_path,
                     img_paths[0],
                     ZONE_ID_COL='zone',
                     BUFFER_METERS=0,
                     all_touched=False)

error_band = None
error_idx = None

for i, img_path in enumerate(img_paths):
    try:
        os.chdir("/tmp")
        fname = os.path.basename(img_path)
        band_name = fname.replace(".vrt", "")
        print(f'[{i}] Zonal stats: {band_name}')

        out = zonal_stats_raster(zones_xarr = zones_xarr, 
                                 img_path = img_path, 
                                 ZONE_ID_COL = ZONE_ID_COL, 
                                 BUFFER_METERS = BUFFER_METERS, 
                                 band_name = band_name)
        
        gdf = gdf.merge(out, on=ZONE_ID_COL, how="left")

    except Exception as e:
        print("\n❌ Error at index:", i)
        print("Band name:", band_name)
        print("File:", img_path)
        print("Error:", e, "\n")

        error_band = band_name
        error_idx = i
        break

        

In [None]:
print("Error index:", error_idx)
print("Error band:", error_band)

In [None]:
# Error at index: 95
# Wall time: 52min 54s

In [None]:
start_idx = error_idx if 'error_idx' in locals() and error_idx is not None else 0

for i, img_path in enumerate(img_paths[start_idx:], start=start_idx):
    try:
        os.chdir("/tmp")
        fname = os.path.basename(img_path)
        band_name = fname.replace(".vrt", "")
        print(f'[{i}] Zonal stats: {band_name}')

        out = zonal_stats_raster(gpkg_points_path, img_path, ZONE_ID_COL, BUFFER_METERS, band_name)
        gdf = gdf.merge(out, on=ZONE_ID_COL, how="left")

    except Exception as e:
        print("\n❌ Error at index:", i)
        print("Band name:", band_name)
        print("File:", img_path)
        print("Error:", e, "\n")

        error_band = band_name
        error_idx = i
        break


In [None]:
print("Error index:", error_idx)
print("Error band:", error_band)

In [None]:
gdf.columns

In [None]:
# gdf.head()

In [None]:
output_gpkg_zonalstats_withoutliers_fn = '/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/l4a_t90km_t89_veg2022_zonal_HLS.gpkg'

In [None]:
# Export gdf_stats
gdf.to_file(output_gpkg_zonalstats_withoutliers_fn)

In [None]:
gdf_stats_nona = gdf.dropna()

In [None]:
len(gdf)

In [None]:
len(gdf_stats_nona)

In [None]:
output_gpkg_zonalstats_fn

In [None]:
# Export gdf_stats
gdf_stats_nona.to_file(output_gpkg_zonalstats_fn)


In [None]:
output_gpkg_zonalstats_fn

# Test dask

In [None]:
def zonal_stats_raster_dask(
    gpkg_points_path,
    img_path,
    ZONE_ID_COL="zone",
    BUFFER_METERS=0,
    band_name="b",
    chunks="auto",                 # e.g., {'y': 2048, 'x': 2048}
    all_touched=True,
    compute=True                   # if False, return a lazy dask object
):
    import numpy as np
    import geopandas as gpd
    import rioxarray
    import xarray as xr
    from rasterio import features
    from pyproj import CRS

    # 1) Open raster lazily with chunks
    da = rioxarray.open_rasterio(img_path, chunks=chunks).squeeze()
    if "band" in da.dims:
        da = da.squeeze("band", drop=True)

    raster_crs = CRS.from_user_input(da.rio.crs)
    transform  = da.rio.transform()
    out_shape  = da.shape
    raster_nodata = da.rio.nodata
    if raster_nodata is not None:
        da = da.where(da != raster_nodata)

    # 2) Read points
    gdf = gpd.read_file(gpkg_points_path)
    if ZONE_ID_COL not in gdf.columns:
        gdf = gdf.reset_index(drop=True)
        gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

    # 3) Reproject + buffer
    gdf_proj = gdf.to_crs(raster_crs)
    if BUFFER_METERS and BUFFER_METERS > 0:
        gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)

    # 4) Rasterize zones (this part is still eager, but we re-chunk after)
    shapes = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL].astype(np.int32)))
    zones_np = features.rasterize(
        shapes=shapes,
        out_shape=out_shape,
        transform=transform,
        fill=0,
        dtype="int32",
        all_touched=all_touched,
    )

    # 5) Wrap zones as DataArray and align chunks with the raster
    zones = xr.DataArray(
        zones_np, coords=da.coords, dims=da.dims, name="zones"
    ).chunk(da.chunks)

    # 6) Stack to 1D and drop nodata pixels lazily
    ds = xr.Dataset({"val": da, "zone": zones})
    stacked = ds.stack(z=("y", "x"))
    if raster_nodata is not None:
        stacked = stacked.dropna("z", subset=["val"])

    # 7) Dask-backed groupby/mean by zone (lazy until .compute())
    means = stacked["val"].groupby(stacked["zone"]).mean()

    if not compute:
        # Return lazy xarray object; user can .compute() later
        return means.rename(f"{band_name}_mean")

    # 8) Materialize and convert to pandas only at the end
    means_pd = means.compute().to_dataframe(name=f"{band_name}_mean").reset_index()
    means_pd = means_pd.rename(columns={"zone": ZONE_ID_COL})
    # drop background zone 0 if present
    means_pd = means_pd[means_pd[ZONE_ID_COL] != 0]

    return means_pd


# Test

In [None]:
####################################################
# 1) Open raster as xarray (single band)
####################################################

da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
if 'band' in da.dims:
    da = da.squeeze('band', drop=True)
    
raster_crs = CRS.from_user_input(da.rio.crs)
transform  = da.rio.transform()
out_shape  = da.shape  # (rows, cols)
nodata     = da.rio.nodata

da = da.where(da != da.rio.nodata)



####################################################
# 2) Read points
####################################################

# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)



####################################################
# 3) Reproject to raster CRS and buffer to polygons 
####################################################

gdf_proj = gdf.to_crs(raster_crs)

# BUffer
if BUFFER_METERS > 0:
    gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)


####################
# 4) Build (geometry, id) tuples for rasterization
####################
geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))



####################################################
# 5) Rasterize zones (same shape/transform as raster)
####################################################
zones_arr = features.rasterize(
    geom,
    out_shape=out_shape,
    transform=transform,
    fill= nodata,
    nodata = nodata,
    masked = True,
    # fill=0,# background (no zone)
    dtype="int32",
    # all_touched=True
)






In [None]:
####################################################
# # 6) Wrap zones into an xarray aligned with the raster
####################################################

zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
zones_xarr.data = zones_arr
# zones_xarr.attrs["nodata"] = nodata
# zones_xarr


####################################################
# 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
####################################################

# If your raster has nodata, pass it so stats ignore it
zs_df = zonal_stats(
    zones=zones_xarr,      # your integer-labeled zone raster
    values=da,             # the raster with values to summarize
    # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
    stats_funcs=['mean'],
    nodata_values=nodata,   # very important → ensures nodata is ignored
    return_type='pandas.DataFrame'
)


In [None]:
zs_df.head()

In [None]:
gdf = gdf.merge(zs_df, on=ZONE_ID_COL, how="left")

In [None]:
gdf.columns

In [None]:




####################################################
# # Keep only "mean" and rename the band
####################################################
# zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()

# Rename column -> e.g. "evi2_mean"
zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})


####################################################
# 9) create output with zone id and values
####################################################
out = (gdf[[ZONE_ID_COL]]
       .drop_duplicates()
       .merge(zs_df, on=ZONE_ID_COL, how="left"))



In [None]:
gdf.columns

In [None]:
gdf.columns