In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import dask.dataframe as dd
import seaborn as sns
import xarray as xr
import geopandas as gpd
import xvec
import coiled
import rasterio
import xyzservices.providers as xyz
import dask
import matplotlib.cm as cm

import tqdm
import easysnowdata
from global_snowmelt_runoff_onset.config import Config

In [None]:
config = Config('../config/global_config.txt')

In [None]:
StationsWUS = easysnowdata.automatic_weather_stations.StationCollection()
StationsWUS.get_entire_data_archive()

In [None]:
tiles_with_stations

In [None]:
tiles_with_stations_gdf = gpd.sjoin(
    config.valid_tiles_gdf,
    StationsWUS.all_stations,
    how='inner',
    predicate='contains'
)
tiles_with_stations_gdf = tiles_with_stations_gdf.drop_duplicates(subset=['row','col'])
tiles_with_stations_gdf

In [None]:
tiles_with_stations_gdf.explore()

In [None]:
stations_WUS_da = StationsWUS.entire_data_archive['WTEQ'].sel(time=slice('2014-10-01','2024-09-30'))
stations_WUS_da

In [None]:
#stations_WUS_da.fillna(-9999).groupby("WY")#.map(lambda x: x.argmax("time",skipna=True))

In [None]:
abs_diffs = np.abs(stations_WUS_da.diff(dim='time'))
stations_WUS_da = stations_WUS_da.where((stations_WUS_da>=0) & (abs_diffs<0.2))


def check_missing_data(group):
    nov_to_apr_mask = group.time.dt.month.isin([11, 12, 1, 2, 3, 4, 5])
    filtered_group = group.where(nov_to_apr_mask,drop=True)
    missing_data_counts = filtered_group.isnull().sum(dim='time')
    columns_to_nan = missing_data_counts > 30
    group[columns_to_nan] = np.nan
    return group


def check_seasonal_snow_swe(group):
    # Count days with SWE >= 0.05 in each window
    sufficient_swe = (group >= 0.01).rolling(time=60, center=True, min_periods=55).sum()
    # Find locations that meet the criteria in any window
    columns_to_keep = (sufficient_swe >= 55).any(dim='time')
    # Mask out columns that don't meet criteria
    columns_to_nan = ~columns_to_keep
    group[columns_to_nan] = np.nan
    return group

stations_WUS_da = stations_WUS_da.groupby('WY').apply(check_missing_data)
stations_WUS_da = stations_WUS_da.groupby('WY').apply(check_seasonal_snow_swe)


stations_WUS_max_SWE_timing_da = stations_WUS_da.fillna(-9999).groupby("WY").map(lambda x: x.argmax("time",skipna=True)).where(lambda x: x>0)
runoff_onset_max_swe_timing_WUS_ds = stations_WUS_max_SWE_timing_da.to_dataset(name='station_max_SWE_timing')
runoff_onset_max_swe_timing_WUS_ds

In [None]:
#aws_WUS_max_SWE_timing_da.dropna(dim='station',how='all')
#aws_WUS_max_SWE_timing_da.dropna(dim='station',how='any')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create figure with appropriate size
plt.figure(figsize=(12, 6))

# Create boxplot using seaborn
sns.boxplot(data=stations_WUS_max_SWE_timing_da.to_dataframe().reset_index(), 
           x='WY', 
           y=stations_WUS_max_SWE_timing_da.name,
           color='lightblue')

# Add a line connecting the medians
sns.pointplot(data=stations_WUS_max_SWE_timing_da.to_dataframe().reset_index(),
             x='WY',
             y=stations_WUS_max_SWE_timing_da.name,
             color='darkblue',
             markers='o')

plt.xticks(rotation=45)
plt.title('Distribution of Max SWE Timing Across Stations by Water Year')
plt.xlabel('Water Year')
plt.ylabel('Max SWE Timing (days)')
plt.tight_layout()


import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))

# Create violin plot
sns.violinplot(data=stations_WUS_max_SWE_timing_da.to_dataframe().reset_index(),
               x='WY',
               y=stations_WUS_max_SWE_timing_da.name,
               color='lightblue',
               inner='box')  # Shows quartile box inside violin

plt.xticks(rotation=45)
plt.title('Distribution of Max SWE Timing Across Stations by Water Year')
plt.xlabel('Water Year')
plt.ylabel('Max SWE Timing (days)')
plt.tight_layout()


In [None]:
f,ax=plt.subplots(figsize=(10,30))
stations_WUS_max_SWE_timing_da.dropna(dim='station',how='any').plot(ax=ax)

In [None]:
StationsWUS.all_stations.explore()

In [None]:
runoff_onset_global_ds = xr.open_zarr(config.global_runoff_store, consolidated=True,decode_coords='all')
runoff_onset_global_ds

runoff_onset_WUS_ds = runoff_onset_global_ds.rio.clip_box(-120,30,-110,50,crs='EPSG:4326')
runoff_onset_WUS_ds

In [None]:
def get_station_gdf(station_code, buffer_radius=None):
    station_gdf = StationsWUS.all_stations[StationsWUS.all_stations.index==station_code]
    station_epsg = station_gdf.estimate_utm_crs().to_epsg()
    station_gdf = station_gdf.to_crs(epsg=station_epsg)
    if buffer_radius:
        station_gdf['geometry'] = station_gdf.geometry.buffer(buffer_radius)
    return station_gdf



def get_station_buffered_runoff_onset(runoff_onset_WUS_ds, station_gdf):

    runoff_onset_station_ds = runoff_onset_WUS_ds.rio.clip_box(*station_gdf.total_bounds, crs=station_gdf.crs).rio.reproject(station_gdf.crs)
    runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry) 

    fcf_da = easysnowdata.remote_sensing.get_forest_cover_fraction(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326')).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['fcf'] = fcf_da

    dem_da = easysnowdata.topography.get_copernicus_dem(runoff_onset_station_ds.rio.transform_bounds('EPSG:4326'), resolution=30).rio.reproject_match(runoff_onset_station_ds,resampling=rasterio.enums.Resampling.bilinear)
    runoff_onset_station_ds['dem'] = dem_da

    #runoff_onset_station_ds = runoff_onset_station_ds.rio.clip(station_gdf.geometry,all_touched=True) 

    return runoff_onset_station_ds

In [None]:
buffer_radius = 500
station_code = '907_UT_SNTL'
station_code = '1135_UT_SNTL'
station_code = '448_MT_SNTL'
station_code = '301_CA_SNTL'
station_code = '679_WA_SNTL'
station_code = '916_MT_SNTL'
station_code = '1267_AK_SNTL'
station_code = '908_WA_SNTL'
station_code 

In [None]:
station_gdf = get_station_gdf(station_code, buffer_radius) # check auto expand, cant use clip if yes
#station_gdf = get_station_gdf(station_code)
station_gdf.explore(tiles=xyz.Esri.WorldImagery())

In [None]:
runoff_onset_station_ds = get_station_buffered_runoff_onset(runoff_onset_global_ds, station_gdf)
runoff_onset_station_ds

In [None]:
runoff_onset_station_ds['runoff_onset'].where(runoff_onset_station_ds['fcf']<50).median(dim=['x','y'])

In [None]:
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].sel(station=station_code)


In [None]:
f,ax=plt.subplots(figsize=(12,7))
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].sel(station=station_code).plot(ax=ax, label='Max SWE Timing',color='black')

fcf_thresh_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

cmap = cm.get_cmap('viridis', len(fcf_thresh_values))

for i, fcf_thresh in enumerate(fcf_thresh_values):
    color = cmap(i)
    runoff_onset_station_ds['runoff_onset'].where(runoff_onset_station_ds['fcf'] < fcf_thresh).median(dim=['x', 'y']).plot.scatter(ax=ax, label=f'{fcf_thresh}th percentile', color=color)
ax.legend()

In [None]:
f,ax=plt.subplots(figsize=(12,7))
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].sel(station=station_code).plot(ax=ax, label='Max SWE Timing',color='black')

fcf_thresh_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

cmap = cm.get_cmap('inferno', len(fcf_thresh_values))

for i, fcf_thresh in enumerate(fcf_thresh_values):
    color = cmap(i)
    runoff_onset_station_ds['runoff_onset'].where((runoff_onset_station_ds['fcf'] < fcf_thresh) & (runoff_onset_station_ds['fcf'] >= fcf_thresh-10)).median(dim=['x', 'y']).plot.scatter(ax=ax, label=f'{fcf_thresh-10}<=fcf<{fcf_thresh}', color=color)
ax.legend()

In [None]:
import hvplot.xarray

In [None]:
easysnowdata.utils.datetime_to_DOWY('2024-04-10')

In [None]:
#f,ax=plt.subplots(figsize=(12,7))
stations_WUS_da.sel(station=station_code).hvplot.scatter(x='time')

In [None]:
runoff_onset_station_ds['runoff_onset'].plot.imshow(col='water_year',col_wrap=5,robust=True,vmin=50,vmax=250)

In [None]:
runoff_onset_station_ds['runoff_onset'].where(runoff_onset_station_ds['fcf']<20).plot.imshow(col='water_year',col_wrap=5,robust=True,vmin=50,vmax=250)

In [None]:
runoff_onset_station_ds['fcf'].where(runoff_onset_station_ds['runoff_onset_median']>0).plot.imshow()

In [None]:
runoff_onset_station_ds['dem'].where(runoff_onset_station_ds['runoff_onset_median']>0).plot.imshow()

In [None]:
# now test sensitivty to FCF

In [None]:
fcf_thresh_values = [10,20,30,40,50,60,70,80,90,100]
buffer_radii = [50,100,200,300,400,500,1000]

runoff_onset_timing_WUS_ds = xr.DataArray(
    data=np.nan,  # Initialize with NaNs
    dims=["station","WY","buffer_radii","fcf"],
    coords={"station": runoff_onset_max_swe_timing_WUS_ds.station, "WY": runoff_onset_max_swe_timing_WUS_ds.WY, "buffer_radii": buffer_radii, "fcf": fcf_thresh_values},
    name="runoff_onset_timing"
)
#runoff_onset_timing_WUS_ds


# Iterate over each station
for station_code in tqdm.tqdm(runoff_onset_max_swe_timing_WUS_ds.station.values):
    print(f'Processing station {station_code}')
    for buffer_radius in buffer_radii:
        print(f'Processing buffer radius {buffer_radius}')
        try:
            station_gdf = get_station_gdf(station_code, buffer_radius)

            runoff_onset_station_ds = get_station_buffered_runoff_onset(runoff_onset_global_ds, station_gdf)
            
            # Assign the calculated timing to the new DataArray
            for fcf_thresh in fcf_thresh_values:
                runoff_onset_timing_WUS_ds.loc[dict(station=station_code,buffer_radii=buffer_radius,fcf=fcf_thresh)] = runoff_onset_station_ds['runoff_onset'].where(runoff_onset_station_ds['fcf'] <= fcf_thresh).median(dim=['x', 'y']).rename({'water_year':'WY'})
        except Exception as e:
            print(e)

runoff_onset_max_swe_timing_WUS_ds = runoff_onset_max_swe_timing_WUS_ds.assign(runoff_onset_timing=runoff_onset_timing_WUS_ds)
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'] = runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'] - runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing']
runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf(f'snotel_sar_differences.nc')

In [None]:
test_ds = xr.open_dataset('snotel_sar_differences.nc').compute()
test_ds

In [None]:
runoff_onset_max_swe_timing_WUS_ds

In [None]:
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=10).median(dim='station').plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].median(dim='station').plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=50).median(dim='station')

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=100).median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=50).median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=10).median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=10).count(dim='station').plot()

In [None]:
fcf_thresh=20
buffer_radius=500
f,axs=plt.subplots(len(runoff_onset_max_swe_timing_WUS_ds.WY.values),1,figsize=(6,20),sharex=True)
for ax, WY in zip(axs,runoff_onset_max_swe_timing_WUS_ds.WY.values):
    runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(WY=WY,fcf=10,buffer_radii=50).plot.hist(ax=ax,bins=50)#.median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')
    print(f'For {WY} with fcf_thresh of {fcf_thresh} and buffer_radius of {buffer_radius}, MAE = {np.abs(runoff_onset_max_swe_timing_WUS_ds["sar_minus_stations"].sel(WY=WY,fcf=fcf_thresh,buffer_radii=buffer_radius)).mean().values}')
f.tight_layout()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=20,buffer_radii=100).plot.hist(bins=50)#.median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')

In [None]:
sar_minus_stations_df = runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=20,buffer_radii=100).median(dim='WY').to_dataframe()
sar_minus_stations_gdf = gpd.GeoDataFrame(sar_minus_stations_df, geometry=gpd.points_from_xy(sar_minus_stations_df.longitude, sar_minus_stations_df.latitude))
sar_minus_stations_gdf

In [None]:
import xyzservices as xyz

In [None]:
sar_minus_stations_gdf.explore('sar_minus_stations', tiles=xyz.providers.Esri.WorldImagery)

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=20,buffer_radii=1000).plot.hist(bins=50)#.median(dim='station').plot(vmin=-20,vmax=20,cmap='RdBu_r')

In [None]:
np.abs(runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].sel(fcf=20)).median(dim='station').plot(vmin=0,vmax=30)

In [None]:
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=100)#.plot.hist(alpha=0.5,bins=500)


In [None]:
f,ax=plt.subplots()
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].plot.hist(ax=ax,alpha=0.5,bins=50)
for fcf_thresh in fcf_thresh_values:
    runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=fcf_thresh).plot.hist(ax=ax,alpha=0.2,bins=50)
ax.legend()



In [None]:
f,ax=plt.subplots(figsize=(10,10))
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].plot.hist(ax=ax,alpha=0.5,bins=50,color='black')
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=10).plot.hist(ax=ax,alpha=0.2,bins=50,color='red')
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=50).plot.hist(ax=ax,alpha=0.2,bins=50,color='green')
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].sel(fcf=100).plot.hist(ax=ax,alpha=0.2,bins=50,color='blue')

ax.legend()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations']

In [None]:
runoff_onset_max_swe_timing_WUS_ds['station_max_SWE_timing'].count(dim='station').plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['runoff_onset_timing'].count(dim='station').plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].count(dim='station').plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations'].mean(['station']).plot()

In [None]:
np.abs(runoff_onset_max_swe_timing_WUS_ds['sar_minus_stations']).mean(['station']).plot()

In [None]:
runoff_onset_max_swe_timing_WUS_ds.drop_vars('geometry').to_netcdf('snotel_sar_differences.nc')