In [3]:
import os 
os.environ['PROJ_LIB'] = "/home/jesseake/skagit-met/.pixi/envs/analysis/share/proj"
import xarray as xr

In [2]:
import dask
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv
import datetime as dt
import hvplot.xarray
import rioxarray as rxr
import geopandas as gpd
from shapely import vectorized
hv.extension('bokeh', 'matplotlib')
# hv.extension('matplotlib')

In [None]:
from datetime import datetime as dt
ar_1990 = ('1990-11-03', '1990-11-17')
ar_1995 = ('1995-11-22', '1995-12-06')
ar_2003 = ('2003-10-14', '2003-10-28')
ar_2006 = ('2006-10-31', '2006-11-13')
ar_2011 = ('2011-01-01', '2011-02-01')
ar_2021 = ('2021-11-10', '2021-11-17')
ars = [ar_1990, ar_1995, ar_2003, ar_2006]
dem = xr.open_dataset('../Data/GIS/SkagitRiver_90mDEM.tif')
poly = gpd.read_file('../Data/GIS/SkagitBoundary.json').geometry
skagit_dem = dem.rio.clip(poly)

# Remmeber WRF and PNNL are hourly
# SNOTEL and WRF alread have hgt/HGT data
def openDataset(dates: tuple[str], product: str, frequency: str = '', resolution: str = '') -> xr.Dataset:
    if product.lower() == 'pnnl':
        grid_path = '/data0/skagit_met/PNNL/historical/SERDP6km.geo_em.d01.nc'
        ds_grid = xr.open_dataset(grid_path).squeeze() # Drop Time=0 scalar dimension
        # Assign Coordinates in Xarray (keep landmask for viz) 
        dsg = ds_grid[['LANDMASK','CLONG','CLAT']]
        # Still not recognized as multidimensional coordinates
        # Rename to x and y To match data files
        dsg = dsg.set_coords(("CLONG", "CLAT")).rename(dict(south_north='y', west_east='x'))
        virtual_pnnl = '/data0/skagit_met/PNNL/historical/PNNL_historical.parquet'
        pnnl = xr.open_dataset(virtual_pnnl, engine='kerchunk', mask_and_scale=False)
        pnnl.assign_coords(dsg.coords)
        # Also bring in land mask
        pnnl.coords['LANDMASK'] = dsg.LANDMASK
        # skagit = gpd.read_file('~/skagit-met/Data/GIS/SkagitBoundary.json')
        mask = vectorized.contains(poly.geometry[0], pnnl.CLONG.values, pnnl.CLAT.values)
        ds = pnnl.where(mask)
        ds = ds.sel(time=slice(dates[0], dates[1]))
    elif product.lower() == 'ornl':
        start = dt.fromisoformat(dates[0]).year
        end = dt.fromisoformat(dates[1]).year
        ds = xr.open_zarr(f'/data0/skagit_met/atmospheric_rivers/{start}_{end}_ORNL_data.zarr')
        ds = ds.sel(time=slice(dates[0],dates[1]))
    else:
        ds = xr.open_zarr(f'/data0/skagit_met/atmospheric_rivers/{dates[0]}_{dates[1]}{"_"+frequency if frequency else ""}{"_"+resolution if resolution else ""}_{product}_data.zarr')

    if product in ['wrf_era5', 'pnnl']:
        ds['T2C'] = ds['T2'] - 273.15

    if product == 'SNOTEL':
        if 'AIR TEMP' in ds:
            ds['T2C'] = (ds['AIR TEMP'] - 32) * (5/9)

        if 'AVG AIR TEMP' in ds:
            ds['AVG_T2C'] = (ds['AVG AIR TEMP'] - 32) * (5/9)
        
        ds['elevation_m'] = ds['elevation_ft'] * 0.3048

    if product.lower() == 'wrf_era5':
        ds['PRCP'] = (ds['RAINC']+ ds['RAINNC'])
        try:
            ds = ds.set_coords('HGT')
        except Exception:
            ds = ds.set_coords('hgt')

    if product.lower() == 'ornl':
        ds['tmean'] = (ds.tmax + ds.tmin) / 2

    return ds

## Us CLAT for pnnnl x/y and x,y for index_x, index_y
def embedDEM(ds: xr.Dataset, dem: xr.Dataset, x: str, y: str, index_x: str = None, index_y: str = None, method: str  = 'nearest') -> xr.Dataset:
    dem_resampled = dem.interp(x=ds[x], y=ds[y], method=method)
    avg_elevation = dem_resampled.mean(dim='band')
    ds['elevation'] = ((index_y if index_y else y, index_x if index_x else x), avg_elevation.band_data.values)
    ds = ds.set_coords('elevation')
    return ds

In [79]:
# PNNL - 6 KM
# PRISM - 4 KM
# UCLA wrf_era5 - 9KM 
# ORNL - 4KM
ornl_file = '../data/weather_data/1981_2011_ORNL_data.zarr'
ornl = embedDEM(openDataset(ornl_file, 'ORNL'), skagit_dem, 'lon', 'lat')

In [80]:
# Group by Water Year
# ornl['water_year'] = ornl.time(lambda x: x.year if x.month < 10 else x.year + 1)
# z = ornl['time'].year if ornl['time'].month < 10 else ornl['time'].year + 1
# ornl.map_blocks(lambda x: x.time.year if x.time.month < 10 else x.time.year + 1)
# Define the function to calculate water year
def calculate_water_year(time):
    """Calculate the water year based on the time coordinate."""
    time = pd.to_datetime(time)  # Ensure the time is in datetime format
    water_year = time.year + (time.month >= 10)  # Add 1 to the year if month >= October
    return water_year

# Define a wrapper function for map_blocks
def add_water_year(da):
    """Add a water_year dimension to the DataArray."""
    # Apply the calculate_water_year function to each time coordinate
    water_years = xr.apply_ufunc(
        calculate_water_year,
        da["time"],
    )
    # Assign the water year as a new coordinate
    da = da.assign_coords(water_year=("time", water_years.data))
    return da

ornl = ornl.map_blocks(add_water_year)

In [99]:
(ornl.prcp.sum(dim=['lat', 'lon']).groupby('water_year').sum()/1000).hvplot.bar(x='water_year' , xlabel='Water Year', ylabel='Precip (m)', xticks=ornl.water_year.values, y='prcp')#cumsum()/1000).hvplot.line(by='water_year', x='time.dayofyear')

In [123]:
(ornl.prcp.sum(dim=['lat', 'lon']).groupby('water_year').cumsum()/1000).groupby('water_year').max().hvplot.line()

In [76]:
# t = xr.open_dataset('../data/weather_data/test/DaymetV4_VIC4_rhum_2003.nc')

In [78]:
# t