# Create SAR-derived runoff onset / snow pillow-derived runoff onset comparison dataset

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import dask.dataframe as dd
import seaborn as sns
import xarray as xr
import geopandas as gpd
import coiled
import rasterio
import xyzservices.providers as xyz
import matplotlib.cm as cm
import xyzservices as xyz
import easysnowdata
from global_snowmelt_runoff_onset.config import Config
import global_snowmelt_runoff_onset.processing as processing
import seaborn as sns
import matplotlib.pyplot as plt
import contextily as ctx


In [None]:
config = Config('../config/global_config_v6.txt')

In [None]:
StationsWUS = easysnowdata.automatic_weather_stations.StationCollection()
StationsWUS.get_entire_data_archive()

In [None]:
tiles_with_stations_gdf = gpd.sjoin(
    config.valid_tiles_gdf,
    StationsWUS.all_stations,
    how='inner',
    predicate='contains'
)
tiles_with_stations_gdf = tiles_with_stations_gdf.drop_duplicates(subset=['row','col'])
tiles_with_stations_gdf

In [None]:
def calculate_temporal_resolutions(gdf):
    # Initialize results dictionary
    yearly_means = {}
    
    # Calculate yearly weighted means
    for year in range(2015, 2025):
        tr_col = f'tr_{year}'
        pix_col = f'pix_ct_{year}'
        
        if tr_col in gdf.columns and pix_col in gdf.columns:
            # Filter out NaN values for this year
            year_data = gdf[[tr_col, pix_col]].dropna()
            if not year_data.empty:
                weighted_mean = np.average(
                    year_data[tr_col],
                    weights=year_data[pix_col]
                )
                yearly_means[year] = weighted_mean
    
    # Calculate overall weighted mean
    tr_cols = [f'tr_{year}' for year in range(2015, 2025)]
    pix_cols = [f'pix_ct_{year}' for year in range(2015, 2025)]
    
    # Create DataFrame with all tr and pixel count values
    all_data = pd.DataFrame({
        'tr': gdf[tr_cols].values.flatten(),
        'pix': gdf[pix_cols].values.flatten()
    })
    
    # Remove rows where either tr or pix is NaN
    all_data = all_data.dropna()
    
    overall_mean = np.average(all_data['tr'], weights=all_data['pix']) if not all_data.empty else np.nan
    
    return yearly_means, overall_mean

# Calculate results
yearly_means, overall_mean = calculate_temporal_resolutions(tiles_with_stations_gdf)

# Print results
for year, mean in yearly_means.items():
    print(f"Weighted average temporal resolution for {year}: {mean:.2f}")
print(f"\nOverall weighted average temporal resolution: {overall_mean:.2f}")

In [None]:
stations_WUS_da = StationsWUS.entire_data_archive['WTEQ'].sel(time=slice('2014-10-01','2024-09-30'))
stations_WUS_da

In [None]:
# f,ax=plt.subplots(figsize=(20,80))
# stations_WUS_WY2023_da = stations_WUS_da.where(stations_WUS_da.WY==2023,drop=True)
# stations_WUS_WY2023_normalized_da = stations_WUS_WY2023_da/stations_WUS_WY2023_da.max(dim='time')


# stations_WUS_WY2023_normalized_da.plot(ax=ax,vmin=0.5,vmax=1,cmap='gist_rainbow')

In [None]:
# f,ax=plt.subplots(figsize=(40,200))

# stations_WUS_da.sel(time=slice('2021-10-01','2022-09-30')).plot(ax=ax,vmax=1.4)

In [None]:
stations_WUS_da = stations_WUS_da.where(stations_WUS_da>=0)

abs_diffs = np.abs(stations_WUS_da.diff(dim='time'))
abs_diffs_forward = abs_diffs.shift(time=1)
abs_diffs_backward = abs_diffs.shift(time=-1)
jump_mask = (abs_diffs_forward < 0.2) & (abs_diffs_backward < 0.2)
stations_WUS_da = stations_WUS_da.where(jump_mask)

window=10
valid_mask = ~np.isnan(stations_WUS_da)
rolling_valid = valid_mask.rolling(time=window*2, center=True).sum()
stations_WUS_da = stations_WUS_da.where(rolling_valid >= window)  


def check_missing_data(group):
    nov_to_apr_mask = group.time.dt.month.isin([11, 12, 1, 2, 3]) # not good!!!!! 
    filtered_group = group.where(nov_to_apr_mask,drop=True)
    missing_data_counts = filtered_group.isnull().sum(dim='time')

    missing_data = missing_data_counts > 30

    valid_data = ~np.isnan(group)

    def calc_location_gaps(location_data):
        valid_indices = ~np.isnan(location_data)
        if not np.any(valid_indices):
            return True
        valid_times = group.time.values[valid_indices]
        gaps = np.diff(valid_times)
        return np.any(gaps / np.timedelta64(1, 'D') > 10)
    
    large_gaps = xr.apply_ufunc(
        calc_location_gaps,
        group,
        input_core_dims=[['time']],
        vectorize=True,
        output_dtypes=[bool]
    )
    

    # Check for proper seasonal evolution
    if valid_data.any():
        first_valid_idx = valid_data.argmax(dim='time')
        last_valid_idx = valid_data[:, ::-1].argmax(dim='time')
        first_valid = group.isel(time=first_valid_idx)
        last_valid = group.isel(time=group.sizes['time'] - last_valid_idx - 1)
        improper_evolution = (first_valid > 0.1) | (last_valid > 0.1)
    else:
        improper_evolution = True

    # print(f'missing data count: {missing_data.sum().values}')
    # print(f'large gaps count: {large_gaps.sum().values}')
    # print(f'improper evolution count: {improper_evolution.sum().values}')

    columns_to_nan = missing_data | large_gaps | improper_evolution
    group[columns_to_nan] = np.nan
    return group

def check_seasonal_snow_swe(group):
    # Count days with SWE >= 0.05 in each window
    sufficient_swe = (group >= 0.05).rolling(time=60, center=True, min_periods=55).sum() #0.01 CHANGED TO 0.2, 20 CM!!!!!!!!!!!
    # Find locations that meet the criteria in any window
    columns_to_keep = (sufficient_swe >= 55).any(dim='time')
    # Mask out columns that don't meet criteria
    columns_to_nan = ~columns_to_keep
    group[columns_to_nan] = np.nan
    return group

def find_pct_max_timing(da, pct, dim='time', skipna=True):
    """Find the time when SWE last crosses below a percentage of max SWE"""
    max_val = da.max(dim=dim, skipna=skipna)
    threshold = max_val * pct
    # Create boolean mask of values above threshold
    above_thresh = xr.where(da >= threshold, 1, np.nan)
    # Find the last True value
    return above_thresh.sel(time=slice(None, None, -1)).swap_dims({'time':'DOWY'}).idxmax(dim="DOWY", skipna=True).drop_vars('WY').where(lambda x: x>0) #reversed to get last max value instead of first


stations_WUS_da = stations_WUS_da.groupby('WY').apply(check_missing_data)
stations_WUS_da = stations_WUS_da.groupby('WY').apply(check_seasonal_snow_swe)


#stations_WUS_max_SWE_timing_da = stations_WUS_da.fillna(-9999).groupby("WY").map(lambda x: x.sel(time=slice(None, None, -1)).idxmax("time",skipna=True).DOWY.drop_vars('WY')).where(lambda x: x>0) #reversed to get last max value instead of first
stations_WUS_max_SWE_timing_da = stations_WUS_da.groupby("WY").map(lambda x: x.sel(time=slice(None, None, -1)).swap_dims({'time':'DOWY'}).idxmax("DOWY",skipna=True).drop_vars('WY')).where(lambda x: x>0)
runoff_onset_max_swe_timing_WUS_ds = stations_WUS_max_SWE_timing_da.to_dataset(name='station_max_SWE_timing')

runoff_onset_max_swe_timing_WUS_ds["station_max_SWE_value"] = stations_WUS_da.groupby("WY").max()

pct_list = [0.99, 0.95, 0.9, 0.5]
for pct in pct_list:
    pct_str = str(int(pct * 100))
    runoff_onset_max_swe_timing_WUS_ds[f'station_max_SWE_{pct_str}pct_timing'] = stations_WUS_da.groupby("WY").map(lambda x: find_pct_max_timing(x, pct)).where(lambda x: x>0)

runoff_onset_max_swe_timing_WUS_ds

In [None]:
runoff_onset_max_swe_timing_WUS_ds.count(dim='station').sel(buffer_radius=1000,fcf=100)

In [None]:
f,ax=plt.subplots(figsize=(10,200))
#stations_WUS_max_SWE_timing_da.dropna(dim='station',how='any').plot(ax=ax)
#stations_WUS_max_SWE_timing_da.plot(ax=ax)
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_95pct_timing'].plot(ax=ax)

In [None]:
runoff_onset_global_ds = xr.open_zarr(config.global_runoff_store, consolidated=True,decode_coords='all')
runoff_onset_global_ds

In [None]:
# stations_gdf = StationsWUS.all_stations
# stations_gdf

In [None]:
stations_gdf = gpd.read_file('~/repos/updated_snotel_locations/snotel_stations_with_updated_locations.geojson').set_index('code')
stations_gdf

In [None]:
def get_station_gdf(stations_gdf, station_code, buffer_radius=None):
    station_gdf = stations_gdf[stations_gdf.index==station_code]
    station_epsg = station_gdf.estimate_utm_crs().to_epsg()
    station_gdf = station_gdf.to_crs(epsg=station_epsg)
    if buffer_radius:
        station_gdf['geometry'] = station_gdf.geometry.buffer(buffer_radius)
    return station_gdf



def get_station_buffered_runoff_onset(runoff_onset_WUS_ds, station_gdf):

    runoff_onset_station_ds = runoff_onset_WUS_ds.rio.clip_box(*station_gdf.total_bounds, crs=station_gdf.crs).rio.reproject(station_gdf.crs)
    

    fcf_da = easysnowdata.remote_sensing.get_forest_cover_fraction(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'),mask_nodata=True).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['fcf'] = fcf_da

    dem_da = easysnowdata.topography.get_copernicus_dem(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'), resolution=30).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['dem'] = dem_da

    esa_da =easysnowdata.remote_sensing.get_esa_worldcover(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'),mask_nodata=True).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.nearest)
    runoff_onset_station_ds['worldcover'] = esa_da

    runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry) 

    #runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry,all_touched=True) 

    return runoff_onset_station_ds


def get_station_buffered_runoff_onset_delayed(station_gdf):

    runoff_onset_WUS_ds = xr.open_zarr(config.global_runoff_store, consolidated=True,decode_coords='all')

    runoff_onset_station_ds = runoff_onset_WUS_ds.rio.clip_box(*station_gdf.total_bounds, crs=station_gdf.crs).rio.reproject(station_gdf.crs)
    

    fcf_da = easysnowdata.remote_sensing.get_forest_cover_fraction(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'),mask_nodata=True).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['fcf'] = fcf_da

    dem_da = easysnowdata.topography.get_copernicus_dem(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'), resolution=30).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['dem'] = dem_da

    esa_da =easysnowdata.remote_sensing.get_esa_worldcover(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'),mask_nodata=True).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.mode)
    runoff_onset_station_ds['worldcover'] = esa_da

    runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry) 

    #runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry,all_touched=True) 

    return runoff_onset_station_ds

### create comparison dataset

In [None]:
cluster = coiled.Cluster(idle_timeout="10 minutes",
                         #shutdown_on_close=False,
                         #wait_for_workers=True,
                         #n_workers=[41,170], # 170
                         #n_workers=[31,86],
                         n_workers=60,
                         #n_workers=8,
                         #n_workers=10,
                         worker_memory="8 GB", #coiled.list_instance_types(backend="azure")
                         worker_cpu=4,
                         #worker_options={"nthreads": 1},
                         #worker_options={"nthreads": 32},# 16 8 4 oversubscribe?
                         #scheduler_memory="128 GB",
                         scheduler_memory="16 GB",
                         spot_policy="spot", # spot usually
                         #software="sar_snowmelt_timing",
                         environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
                         #container="mcr.microsoft.com/planetary-computer/python:latest",
                         workspace="uwtacolab",
                         
                         )

client = cluster.get_client()

#use the following config for the problem tiles, otherwise 4 and 32, 8 and 32
                        #  worker_memory="64 GB", 
                        #  worker_cpu=8,

In [None]:
client.restart()

In [None]:
fcf_thresh_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
buffer_radii = [100, 200, 300, 500, 1000]

# Initialize output arrays
runoff_onset_timing_WUS_ds = xr.DataArray(
    data=np.nan,
    dims=["station","WY","buffer_radius","fcf"],
    coords={
        "station": runoff_onset_max_swe_timing_WUS_ds.station,
        "WY": runoff_onset_max_swe_timing_WUS_ds.WY,
        "buffer_radius": buffer_radii,
        "fcf": fcf_thresh_values
    },
    name="runoff_onset_timing"
)

frac_valid_fcf_WUS_ds = xr.DataArray(
    data=np.nan,
    dims=["station","buffer_radius","fcf"],
    coords={
        "station": runoff_onset_max_swe_timing_WUS_ds.station,
        "buffer_radius": buffer_radii,
        "fcf": fcf_thresh_values
    },
    name="pixel_count"
)

mean_fcf_WUS_ds = xr.DataArray(
    data=np.nan,
    dims=["station","buffer_radius","fcf"],
    coords={
        "station": runoff_onset_max_swe_timing_WUS_ds.station,
        "buffer_radius": buffer_radii,
        "fcf": fcf_thresh_values
    },
    name="mean_fcf"
)

In [None]:
def process_station_buffer(station_code):
    """Process a single station-buffer combination and return results"""
    # try:
    results = []
    max_buffer = max(buffer_radii)
    station_gdf_max = get_station_gdf(stations_gdf, station_code, max_buffer)
    runoff_onset_station_ds = get_station_buffered_runoff_onset_delayed(station_gdf_max).compute()

    for buffer_radius in buffer_radii:
        station_gdf = get_station_gdf(stations_gdf,station_code, buffer_radius)
        clipped_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry)
        
        
        for fcf_thresh in fcf_thresh_values:
            # if we want to mask out tree cover from ESA worldcover
            #masked_data = clipped_ds['runoff_onset'].where(clipped_ds['fcf'] <= fcf_thresh).where((clipped_ds['worldcover']!=80) & (clipped_ds['worldcover']!=50) & (clipped_ds['worldcover']!=10))
            masked_data = clipped_ds['runoff_onset'].where(clipped_ds['fcf'] <= fcf_thresh).where((clipped_ds['worldcover']!=80) & (clipped_ds['worldcover']!=50))

            result = {
                'station': station_code,
                'buffer_radius': buffer_radius,
                'fcf': fcf_thresh,
                'runoff_onset': masked_data.median(dim=['x', 'y']).rename({'water_year':'WY'}),
                # 'frac_valid': (clipped_ds['fcf'].where(clipped_ds['fcf'] <= fcf_thresh).where((clipped_ds['worldcover']!=80) & (clipped_ds['worldcover']!=50) & (clipped_ds['worldcover']!=10)).count(dim=['x', 'y'])/
                #             clipped_ds['fcf'].count(dim=['x', 'y'])),
                'frac_valid': (clipped_ds['fcf'].where(clipped_ds['fcf'] <= fcf_thresh).where((clipped_ds['worldcover']!=80) & (clipped_ds['worldcover']!=50)).count(dim=['x', 'y'])/
                            clipped_ds['fcf'].count(dim=['x', 'y'])),
                'mean_fcf': clipped_ds['fcf'].where(clipped_ds['fcf'] <= fcf_thresh).mean(dim=['x', 'y'])
            }
            results.append(result)
    return results
    # except Exception as e:
    #     print(f"Error processing station {station_code}: {e}")
    #     return None


futures = [client.submit(process_station_buffer, station_code, retries=100) for station_code in runoff_onset_max_swe_timing_WUS_ds.station.values]
futures

In [None]:
stati = [future.status for future in futures]
np.unique(stati,return_counts=True)

In [None]:
results = client.gather(futures)
results

In [None]:
# Fill arrays with results
for result_group in results:
    if result_group is not None:
        for result in result_group:
            station = result['station']
            buffer_radius = result['buffer_radius']
            fcf = result['fcf']
            runoff_onset_timing_WUS_ds.loc[dict(station=station, buffer_radius=buffer_radius, fcf=fcf)] = result['runoff_onset']
            frac_valid_fcf_WUS_ds.loc[dict(station=station, buffer_radius=buffer_radius, fcf=fcf)] = result['frac_valid']
            mean_fcf_WUS_ds.loc[dict(station=station, buffer_radius=buffer_radius, fcf=fcf)] = result['mean_fcf']

# Final assignments
runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(
    runoff_onset_timing=runoff_onset_timing_WUS_ds,
    frac_valid_fcf=frac_valid_fcf_WUS_ds,
    mean_fcf=mean_fcf_WUS_ds
)

runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'] = (
    runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'] - 
    runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing']
)

for pct in pct_list:
    pct_str = str(int(pct * 100))
    runoff_onset_max_swe_timing_WUS_ds[f'sar_minus_{pct_str}pct'] = (
        runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'] - runoff_onset_max_swe_timing_WUS_ds[f'station_max_SWE_{pct_str}pct_timing'])
    

runoff_onset_max_swe_timing_WUS_ds

In [None]:
#runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf(f'station_comparison/snotel_sar_differences_UPDATED_SNOTEL_LOCATIONS_v5.nc')
#runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf(f'station_comparison/snotel_sar_differences_treemask_UPDATED_SNOTEL_LOCATIONS_andreq20cmSWEfor60days_v5.nc')
runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf(f'comparison_datasets/snotel_sar_differences_UPDATED_SNOTEL_LOCATIONS_andreq5cmSWEfor60days_v6.nc')

## Code graveyard

In [None]:
#fcf_thresh_values = [10,20,30,40,50,60,70,80,90,100]
# buffer_radii = [50,100,300,500,1000]

# def process_single_station(station_code):
#     fcf_thresh_values = [10,20,30,40,50,60,70,80,90,100]
#     buffer_radii = [50,100,300,500,1000]
#     """Process all buffer/fcf combinations for a single station"""
#     station_results = {
#         'runoff_onset': {},
#         'frac_valid': {},
#         'mean_fcf': {}
#     }
    
#     runoff_onset_global_ds = xr.open_zarr(config.global_runoff_store, consolidated=True,decode_coords='all')
    
#     try:
#         for buffer_radius in buffer_radii:
#             station_gdf = get_station_gdf(station_code, buffer_radius)
#             runoff_onset_station_ds = get_station_buffered_runoff_onset(
#                 runoff_onset_global_ds, 
#                 station_gdf
#             )
            
#             for fcf_thresh in fcf_thresh_values:
#                 key = (station_code, buffer_radius, fcf_thresh)
                
#                 # Calculate metrics for this combination
#                 runoff_onset = runoff_onset_station_ds['runoff_onset'].where(
#                     runoff_onset_station_ds['fcf'] <= fcf_thresh
#                 ).median(dim=['x', 'y']).rename({'water_year':'WY'})
                
#                 frac_valid = (runoff_onset_station_ds['fcf'].where(
#                     runoff_onset_station_ds['fcf'] <= fcf_thresh
#                 ).count(dim=['x', 'y'])/runoff_onset_station_ds['fcf'].count(dim=['x', 'y']))
                
#                 mean_fcf = runoff_onset_station_ds['fcf'].where(
#                     runoff_onset_station_ds['fcf'] <= fcf_thresh
#                 ).mean(dim=['x', 'y'])
                
#                 station_results['runoff_onset'][key] = runoff_onset
#                 station_results['frac_valid'][key] = frac_valid
#                 station_results['mean_fcf'][key] = mean_fcf
                
#     except Exception as e:
#         print(f"Error processing station {station_code}: {e}")
        
#     return station_results


# station_codes = list(runoff_onset_max_swe_timing_WUS_ds.station.values)

# # futures = []
# # for station_code in runoff_onset_max_swe_timing_WUS_ds.station.values:
# #     future = client.submit(
# #         process_single_station,
# #         station_code,
# #     )
# #     futures.append(future)

# futures = [client.submit(process_single_station, station_code) for station_code in station_codes]

# # Gather results
# print(f"Computing {len(futures)} station tasks in parallel...")
# results = client.gather(futures)
# results


# # Create output arrays with same structure as before
# runoff_onset_timing = xr.DataArray(
#     data=np.nan,
#     dims=["station", "WY", "buffer_radius", "fcf"],
#     coords={
#         "station": runoff_onset_max_swe_timing_WUS_ds.station,
#         "WY": runoff_onset_max_swe_timing_WUS_ds.WY,
#         "buffer_radius": buffer_radii,
#         "fcf": fcf_thresh_values
#     }
# )

# frac_valid_fcf = xr.DataArray(
#     data=np.nan,
#     dims=["station", "buffer_radius", "fcf"],
#     coords={
#         "station": runoff_onset_max_swe_timing_WUS_ds.station,
#         "buffer_radius": buffer_radii,
#         "fcf": fcf_thresh_values
#     }
# )

# mean_fcf = xr.DataArray(
#     data=np.nan,
#     dims=["station", "buffer_radius", "fcf"],
#     coords={
#         "station": runoff_onset_max_swe_timing_WUS_ds.station,
#         "buffer_radius": buffer_radii,
#         "fcf": fcf_thresh_values
#     }
# )


# # Populate arrays from results
# for station_result in results:
#     for key, onset in station_result['runoff_onset'].items():
#         station, buffer, fcf = key
#         runoff_onset_timing.loc[dict(station=station, buffer_radius=buffer, fcf=fcf)] = onset
#         frac_valid_fcf.loc[dict(station=station, buffer_radius=buffer, fcf=fcf)] = station_result['frac_valid'][key]
#         mean_fcf.loc[dict(station=station, buffer_radius=buffer, fcf=fcf)] = station_result['mean_fcf'][key]

# # Create final dataset and save
# runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(
#     runoff_onset_timing=runoff_onset_timing,
#     frac_valid_fcf=frac_valid_fcf,
#     mean_fcf=mean_fcf
# )

# runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'] = (
#     runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'] - 
#     runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing']
# )

# runoff_onset_max_swe_timing_WUS_ds

# runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf('snotel_sar_differences_vv_v4.nc')


# fcf_thresh_values = [10,20,30,40,50,60,70,80,90,100]
# buffer_radii = [100,200,300,500,1000]

# runoff_onset_timing_WUS_ds = xr.DataArray(
#     data=np.nan,  # Initialize with NaNs
#     dims=["station","WY","buffer_radius","fcf"],
#     coords={"station": runoff_onset_max_swe_timing_WUS_ds.station, "WY": runoff_onset_max_swe_timing_WUS_ds.WY, "buffer_radius": buffer_radii, "fcf": fcf_thresh_values},
#     name="runoff_onset_timing"
# )

# frac_valid_fcf_WUS_ds = xr.DataArray(
#     data=np.nan,  # Initialize with NaNs
#     dims=["station","buffer_radius","fcf"],
#     coords={"station": runoff_onset_max_swe_timing_WUS_ds.station, "buffer_radius": buffer_radii, "fcf": fcf_thresh_values},
#     name="pixel_count"
# )

# mean_fcf_WUS_ds = xr.DataArray(
#     data=np.nan,  # Initialize with NaNs
#     dims=["station","buffer_radius","fcf"],
#     coords={"station": runoff_onset_max_swe_timing_WUS_ds.station, "buffer_radius": buffer_radii, "fcf": fcf_thresh_values},
#     name="mean_fcf"
# )


# # Iterate over each station
# for station_code in tqdm.tqdm(runoff_onset_max_swe_timing_WUS_ds.station.values):
#     print(f'Processing station {station_code}')
#     for buffer_radius in buffer_radii:
#         print(f'Processing buffer radius {buffer_radius}')
#         try: 
#             station_gdf = get_station_gdf(station_code, buffer_radius)

#             runoff_onset_station_ds = get_station_buffered_runoff_onset(runoff_onset_global_ds, station_gdf)
#             runoff_onset_station_ds = runoff_onset_station_ds.where(runoff_onset_station_ds['worldcover']!=80)
            
#             # Assign the calculated timing to the new DataArray
#             for fcf_thresh in fcf_thresh_values:
#                 runoff_onset_timing_WUS_ds.loc[dict(station=station_code,buffer_radius=buffer_radius,fcf=fcf_thresh)] = runoff_onset_station_ds['runoff_onset'].where(runoff_onset_station_ds['fcf'] <= fcf_thresh).median(dim=['x', 'y']).rename({'water_year':'WY'})
#                 frac_valid_fcf_WUS_ds.loc[dict(station=station_code,buffer_radius=buffer_radius,fcf=fcf_thresh)] = (runoff_onset_station_ds['fcf'].where(runoff_onset_station_ds['fcf'] <= fcf_thresh).count(dim=['x', 'y'])/runoff_onset_station_ds['fcf'].count(dim=['x', 'y']))
#                 mean_fcf_WUS_ds.loc[dict(station=station_code,buffer_radius=buffer_radius,fcf=fcf_thresh)] = runoff_onset_station_ds['fcf'].where(runoff_onset_station_ds['fcf'] <= fcf_thresh).mean(dim=['x', 'y'])

#         except Exception as e:
#             print(e)

# runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(runoff_onset_timing=runoff_onset_timing_WUS_ds)
# runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(frac_valid_fcf=frac_valid_fcf_WUS_ds)
# runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(mean_fcf=mean_fcf_WUS_ds)
# runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'] = runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'] - runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing']
# runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf(f'snotel_sar_differences_vv_v4.nc')