In [None]:
%load_ext autoreload
%autoreload 2

import zarr
import xarray as xr
import os

In [None]:
# rechunked version
mapper = zarr.storage.ABSStore(
    'carbonplan-scratch',
    prefix='rechunker/terraclimate/target.zarr/',
    account_name="carbonplan",
    account_key=os.environ["BLOB_ACCOUNT_KEY"])

ds_conus = xr.open_zarr(mapper, consolidated=True)
ds_conus

In [None]:
from carbonplan.data import cat

ds_global = cat.terraclimate.raw_terraclimate.to_dask()

In [None]:
dem = xr.open_rasterio('https://storage.googleapis.com/carbonplan-data/processed/edna/conus/4000m/raster.tif').load()
dem = dem.where(dem > -1000).squeeze(drop=True)
ds_conus['dem'] = dem
dem

In [None]:

awc = xr.open_rasterio('/home/jovyan/awc_4000m.tif').load()
awc = awc.where(awc != 255).squeeze(drop=True)
ds_conus['awc'] = awc
awc

In [None]:
from metpy.calc import dewpoint
from metpy.units import units
import numpy as np
from cmip6_downscaling.disagg import terraclimate

In [None]:
ds_conus

In [None]:
# df = ds_conus.isel(x=200, y=200).squeeze(drop=True).to_dataframe()
df = ds_conus.sel(x=-2e6, y=3e6, method='nearest').squeeze(drop=True).to_dataframe()
df['awc'] = df['awc'] / 100 * 1000

In [None]:
df_global = ds_global.sel(lat=df.lat[0], lon=df.lon[0], method='nearest').squeeze(drop=True).to_dataframe()
df_global.head()

In [None]:
df.update(df_global)


In [None]:
df['tmean'] = (df.tmax + df.tmin) / 2
df['tdew'] = np.asarray(dewpoint(df['vap'].values * units.pascal * 1000))
df.head()

In [None]:
%time
WM2_TO_MGM2D = 86400 / 1e6

import pandas as pd
import matplotlib.pyplot as plt

df_v2 = pd.DataFrame(index=df.index, columns=['snowpack', 'h2o_input', 'albedo', 'et0', 'aet', 'soil', 'runoff'])
# df['awc'] *= 25.4 * 10

snowpack_prev = 0.
tmean_prev = df['tmean'][0]
soil_prev = 0 # df['awc'][0]

for i, row in df.iterrows():
    out = terraclimate.snowmod(
        row['tmean'],
        row['ppt'],
        radiation=row['srad'] * WM2_TO_MGM2D,
        snowpack_prev=snowpack_prev)
    
    out['et0'] = terraclimate.monthly_et0(
        row['srad'] * WM2_TO_MGM2D,
        row['tmax'],
        row['tmin'],
        row['ws'],
        row['tdew'],
        tmean_prev,
        row['lat'],
        row['dem'],
        i.month - 1,
    )

    out.update(terraclimate.aetmod(
        out['et0'],
        out['h2o_input'],
        row['awc'] ,
        soil_prev=soil_prev))
    df_v2.loc[i] = out
    
    tmean_prev = row['tmean']
    snowpack_prev = out['snowpack']
    soil_prev = out['soil']

In [None]:
var = 'runoff'
s = slice(-48, None)
df.q[s].plot(label='v1')
df_v2[var][s].plot(label='v2')
plt.legend()

In [None]:
s = slice(-48, None)
df.pet[s].plot(label='pet-v1')
# df_v2.et0[s].plot(label='et0-v2')
df.aet[s].plot(label='v1')
df_v2.aet[s].plot(label='v2')
plt.legend()

In [None]:
s = slice(-72, None)
df.swe[s].plot(label='v1')
df_v2.snowpack[s].plot(label='v2')
plt.legend()

In [None]:
s = slice(-48, None)
df.ppt[s].plot(label='ppt')
df.tmin[s].plot(label='tmin', secondary_y=True)
plt.legend()