## Retrieve NWM retrospective precipitation forcing data (Grid-to-point) from version 2.1

* This code retrieves precipitation data from the NWM retrospective dataset stored in AWS (https://registry.opendata.aws/nwm-archive/). 
* The values are extracted from the grid cell that matches a lat/lon location.

In [None]:
import os
import pandas as pd
import xarray
import numpy
import pyproj
import s3fs
from datetime import datetime
import hvplot.xarray
import dask.array as da

from dask.distributed import Client
client = Client()
client

In [None]:
# Path where the precipitation data lives
s3_path = 's3://noaa-nwm-retrospective-2-1-zarr-pds/precip.zarr'

In [None]:
# Connect to S3
s3 = s3fs.S3FileSystem(anon=True)
store = s3fs.S3Map(root=s3_path, s3=s3, check=False)

In [None]:
%%time
# load the dataset
ds = xarray.open_zarr(store=store, consolidated=True)

#### Define lat/lon locations where data will be extracted

In [None]:
sitesPath = './Input/'

In [None]:
# read csv with multiple lat/lon locations
#------------------------------------------
sites_loc = pd.read_csv(sitesPath+'selStn_precip.csv',dtype={'siteID': 'string','name':'string','Source': 'string'})
lat = sites_loc['latitude'].values.tolist()
lon = sites_loc['longitude'].values.tolist()
siteIDs = sites_loc['siteID'].values.tolist()

In [None]:
# define the input crs
wrf_proj = pyproj.Proj(proj='lcc',
                       lat_1=30.,
                       lat_2=60., 
                       lat_0=40.0000076293945, lon_0=-97., # Center point
                       a=6370000, b=6370000) 

# define a target coordinate system to convert locations into the projection of our forcing data
target_crs = wrf_proj

# Obs proj.
wgs_proj = pyproj.Proj(proj='latlong', datum='WGS84')

# Define transformer to reproject the station locations to the coordinates of AORC/NWM
transformer = pyproj.Transformer.from_crs(wgs_proj.crs, target_crs.crs)

In [None]:
sites = numpy.array([siteIDs])

# Reproject to AORC/NWM coordinates
xx0, yy0 = transformer.transform(lon,lat)

xx = xarray.DataArray(xx0, coords=sites, dims=['location'])
yy = xarray.DataArray(yy0, coords=sites, dims=['location'])

* [OPTIONAL] Add lat/lon coordinates to the NWM dataset (used for plotting and additional reference)

In [None]:
# create a 2D grid of coordinate values
X, Y = numpy.meshgrid(ds.x.values, ds.y.values)

# transform X, Y into Lat, Lon
transformer = pyproj.Transformer.from_crs(wrf_proj.crs, wgs_proj.crs)
lon, lat = transformer.transform(X, Y)

# add geographical coordinate values (log and lat) to the dataset
ds = ds.assign_coords(lon = (['y', 'x'], lon))
ds = ds.assign_coords(lat = (['y', 'x'], lat))

#add crs to file
ds.rio.write_crs(ds.crs.attrs['spatial_ref'], inplace=True
                ).rio.set_spatial_dims(x_dim="x",
                                       y_dim="y",
                                       inplace=True,
                                       ).rio.write_coordinate_system(inplace=True)

# make sure the data is sorted by time
ds = ds.sortby('time')

#### Extract data over the full retrospective period by time chunks

In [None]:
savePath = "./Output/"

In [None]:
# slice all data in time chunks
start_date = datetime.strptime("1979-02-01 00:00:00", "%Y-%m-%d %H:%M:%S")
end_date = datetime.strptime("2021-01-01 00:00:00", "%Y-%m-%d %H:%M:%S")

date_list = pd.date_range(start_date, end_date, periods=11) # Adjust number of periods as needed
print(date_list)

In [None]:
%%time
# Loop to process the data in time chunks
max_lon = ds["x"].max()
max_lat = ds["y"].max()

for i in range(len(date_list)-1):
    print("Processing block of dates from",date_list[i], "to",date_list[i+1])
    timerange = slice(str(date_list[i]), str(date_list[i+1]))
    print(timerange)
    dat = ds.sel(time=timerange,x=slice(1e6,max_lon),y=slice(0,max_lat)).RAINRATE.persist()
    
    # Extract the values at the point locations
    values_temp = dat.sel(x=xx, y=yy, method='nearest').to_dataframe()
    values_temp.to_csv(f'{savePath}NWM_GTPprecipretro_{str(i)}.csv')
    print("Data saved...Done")
    
    # Delete unnecesary data to save memory
    del(dat,values_temp)