"""
Script to compute zonal statistics for GEDI L4A points against a multiband HLS raster.
Steps:
1. Load GEDI L4A geospatial points (optionally buffer it into polygons - ## check errors on rastererization).
2. Load single band image - e.g. HLS composite raster for one band.
3. Reproject points to raster CRS and rasterize zones.
4. Compute per-zone statistics (e.g., mean) using xrspatial.zonal_stats.
5. Join results back to the input GeoDataFrame and export to GeoPackage.
"""

In [2]:
import os
import glob
import numpy as np
import geopandas as gpd
import rioxarray
from rasterio import features
from shapely.geometry import Point
from xrspatial import zonal_stats
import pandas as pd
from pyproj import CRS

In [3]:

def zonal_stats_raster(gpkg_points_path,img_path, ZONE_ID_COL = 'zone', BUFFER_METERS =0, band_name = 'b'):
        
    ####################################################
    # 1) Open raster as xarray (single band)
    ####################################################
    
    da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
    if 'band' in da.dims:
        da = da.squeeze('band', drop=True)
        
    raster_crs = CRS.from_user_input(da.rio.crs)
    transform  = da.rio.transform()
    out_shape  = da.shape  # (rows, cols)
    nodata     = da.rio.nodata
    
    da = da.where(da != da.rio.nodata)
    
    
    
    ####################################################
    # 2) Read points
    ####################################################
    gdf = gpd.read_file(gpkg_points_path)
    
    # Create ID for the zonal stats
    if ZONE_ID_COL not in gdf.columns:
        gdf = gdf.reset_index(drop=True)
        gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)
    
    
    
    ####################################################
    # 3) Reproject to raster CRS and buffer to polygons 
    ####################################################
    
    gdf_proj = gdf.to_crs(raster_crs)
    
    # BUffer
    if BUFFER_METERS > 0:
        gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)
    
    
    ####################
    # 4) Build (geometry, id) tuples for rasterization
    ####################
    geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))
    
    
    
    ####################################################
    # 5) Rasterize zones (same shape/transform as raster)
    ####################################################
    zones_arr = features.rasterize(
        geom,
        out_shape=out_shape,
        transform=transform,
        fill= nodata,
        nodata = nodata,
        masked = True,
        # fill=0,# background (no zone)
        dtype="int32",
        # all_touched=True
    )
    
    
    
    ####################################################
    # # 6) Wrap zones into an xarray aligned with the raster
    ####################################################
    
    zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
    zones_xarr.data = zones_arr
    # zones_xarr.attrs["nodata"] = nodata
    # zones_xarr
    
    
    ####################################################
    # 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
    ####################################################
    
    # If your raster has nodata, pass it so stats ignore it
    zs_df = zonal_stats(
        zones=zones_xarr,      # your integer-labeled zone raster
        values=da,             # the raster with values to summarize
        # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
        stats_funcs=['mean'],
        nodata_values=nodata,   # very important → ensures nodata is ignored
        return_type='pandas.DataFrame'
    )
    
    
    ####################################################
    # # Keep only "mean" and rename the band
    ####################################################
    # zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()
    
    # Rename column -> e.g. "evi2_mean"
    zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})
    
    
    ####################################################
    # 9) create output with zone id and values
    ####################################################
    # out = (gdf[[ZONE_ID_COL]]
    #        .drop_duplicates()
    #        .merge(zs_df, on=ZONE_ID_COL, how="left"))
    
   
    
    return zs_df
    
    

In [4]:
ref_year = 2022 # Ref year of target variable
ref_tile = 89
start_year, end_year = 2020, 2022 # length of time series

In [5]:
gpkg_points_path = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_veg{ref_year}_outrm.gpkg'

In [6]:
# Multiband image path
dir_img = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/monthly/br_af_grid30km_prj_evi2_p95/tile_{ref_tile:03d}/bands_vrt_{start_year}-{end_year}/'
# dir_img

In [7]:
output_gpkg_zonalstats_fn = f'/projects/my-private-bucket/HLS-1DCNN-AGB/data/shp/gedi/test/l4a_t90km_t{ref_tile}_{start_year}-{end_year}_outrm_zonal_TimeSeries.gpkg'

In [8]:
# Define excluded band suffixes
exclude_bands = ['ValidMask.vrt', 'count.vrt', 'yearDate.vrt', 'JulianDate.vrt']


In [9]:
ZONE_ID_COL = 'zone'
BUFFER_METERS = 0

In [10]:
dir_img

'/projects/my-private-bucket/HLS-1DCNN-AGB/data/tif/HLS_composites/monthly/br_af_grid30km_prj_evi2_p95/tile_089/bands_vrt_2020-2022/'

In [11]:
# glob.glob(f'{dir_img}/*')

In [14]:
# img_path = img_paths[0]

In [15]:
# img_path

In [16]:
# stem = Path(img_path).stem
# print(stem)
# parts = stem.split("_")
# print(parts)
# year = int(parts[4])
# print(year)

# month = parts[2].split('-')[0]
# print(month)

In [17]:
from pathlib import Path


# def parse_key(fpath: str):
#     """
#     Extract (band, year, month) so sorting is band-first, then time.
#     """
#     stem = Path(fpath).stem
#     parts = stem.split("_")
    
#     # year is 4th element, month is first two chars of 2nd element
#     year = int(parts[4])
#     month = int(parts[1][:2])
    
#     band = parts[-1]  # last part is band name
    
#     return (band, year, month)



def parse_key(fpath: str):
    """
    Extract (band, year, month) so sorting is band-first, then time.
    """
    stem = Path(fpath).stem
    parts = stem.split("_")
    
    # year is 4th element, month is first two chars of 2nd element
    year = int(parts[4])
    month = parts[2].split('-')[0]

    
    band = parts[-1]  # last part is band name
    
    return (band, year, month)



# def parse_key(fpath: str):
#     """
#     Extract (year, month, band) using the date pattern:
#     e.g., HLS_89_01-01_01-31_2020_2020_percentile95.0evi2_Blue.vrt
#                        ^^^^^
#                        month is '01'
#     """
#     stem = Path(fpath).stem
#     parts = stem.split("_")

#     # parts[2] is '01-01' (MM-DD)
#     month = int(parts[2][:2])
#     year = int(parts[3])   # '2020'

#     band = parts[-1]       # 'Blue'

#     return (year, month, band)

In [18]:
img_paths = glob.glob(f'{dir_img}/**/*.vrt',recursive= True)


img_paths = [
    f for f in img_paths
    if not any(f.endswith(ex) for ex in exclude_bands)
]

In [19]:
img_paths = sorted(img_paths, key=parse_key)

In [21]:
# img_paths[0:10]

In [33]:
# Load the base gdf
gdf = gpd.read_file(gpkg_points_path)



In [34]:
# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)

In [35]:
len(img_paths)

627

In [36]:
%%time
# Loop through rasters
for img_path in img_paths[:2]:
   
    # Get band name
    fname = os.path.basename(img_path)
    # -> "HLS_89_01-01_12-30_2022_2022_percentile95.0evi2_Blue.vrt"    
    fn_parts = fname.replace(".vrt", "").split("_")
    # ['HLS', '89', '01-01', '12-30', '2022', '2022', 'percentile95.0evi2', 'Blue']  
    # Rebuild the string to variable name
    band_name = f"{fn_parts[0]}.{fn_parts[4]}-{fn_parts[2]}.{fn_parts[5]}-{fn_parts[3]}.{fn_parts[-1]}"
    
    # print(f'Zonal stats: {band_name}')

    
    # Run your function (it returns a gdf with ZONE_ID + stats)
    out = zonal_stats_raster(gpkg_points_path, img_path, ZONE_ID_COL, BUFFER_METERS, band_name)
    
    # print(f'Merge: {band_name}')
    # # Merge results back into main gdf (on zone_id)
    gdf = gdf.merge(out, on=ZONE_ID_COL, how="left")


# ~10k points 627 bands 90 km tile
# Wall time: 21min 2s

CPU times: user 2.33 s, sys: 97.5 ms, total: 2.43 s
Wall time: 4.11 s


In [37]:
gdf.columns


Index(['time', 'beam', 'elevation', 'shot_number', 'flags', 'sensitivity',
       'orbit', 'solar_elevation', 'track', 'agbd', 'year', 'month',
       'vegmask_2022', 'geometry', 'zone',
       'HLS.2020-01-01.2020-01-31.Blue_mean',
       'HLS.2020-02-01.2020-02-29.Blue_mean'],
      dtype='object')

In [38]:
gdf.head()

Unnamed: 0,time,beam,elevation,shot_number,flags,sensitivity,orbit,solar_elevation,track,agbd,year,month,vegmask_2022,geometry,zone,HLS.2020-01-01.2020-01-31.Blue_mean,HLS.2020-02-01.2020-02-29.Blue_mean
0,2022-01-04 23:15:58.116,6,869.504272,173550600100128255,134,0.986958,17355,-13.757421,8749,40.626755,2022,1,3,POINT (-49.37641 -24.1899),1,,
1,2022-01-04 23:15:58.124,6,868.808105,173550600100128256,134,0.984109,17355,-13.757937,8749,33.338924,2022,1,3,POINT (-49.37604 -24.1895),2,,
2,2022-01-04 23:15:58.132,6,873.284424,173550600100128257,134,0.981359,17355,-13.758456,8749,23.36887,2022,1,3,POINT (-49.37568 -24.18911),3,,
3,2022-01-04 23:15:58.178,11,993.211304,173551100100090665,134,0.986288,17355,-13.746551,8749,8.862288,2022,1,3,POINT (-49.37735 -24.20893),4,,
4,2022-01-04 23:15:58.256,6,887.880554,173550600100128272,134,0.985115,17355,-13.766229,8749,3.462804,2022,1,3,POINT (-49.37025 -24.18319),5,,


In [26]:
# gdf.columns

In [30]:
gdf_stats_nona = gdf.dropna(subset ='HLS.2022-03-01.2022-03-31.NDVI_mean')

In [31]:
len(gdf_stats_nona)

0

In [105]:
# Export gdf_stats
gdf_stats_nona.to_file(output_gpkg_zonalstats_fn)


In [1]:
output_gpkg_zonalstats_fn

NameError: name 'output_gpkg_zonalstats_fn' is not defined

# Test

In [81]:
####################################################
# 1) Open raster as xarray (single band)
####################################################

da = rioxarray.open_rasterio(img_path).squeeze()  # [y, x]
if 'band' in da.dims:
    da = da.squeeze('band', drop=True)
    
raster_crs = CRS.from_user_input(da.rio.crs)
transform  = da.rio.transform()
out_shape  = da.shape  # (rows, cols)
nodata     = da.rio.nodata

da = da.where(da != da.rio.nodata)



####################################################
# 2) Read points
####################################################

# Create ID for the zonal stats
if ZONE_ID_COL not in gdf.columns:
    gdf = gdf.reset_index(drop=True)
    gdf[ZONE_ID_COL] = np.arange(1, len(gdf) + 1, dtype=np.int32)



####################################################
# 3) Reproject to raster CRS and buffer to polygons 
####################################################

gdf_proj = gdf.to_crs(raster_crs)

# BUffer
if BUFFER_METERS > 0:
    gdf_proj["geometry"] = gdf_proj.geometry.buffer(BUFFER_METERS)


####################
# 4) Build (geometry, id) tuples for rasterization
####################
geom = list(zip(gdf_proj.geometry, gdf_proj[ZONE_ID_COL]))



####################################################
# 5) Rasterize zones (same shape/transform as raster)
####################################################
zones_arr = features.rasterize(
    geom,
    out_shape=out_shape,
    transform=transform,
    fill= nodata,
    nodata = nodata,
    masked = True,
    # fill=0,# background (no zone)
    dtype="int32",
    # all_touched=True
)






In [82]:
####################################################
# # 6) Wrap zones into an xarray aligned with the raster
####################################################

zones_xarr = da.copy(deep=False)  # copies coords/attrs, not data
zones_xarr.data = zones_arr
# zones_xarr.attrs["nodata"] = nodata
# zones_xarr


####################################################
# 7) Compute zonal stats using xrspatial (min/max/mean/etc.)
####################################################

# If your raster has nodata, pass it so stats ignore it
zs_df = zonal_stats(
    zones=zones_xarr,      # your integer-labeled zone raster
    values=da,             # the raster with values to summarize
    # stats_funcs=['mean', 'max', 'min', 'sum', 'std', 'var', 'count'],
    stats_funcs=['mean'],
    nodata_values=nodata,   # very important → ensures nodata is ignored
    return_type='pandas.DataFrame'
)


In [83]:
zs_df.head()

Unnamed: 0,zone,mean
0,4.0,0.0103
1,48.0,0.0158
2,49.0,0.0146
3,50.0,0.0153
4,51.0,0.0145


In [84]:
gdf = gdf.merge(zs_df, on=ZONE_ID_COL, how="left")

In [85]:
gdf.columns

Index(['time', 'beam', 'elevation', 'shot_number', 'flags', 'sensitivity',
       'orbit', 'solar_elevation', 'track', 'agbd', 'year', 'month',
       'vegmask_2022', 'geometry', 'zone', 'mean'],
      dtype='object')

In [76]:




####################################################
# # Keep only "mean" and rename the band
####################################################
# zs_df = pd.DataFrame({"mean": pd.Series(zs["mean"])}).rename_axis(ZONE_ID_COL).reset_index()

# Rename column -> e.g. "evi2_mean"
zs_df = zs_df.rename(columns={"mean": f"{band_name}_mean"})


####################################################
# 9) create output with zone id and values
####################################################
out = (gdf[[ZONE_ID_COL]]
       .drop_duplicates()
       .merge(zs_df, on=ZONE_ID_COL, how="left"))



In [75]:
gdf.columns

Index(['time', 'beam', 'elevation', 'shot_number', 'flags', 'sensitivity',
       'orbit', 'solar_elevation', 'track', 'agbd', 'year', 'month',
       'vegmask_2020', 'geometry', 'zone'],
      dtype='object')

In [70]:
gdf.columns

Index(['time', 'beam', 'elevation', 'shot_number', 'flags', 'sensitivity',
       'orbit', 'solar_elevation', 'track', 'agbd', 'year', 'month',
       'vegmask_2020', 'geometry', 'zone'],
      dtype='object')