In [1]:
import os 
os.environ['PROJ_LIB'] = "/home/jesseake/skagit-met/.pixi/envs/analysis/share/proj"
import xarray as xr

In [2]:
import dask
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv
import datetime as dt
import hvplot.xarray
import hvplot.dask
import rioxarray as rxr
import geopandas as gpd
import shapely
hv.extension('bokeh')

In [None]:
from datetime import datetime as dt
from scipy.interpolate import griddata
ar_1990 = ('1990-11-03', '1990-11-17')
ar_1995 = ('1995-11-22', '1995-12-06')
ar_2003 = ('2003-10-14', '2003-10-28')
ar_2006 = ('2006-10-31', '2006-11-13')
ar_2011 = ('2011-01-01', '2011-02-01')
ar_2021 = ('2021-11-10', '2021-11-17')
ars = [ar_1990, ar_1995, ar_2003, ar_2006]
dem = xr.open_dataset('../Data/GIS/SkagitRiver_90mDEM.tif')
poly = gpd.read_file('../Data/GIS/SkagitBoundary.json').geometry
skagit_dem = dem.rio.clip(poly)

# Do this once, since it takes ~ 1 minute
def load_and_clip_pnnl():
    grid_path = '/data0/skagit_met/PNNL/historical/SERDP6km.geo_em.d01.nc'
    ds_grid = xr.open_dataset(grid_path).squeeze() # Drop Time=0 scalar dimension
    # Assign Coordinates in Xarray (keep landmask for viz) 
    dsg = ds_grid[['LANDMASK','CLONG','CLAT']]
    # Still not recognized as multidimensional coordinates
    # Rename to x and y To match data files
    dsg = dsg.set_coords(("CLONG", "CLAT")).rename(dict(south_north='y', west_east='x'))
    virtual_pnnl = '/data0/skagit_met/PNNL/historical/PNNL_historical.parquet'
    pnnl = xr.open_dataset(virtual_pnnl, engine='kerchunk', mask_and_scale=False)
    # pnnl = pnnl.assign_coords(dsg.coords)
    pnnl = pnnl.assign_coords(
        CLONG=(("y", "x"), dsg["CLONG"].values),
        CLAT=(("y", "x"), dsg["CLAT"].values),
        LANDMASK=(("y", "x"), dsg["LANDMASK"].values)
    )
    pnnl = pnnl.set_coords(("CLONG", "CLAT"))
    # Also bring in land mask
    pnnl.coords['LANDMASK'] = dsg.LANDMASK
    mask = shapely.contains_xy(poly.geometry[0], pnnl.CLONG.values, pnnl.CLAT.values)
    mask_da = xr.DataArray(data=mask, dims=['y', 'x'], coords=dict(CLONG=(['y','x'], pnnl.CLONG.values),CLAT=(['y','x'], pnnl.CLAT.values)))
    pnnl = pnnl.where(mask_da, drop=True)
    mask_da = None
    mask = None
    return pnnl

# Remmeber WRF and PNNL are hourly
# SNOTEL and WRF alread have hgt/HGT data
def openDataset(dates: tuple[str], product: str, frequency: str = '', resolution: str = '', pnnl_ds = None) -> xr.Dataset:
    if product.lower() == 'pnnl':
        ds = pnnl_ds.sel(time=slice(dates[0], dates[1]))
    elif product.lower() == 'ornl':
        start = dt.fromisoformat(dates[0]).year
        end = dt.fromisoformat(dates[1]).year
        ds = xr.open_zarr(f'/data0/skagit_met/atmospheric_rivers/{start}_{end}_ORNL_data.zarr')
        ds = ds.sel(time=slice(dates[0],dates[1]))
    else:
        ds = xr.open_zarr(f'/data0/skagit_met/atmospheric_rivers/{dates[0]}_{dates[1]}{"_"+frequency if frequency else ""}{"_"+resolution if resolution else ""}_{product}_data.zarr')

    if product in ['wrf_era5', 'pnnl']:
        ds['T2C'] = ds['T2'] - 273.15

    if product == 'SNOTEL':
        if 'AIR TEMP' in ds:
            ds['T2C'] = (ds['AIR TEMP'] - 32) * (5/9)

        if 'AVG AIR TEMP' in ds:
            ds['AVG_T2C'] = (ds['AVG AIR TEMP'] - 32) * (5/9)
        
        ds['elevation_m'] = ds['elevation_ft'] * 0.3048

    if product.lower() == 'wrf_era5':
        ds['PRCP'] = (ds['RAINC']+ ds['RAINNC'])
        try:
            ds = ds.set_coords('HGT')
        except Exception:
            ds = ds.set_coords('hgt')

    if product.lower() == 'ornl':
        ds['tmean'] = (ds.tmax + ds.tmin) / 2

    return ds

def embedDEM(ds, dem, ds_x, ds_y, dem_x='x', dem_y='y', dem_var=None, method='nearest'):
    # Pick DEM variable
    if dem_var is None:
        dem_var = [v for v in dem.data_vars][0]
        
    if 'band' in dem[dem_var].dims:
        dem_elev = dem[dem_var].isel(band=0).squeeze()
    else:
        dem_elev = dem[dem_var]

    # If DEM coordinates are 1D, meshgrid to get 2D arrays
    dem_lon = dem[dem_x]
    dem_lat = dem[dem_y]
    if dem_lon.ndim == 1 and dem_lat.ndim == 1:
        dem_lon2d, dem_lat2d = np.meshgrid(dem_lon, dem_lat)
    else:
        dem_lon2d, dem_lat2d = dem_lon, dem_lat
    
    dem_points = np.column_stack((dem_lon2d.ravel(), dem_lat2d.ravel()))
    dem_values = dem_elev.values.ravel()

    # Get target coordinates
    target_x = ds[ds_x]
    target_y = ds[ds_y]

    # If both are 1D: use xarray.interp
    if target_x.ndim == 1 and target_y.ndim == 1:
        elev_interp = dem_elev.interp(
            {dem_x: target_x, dem_y: target_y}, method=method
        )
        dims = target_y.dims + target_x.dims
        new_ds = ds.assign_coords(elevation=(dims, elev_interp.values))
    # If either is 2D: use griddata
    else:
        # target_x and target_y must be same shape
        elev_interp = griddata(
            points=dem_points,
            values=dem_values,
            xi=(target_x, target_y),
            method=method,
            fill_value=np.nan
        )
        # Use target_y.dims (should be ('y', 'x')) ONCE
        dims = ds[ds_y].dims
        new_ds = ds.assign_coords(elevation=(dims, elev_interp))
    new_ds = new_ds.set_coords('elevation')
    return new_ds

Load datasets, and embed DEM

In [None]:
pnnl = load_and_clip_pnnl()

pnnl_1990 = embedDEM(openDataset(ar_1990, 'pnnl', pnnl_ds = pnnl), skagit_dem, 'CLONG', 'CLAT')
pnnl_1995 = embedDEM(openDataset(ar_1995, 'pnnl', pnnl_ds = pnnl), skagit_dem, 'CLONG', 'CLAT')
pnnl_2003 = embedDEM(openDataset(ar_2003, 'pnnl', pnnl_ds = pnnl), skagit_dem, 'CLONG', 'CLAT')
pnnl_2006 = embedDEM(openDataset(ar_2006, 'pnnl', pnnl_ds = pnnl), skagit_dem, 'CLONG', 'CLAT')
pnnl_2011 = embedDEM(openDataset(ar_2011, 'pnnl', pnnl_ds = pnnl), skagit_dem, 'CLONG', 'CLAT')

ornl_1990 = embedDEM(openDataset(ar_1990, 'ORNL'), skagit_dem, 'lon', 'lat')
ornl_1995 = embedDEM(openDataset(ar_1995, 'ORNL'), skagit_dem, 'lon', 'lat')
ornl_2003 = embedDEM(openDataset(ar_2003, 'ORNL'), skagit_dem, 'lon', 'lat')
ornl_2006 = embedDEM(openDataset(ar_2006, 'ORNL'), skagit_dem, 'lon', 'lat')
ornl_2011 = embedDEM(openDataset(ar_2011, 'ORNL'), skagit_dem, 'lon', 'lat')

prism_1990 = embedDEM(openDataset(ar_1990, 'PRISM'), skagit_dem, 'lon', 'lat')
prism_1995 = embedDEM(openDataset(ar_1995, 'PRISM'), skagit_dem, 'lon', 'lat')
prism_2003 = embedDEM(openDataset(ar_2003, 'PRISM'), skagit_dem, 'lon', 'lat')
prism_2006 = embedDEM(openDataset(ar_2006, 'PRISM'), skagit_dem, 'lon', 'lat')
prism_2011 = embedDEM(openDataset(ar_2011, 'PRISM', 'daily', '4km'), skagit_dem, 'lon', 'lat')
prism_1990_800m = embedDEM(openDataset(ar_1990, 'PRISM', 'daily', '800m'), skagit_dem, 'lon', 'lat')
prism_1995_800m = embedDEM(openDataset(ar_1995, 'PRISM', 'daily', '800m'), skagit_dem, 'lon', 'lat')
prism_2003_800m = embedDEM(openDataset(ar_2003, 'PRISM', 'daily', '800m'), skagit_dem, 'lon', 'lat')
prism_2006_800m = embedDEM(openDataset(ar_2006, 'PRISM', 'daily', '800m'), skagit_dem, 'lon', 'lat')
prism_2011_800m = embedDEM(openDataset(ar_2011, 'PRISM', 'daily', '4km'), skagit_dem, 'lon', 'lat')

wrf_1990 = openDataset(ar_1990, 'wrf_era5')
wrf_1995 = openDataset(ar_1995, 'wrf_era5')
wrf_2003 = openDataset(ar_2003, 'wrf_era5')
wrf_2006 = openDataset(ar_2006, 'wrf_era5')
wrf_2011 = openDataset(ar_2011, 'wrf_era5')

# prism_2021_800m = embedDEM(openDataset(ar_2021, 'PRISM', 'daily', '800m'), skagit_dem, 'lon', 'lat')
# hrrr_2021 = xr.open_dataset('/data0/skagit_met/atmospheric_rivers/2021-11-10_2021-11-17_HRRR_data.zarr')

Delete large dataset to save memory

In [None]:
pnnl = None

already accumulated, so difference to get hourly accumulation

In [None]:
wrf_1990['PRCP'] = wrf_1990.PRCP.diff(dim='time')
wrf_1995['PRCP'] = wrf_1995.PRCP.diff(dim='time')
wrf_2003['PRCP'] = wrf_2003.PRCP.diff(dim='time')
wrf_2006['PRCP'] = wrf_2006.PRCP.diff(dim='time')
wrf_2011['PRCP'] = wrf_2011.PRCP.diff(dim='time')

In [None]:
prism_90_prcp = prism_1990.ppt.mean(dim=['lat', 'lon']).cumsum()
prism_90_800m_prcp = prism_1990_800m.ppt.mean(dim=['lat', 'lon']).cumsum()
wrf_90_prcp = (wrf_1990.PRCP.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()
ornl_90_prcp = ornl_1990.prcp.mean(dim=['lat', 'lon']).cumsum()
pnnl_90_prcp = (pnnl_1990.PREC_ACC_NC.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()

prism_95_prcp = prism_1995.ppt.mean(dim=['lat', 'lon']).cumsum()
prism_95_800m_prcp = prism_1995_800m.ppt.mean(dim=['lat', 'lon']).cumsum()
wrf_95_prcp = (wrf_1995.PRCP.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()
ornl_95_prcp = ornl_1995.prcp.mean(dim=['lat', 'lon']).cumsum()
pnnl_95_prcp = (pnnl_1995.PREC_ACC_NC.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()


prism_03_prcp = prism_2003.ppt.mean(dim=['lat', 'lon']).cumsum()
prism_03_800m_prcp = prism_2003_800m.ppt.mean(dim=['lat', 'lon']).cumsum()
### HOURLY
wrf_03_prcp = (wrf_2003.PRCP.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()
### DAILy
ornl_03_prcp = ornl_2003.prcp.mean(dim=['lat', 'lon']).cumsum()
### HOURLY
pnnl_03_prcp = (pnnl_2003.PREC_ACC_NC.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()

prism_06_prcp = prism_2006.ppt.mean(dim=['lat', 'lon']).cumsum()
prism_06_800m_prcp = prism_2006_800m.ppt.mean(dim=['lat', 'lon']).cumsum()
wrf_06_prcp = (wrf_2006.PRCP.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()
ornl_06_prcp = ornl_2006.prcp.mean(dim=['lat', 'lon']).cumsum()
pnnl_06_prcp = (pnnl_2006.PREC_ACC_NC.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()

prism_11_prcp = prism_2011.ppt.mean(dim=['lat', 'lon']).cumsum()
prism_11_800m_prcp = prism_2011_800m.ppt.mean(dim=['lat', 'lon']).cumsum()
wrf_11_prcp = (wrf_2011.PRCP.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()
ornl_11_prcp = ornl_2011.prcp.mean(dim=['lat', 'lon']).cumsum()
pnnl_11_prcp = (pnnl_2011.PREC_ACC_NC.mean(dim=['y', 'x']).cumsum()).resample(time='1D').max()

In [None]:
# prcp_90 = ((prism_90_prcp).hvplot.line(label='PRISM 4KM') *\
# (prism_90_800m_prcp).hvplot.line(label='PRISM 800m') *\
# (wrf_90_prcp).hvplot.line(label='ERA5 WRF (UCLA) 9KM') *\
# (ornl_90_prcp).hvplot.line(label='ORNL 4KM') *\
# (pnnl_90_prcp).hvplot.line(label='PNNL 6KM')).opts(title='1990 Basin Cumulative Precipitation', width=600, height=600).opts(show_grid=True, bgcolor='lightgray', show_legend=True,
#                                                                                                                                            legend_position='bottom_right')

# prcp_95 = ((prism_95_prcp).hvplot.line(label='PRISM 4KM') *\
# (prism_95_800m_prcp).hvplot.line(label='PRISM 800m') *\
# (wrf_95_prcp).hvplot.line(label='ERA5 WRF (UCLA) 9KM') *\
# (ornl_95_prcp).hvplot.line(label='ORNL 4KM') *\
# (pnnl_95_prcp).hvplot.line(label='PNNL 6KM')).opts(title='1995 Basin Cumulative Precipitation', width=600, height=600).opts(show_grid=True, bgcolor='lightgray', show_legend=True,
#                                                                                                                                            legend_position='bottom_right')

# prcp_03 = ((prism_03_prcp).hvplot.line(label='PRISM 4KM') *\
# (prism_03_800m_prcp).hvplot.line(label='PRISM 800m') *\
# (wrf_03_prcp).hvplot.line(label='ERA5 WRF (UCLA) 9KM') *\
# (ornl_03_prcp).hvplot.line(label='ORNL 4KM') *\
# (pnnl_03_prcp).hvplot.line(label='PNNL 6KM')).opts(title='2003 Basin Cumulative Precipitation', width=600, height=600).opts(show_grid=True, bgcolor='lightgray', show_legend=True,
#                                                                                                                                            legend_position='bottom_right')

# prcp_06 = ((prism_06_prcp).hvplot.line(label='PRISM 4KM') *\
# (prism_06_800m_prcp).hvplot.line(label='PRISM 800m') *\
# (wrf_06_prcp).hvplot.line(label='ERA5 WRF (UCLA) 9KM') *\
# (ornl_06_prcp).hvplot.line(label='ORNL 4KM') *\
# (pnnl_06_prcp).hvplot.line(label='PNNL 6KM')).opts(title='2006 Basin Cumulative Precipitation', width=600, height=600).opts(show_grid=True, bgcolor='lightgray', show_legend=True,
                                                                                                                                           # legend_position='bottom_right')

# prcp_11 = ((prism_11_prcp).hvplot.line(label='PRISM 4KM') *\
# (prism_11_800m_prcp.cumsum()/10000).hvplot.line(label='PRISM 800m') *\
# (wrf_11_prcp).hvplot.line(label='ERA5 WRF (UCLA) 9KM') *\
# (ornl_11_prcp).hvplot.line(label='ORNL 4KM') *\
# (pnnl_11_prcp).hvplot.line(label='PNNL 6KM')).opts(title='2011 Basin Cumulative Precipitation', width=600, height=600).opts(show_grid=True, bgcolor='lightgray', show_legend=True,
                                                                                                                                           # legend_position='bottom_right')

# cum_precip_plots = (prcp_90 + prcp_95 + prcp_03 + prcp_06 + prcp_11).opts(shared_axes=False).cols(2).opts(legend_position='bottom_left').redim.label(time='Date', ppt='Volume of Precip (m^3)')

In [None]:
# cum_precip_plots

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3, 2, figsize=(15, 18))
axs = axs.flatten()

years = [
    {
        'title': '1990 Basin Cumulative Precipitation',
        'series': [
            (prism_90_prcp, 'PRISM 4KM'),
            (prism_90_800m_prcp, 'PRISM 800m'),
            (wrf_90_prcp, 'ERA5 WRF (UCLA) 9KM'),
            (ornl_90_prcp, 'ORNL 4KM'),
            # (pnnl_90_prcp, 'PNNL 6KM'),
        ]
    },
    {
        'title': '1995 Basin Cumulative Precipitation',
        'series': [
            (prism_95_prcp, 'PRISM 4KM'),
            (prism_95_800m_prcp, 'PRISM 800m'),
            (wrf_95_prcp, 'ERA5 WRF (UCLA) 9KM'),
            (ornl_95_prcp, 'ORNL 4KM'),
            # (pnnl_95_prcp, 'PNNL 6KM'),
        ]
    },
    {
        'title': '2003 Basin Cumulative Precipitation',
        'series': [
            (prism_03_prcp, 'PRISM 4KM'),
            (prism_03_800m_prcp, 'PRISM 800m'),
            (wrf_03_prcp, 'ERA5 WRF (UCLA) 9KM'),
            (ornl_03_prcp, 'ORNL 4KM'),
            # (pnnl_03_prcp, 'PNNL 6KM'),
        ]
    },
    {
        'title': '2006 Basin Cumulative Precipitation',
        'series': [
            (prism_06_prcp, 'PRISM 4KM'),
            (prism_06_800m_prcp, 'PRISM 800m'),
            (wrf_06_prcp, 'ERA5 WRF (UCLA) 9KM'),
            (ornl_06_prcp, 'ORNL 4KM'),
            # (pnnl_06_prcp, 'PNNL 6KM'),
        ]
    },
    {
        'title': '2011 Basin Cumulative Precipitation',
        'series': [
            (prism_11_prcp, 'PRISM 4KM'),
            (prism_11_800m_prcp, 'PRISM 800m'),
            (wrf_11_prcp, 'ERA5 WRF (UCLA) 9KM'),
            (ornl_11_prcp, 'ORNL 4KM'),
            # (pnnl_11_prcp, 'PNNL 6KM'),
        ]
    }
]

for i, year in enumerate(years):
    ax = axs[i]
    for da, label in year['series']:
        da.plot(ax=ax, label=label)
    ax.set_title(year['title'])
    ax.set_xlabel('Date')
    ax.set_ylabel('Average Accumulated Precip (mm)')
    ax.legend()
    ax.grid(True)
    ax.set_facecolor('lightgray')

axs[-1].remove()
plt.tight_layout()
plt.show()

In [None]:
fig.savefig('cumulative_precip_all_years.png', dpi=300)