In [8]:
import calendar
import logging
import os
import intake
import concurrent.futures
import numpy as np
import pandas as pd
import xarray as xr
import dask
import zarr
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster
import pyproj
import cartopy.crs as ccrs

In [2]:
r = '/caldera/hovenweep/projects/usgs/water'
d = os.path.join(r, 'wymtwsc', 'dketchum')
c404 = os.path.join(d, 'conus404')
dads = os.path.join(d, 'dads')
ghcn = os.path.join(d, 'climate', 'ghcn')

zarr_store = os.path.join(r, 'impd/hytest/conus404/conus404_hourly.zarr')
sites = os.path.join(dads, 'met', 'stations', 'madis_29OCT2024.csv')
csv_files = os.path.join(c404, 'station_data')

In [3]:
def get_quadrants(b):
    mid_longitude = (b[0] + b[2]) / 2
    mid_latitude = (b[1] + b[3]) / 2
    quadrant_nw = (b[0], mid_latitude, mid_longitude, b[3])
    quadrant_ne = (mid_longitude, mid_latitude, b[2], b[3])
    quadrant_sw = (b[0], b[1], mid_longitude, mid_latitude)
    quadrant_se = (mid_longitude, b[1], b[2], mid_latitude)
    quadrants = [quadrant_nw, quadrant_ne, quadrant_sw, quadrant_se]
    return quadrants


In [4]:
bounds = (-125.0, 25.0, -67.0, 53.0)
quadrants = get_quadrants(bounds)
sixteens = [get_quadrants(q) for q in quadrants]
sixteens = [x for xs in sixteens for x in xs]
sixteens[0]


(-125.0, 46.0, -110.5, 53.0)

In [5]:
stations = sites
nc_data = zarr_store
out_data = csv_files
workers=36
overwrite=False
bounds=sixteens[0]
start_yr=2014
end_yr=2014
mode = 'debug'

In [6]:
station_list = pd.read_csv(stations)
if 'LAT' in station_list.columns:
    station_list = station_list.rename(columns={'STAID': 'fid', 'LAT': 'latitude', 'LON': 'longitude'})
station_list.index = station_list['fid']

if bounds:
    w, s, e, n = bounds
    station_list = station_list[(station_list['latitude'] < n) & (station_list['latitude'] >= s)]
    station_list = station_list[(station_list['longitude'] < e) & (station_list['longitude'] >= w)]
else:
    ln = station_list.shape[0]
    w, s, e, n = (-125.0, 25.0, -67.0, 53.0)
    station_list = station_list[(station_list['latitude'] < n) & (station_list['latitude'] >= s)]
    station_list = station_list[(station_list['longitude'] < e) & (station_list['longitude'] >= w)]
    print('dropped {} stations outside NLDAS-2 extent'.format(ln - station_list.shape[0]))

print(f'{len(station_list)} stations to write')
print(f'sample stations for the selected region:\n {station_list.sample(n=5)}')

dates = [(year, month, calendar.monthrange(year, month)[-1])
         for year in range(start_yr, end_yr + 1) for month in range(1, 13)]

station_list

3092 stations to write
sample stations for the selected region:
          fid   latitude   longitude         elev      stype
fid                                                        
BF279  BF279  51.101871 -124.950180  1480.699951   utmesnet
AV924  AV924  46.042000 -118.341003   275.000000  APRSWXNET
PEFW1  PEFW1  47.152222 -120.946671  1225.300049       RAWS
D9605  D9605  48.478168 -111.352173   953.460022  APRSWXNET
BF163  BF163  49.906502 -116.855103  1018.900024   utmesnet


Unnamed: 0_level_0,fid,latitude,longitude,elev,stype
fid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BNBMD,BNBMD,47.644989,-122.529404,41.174999,AWS
BRMWA,BRMWA,47.641659,-122.607201,44.224998,AWS
CLLMC,CLLMC,47.204441,-120.974098,650.565002,AWS
EVRMT,EVRMT,47.919991,-122.221901,6.405000,AWS
GCOUL,GCOUL,47.938332,-119.004700,509.350006,AWS
...,...,...,...,...,...
2228P,2228P,47.646210,-122.696274,6.100000,utmesnet
D2826,D2826,48.076500,-123.448334,280.380005,APRSWXNET
AP250,AP250,46.638000,-111.943001,1138.000000,APRSWXNET
2306P,2306P,47.660450,-117.424278,568.099976,utmesnet


In [9]:
def projected_coords(row, _bounds=None):
    globe = ccrs.Globe(ellipse='sphere', semimajor_axis=6370000, semiminor_axis=6370000)
    lcc = ccrs.LambertConformal(globe=globe,
                                central_longitude=-97.9000015258789,
                                central_latitude=39.100006103515625,
                                standard_parallels=[30.0, 50.0])
    lcc_wkt = lcc.to_wkt()
    source_crs = 'epsg:4326'
    transformer = pyproj.Transformer.from_crs(source_crs, lcc_wkt)
    if _bounds is not None:
        west, south, east, north = _bounds
        sw_x, sw_y = transformer.transform(south, west)
        ne_x, ne_y = transformer.transform(north, east)
        return sw_x, sw_y, ne_x, ne_y
    else:
        x, y = transformer.transform(row['longitude'], row['latitude'])
        return x, y

bounds_proj = projected_coords(None, _bounds=bounds)
bounds_proj

(-2039715.0340867909,
 1070649.589525464,
 -849923.4175176456,
 1594627.8324152853)

In [10]:
output_mode = 'ba'
hytest_cat = intake.open_catalog(
    "https://raw.githubusercontent.com/hytest-org/hytest/main/dataset_catalog/hytest_intake_catalog.yml")
cat = hytest_cat['conus404-catalog']
if output_mode == 'uncorrected':
    # model output, uncorrected
    dataset = 'conus404-hourly-onprem-hw'
elif output_mode == 'ba':
    # bias-adjusted for precip and temp
    dataset = 'conus404-hourly-ba-onprem-hw'
else:
    raise ValueError('output_mode not recognized')
ds = cat[dataset].to_dask()

  'dims': dict(self._ds.dims),


In [11]:

def get_month_met(station_list_, date_, out_data, overwrite, bounds_=None, output_mode='uncorrected'):
    """"""
    import xoak
    year, month, month_end = date_

    # dataset 1979 to 2022-10-01
    if year == 2022 and month > 9:
        return
    date_string = '{}-{}'.format(year, str(month).rjust(2, '0'))

    fids = station_list_.index.to_list()

    hytest_cat = intake.open_catalog(
        "https://raw.githubusercontent.com/hytest-org/hytest/main/dataset_catalog/hytest_intake_catalog.yml")
    cat = hytest_cat['conus404-catalog']
    if output_mode == 'uncorrected':
        # model output, uncorrected
        dataset = 'conus404-hourly-onprem-hw'
        variables = ['T2', 'TD2', 'QVAPOR', 'U10', 'V10', 'PSFC', 'ACSWDNLSM']
        print('using uncorrected data')
    elif output_mode == 'ba':
        variables = ['RAINRATE', 'T2D']
        # bias-adjusted for precip and temp
        dataset = 'conus404-hourly-ba-onprem-hw'
        print('using bias-adjusted data')
    else:
        raise ValueError('output_mode not recognized')

    ds = cat[dataset].to_dask()
    # extract crs meta before continuing to modify ds
    bounds_proj = projected_coords(row=None, _bounds=bounds)
    ds = ds.sel(time=slice(f'{year}-{month}-01', f'{year}-{month}-{month_end}'))
    ds = ds[variables]
    if bounds_ is not None:
        ds = ds.sel(y=slice(bounds_proj[1], bounds_proj[3]),
                    x=slice(bounds_proj[0], bounds_proj[2]))
    if output_mode == 'uncorrected':
        station_list_ = station_list_.to_xarray()
        ds.xoak.set_index(['lat', 'lon'], 'sklearn_geo_balltree')
        ds = ds.xoak.sel(lat=station_list_.latitude, lon=station_list_.longitude, tolerance=4000)
    else:
        station_list_[['x', 'y']] = station_list_.apply(projected_coords, axis=1, result_type='expand')
        station_list_ = station_list_.to_xarray()
        ds = ds.sel(y=station_list_.y, x=station_list_.x, method='nearest', tolerance=4000)
    
    print(ds)
    ds = xr.merge([station_list_, ds])
    all_df = ds.to_dataframe()

    try:
        ct = 0
        for fid in fids:
            dst_dir = os.path.join(out_data, fid)
            if not os.path.exists(dst_dir):
                os.mkdir(dst_dir)
            _file = os.path.join(dst_dir, f'{fid}_{date_string}.parquet')
            if not os.path.exists(_file) or overwrite:
                df_station = all_df.loc[slice(fid), slice(None)].copy()
                df_station = df_station.groupby(df_station.index.get_level_values('time')).first()
                df_station['dt'] = [i.strftime('%Y%m%d%H') for i in df_station.index]
                df_station.to_parquet(_file, index=False)
                ct += 1
        if ct % 1000 == 0.:
            print(f'{ct} of {len(fids)} for {date_string}')
    except Exception as exc:
        print(f'{date_string}: {exc}')

    del ds

In [12]:
output_target = 'ba'
if mode == 'debug':
    for date in dates:
        get_month_met(station_list, date, out_data, overwrite, bounds, output_target)

elif mode == 'multi':
    with concurrent.futures.ProcessPoolExecutor(max_workers=workers) as executor:
        futures = [
            executor.submit(get_month_met, station_list, dt, out_data, overwrite, bounds, output_target)
            for dt in dates]
        concurrent.futures.wait(futures)

elif mode == 'dask':
    cluster = LocalCluster(n_workers=workers, memory_limit='32GB', threads_per_worker=1,
                           silence_logs=logging.ERROR)
    client = Client(cluster)
    print("Dask cluster started with dashboard at:", client.dashboard_link)
    station_list = client.scatter(station_list)
    tasks = [dask.delayed(get_month_met)(station_list, date, out_data, overwrite, bounds, output_target)
             for date in
             dates]
    dask.compute(*tasks)
    client.close()

using bias-adjusted data


  'dims': dict(self._ds.dims),


KeyError: "not all values found in index 'y'"