In [None]:
from dask.distributed import Client,LocalCluster
import xarray as xr
import dask
import glob
import geopandas as gpd
import rioxarray as rio
import shapely
import os
os.environ['GDAL_DATA'] = os.environ['CONDA_PREFIX'] + r'/Library/share/gdal'
os.environ['PROJ_LIB'] = os.environ['CONDA_PREFIX'] + r'/Library/share'

In [None]:
# data_dir=r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/'
# shpfile=r'C://Users/kerrie/Documents/02_LocalData/boundaries/study_area_bounding_box.shp'
data_dir=r'E://data/nclimgrid_daily/'
shpfile=r'E://data/boundaries/study_area_bounding_box.shp'

In [None]:
# this may require adjustment on other computers

# kerrie laptop client (machine has 32GB RAM and 20 threads)
# nworkers=20  # set equal to number of threads you have

# kerrie desktop client (machine has 64GB RAM and 16 threads)
nworkers=16  # set equal to number of threads you have

cluster=LocalCluster(n_workers=nworkers,threads_per_worker=1)
client=Client(cluster) 
client

## Data Cleaning

1) drop any variables we don't need
2) dim names are time, lat, lon
3) time is correct datetimes
4) lat is ascending
5) lon is ascending -180 to 180
6) variable names tmax,tmin,prcp
7) units of C for temperature
8) units of mm/day for precip
9) spatial subset
11) round all values to 2 decimal places
12) save data with smallest possible precision
13) write cleaned files, 1 file per variable


In [None]:
files=sorted(glob.glob(data_dir+'orig/ncdd*-grd-scaled.nc'))
len(files), files[0:3]

In [None]:
# lazy dataset, all 800+ files
ds=xr.open_mfdataset(files,chunks='auto',lock=False)#.chunk({'time':-1,'lat':-1,'lon':25})
# ds=xr.open_mfdataset(files,chunks={'time':-1,'lat':-1,'lon':25},lock=False)

ds

In [None]:
# 1) drop data we don't need"
ds = ds.drop_vars('tavg')

steps 2-7 aren't necessary, data already looks good wrt to these items

In [None]:
# step 8 millimeter --> mm/day
ds.prcp.attrs['units']='mm/day'

In [None]:
# step 9 spatial subset
# clip data to a bounding box

# get clip object
box=gpd.read_file(shpfile)

# assign crs to netcdf data
ds.rio.write_crs("epsg:4326",inplace=True)
ds_clip=ds.rio.clip(box.geometry.apply(shapely.geometry.mapping),box.crs,drop=True,invert=False)
ds_clip

In [None]:
# save metadata separately for later
coords = ds_clip.coords
prcp_attrs = ds_clip.prcp.attrs
tmax_attrs = ds_clip.tmax.attrs
tmin_attrs = ds_clip.tmin.attrs
dims = ds_clip.dims
print(dims)
coords

step 10 & 11, round and reduce precision

In [None]:
# choosing a single time to test data precision, looking for values on order of at least 100
testtime='2000-06-4'
ds_clip.prcp.sel(time=testtime).plot()

In [None]:
# float16 is probably not enough precision, let's check on a subset of the data

print('loading float16')
prcp_16 = ds_clip.prcp.sel(time=testtime).astype('float16').load()
print('loading float32')
prcp_32 = ds_clip.prcp.sel(time=testtime).astype('float32').load()
print('loading float64')
prcp_64 = ds_clip.prcp.sel(time=testtime).load()

prcp_16.max().item(),prcp_32.max().item(),prcp_64.max().item()

so we can change the data type from float64 to float32 but not go any smaller

In [None]:
# steps 10 & 11
ds = ds_clip.round(decimals=2)
ds = ds_clip.astype('float32')
ds

# write files

I think what we have to do to write each chunk to a separate file is:
- chunk xr arrays in space instead of time
- convert to numpy so we can use to_delayed and ravel
- every worker needs the xr metadata for variable and coordinates
- write a dask delayed function that takes the numpy array data and the metadata separately
- inside the dask delayed function re-create the xarray object and write to file

In [None]:
# chunking along longitude only
ds = ds.chunk({'time':-1,'lat':-1,'lon':12})
ds

In [None]:
# a function to return a list of data chunks 
# and the corresponding longitude coord chunks
def xr_ds_to_delayed(ds,varname):
    # chunk the appropriate variable in ds, delay, ravel to list
    var_chunks = ds[varname].data.to_delayed().ravel()
    # xarray doesn't allow chunking of coordinates, so we have to make a new variable to chunk
    lon_chunks = xr.DataArray(ds.lon.data,coords={'lon':('lon',ds.lon.data)}).chunk({'lon':12}).data.to_delayed().ravel()
    return var_chunks,lon_chunks

# numpy back to xarray, reattaching metadata and writing chunks to separate files
def write_chunk_to_netcdf(datapath,chunk_id,varname,np_datachunk,xr_time,xr_lat,np_lonchunk,lon_meta,var_meta):
    xr_datachunk = xr.Dataset({varname:(['time','lat','lon'],np_datachunk)},coords={'time':('time',xr_time.data),
                                                            'lat':('lat',xr_lat.data),
                                                            'lon':('lon',np_lonchunk)})
    # copy over xr metadata
    xr_datachunk.time.attrs=xr_time.attrs
    xr_datachunk.lat.attrs=xr_lat.attrs
    xr_datachunk.lon.attrs=lon_meta
    xr_datachunk[varname].attrs=var_meta

    # clean up metadata
    attrslist=['time','lat','lon',varname]
    for att in attrslist:
        if 'valid_min' in xr_datachunk[att].attrs:
            del xr_datachunk[att].attrs['valid_min']
        if 'valid_max' in xr_datachunk[att].attrs:
            del xr_datachunk[att].attrs['valid_max']
        if (att!=varname) and ('comment' in xr_datachunk[att].attrs):
            del xr_datachunk[att].attrs['comment']
        if 'id' in xr_datachunk[att].attrs:
            del xr_datachunk[att].attrs['id']            
    
    # write file
    xr_datachunk.to_netcdf(datapath+varname+'_nClimGridDaily_USsouth_'+str(chunk_id).zfill(2)+'.nc')
    return str(chunk_id).zfill(2)

In [None]:
%%time 
var_chunks,lon_chunks = xr_ds_to_delayed(ds,'prcp')
task_list= [dask.delayed(write_chunk_to_netcdf)(data_dir,id,'prcp',
                                                datachunk,ds.time,ds.lat,
                                                lonchunk,ds.lon.attrs,ds.prcp.attrs) \
            for id,(datachunk,lonchunk) in enumerate(zip(var_chunks,lon_chunks))]
completed_files = dask.compute(*task_list)
completed_files

In [None]:
%%time 
var='tmax'
var_chunks,lon_chunks = xr_ds_to_delayed(ds,var)
task_list= [dask.delayed(write_chunk_to_netcdf)(data_dir,id,var,
                                                datachunk,ds.time,ds.lat,
                                                lonchunk,ds.lon.attrs,ds.prcp.attrs) \
            for id,(datachunk,lonchunk) in enumerate(zip(var_chunks,lon_chunks))]
completed_files = dask.compute(*task_list)
completed_files

In [None]:
%%time 
var='tmax'
var_chunks,lon_chunks = xr_ds_to_delayed(ds,var)
task_list= [dask.delayed(write_chunk_to_netcdf)(data_dir,id,var,
                                                datachunk,ds.time,ds.lat,
                                                lonchunk,ds.lon.attrs,ds.prcp.attrs) \
            for id,(datachunk,lonchunk) in enumerate(zip(var_chunks,lon_chunks))]
completed_files = dask.compute(*task_list)

In [None]:
var='prcp'
files = glob.glob(data_dir+var+'_nClimGridDaily_USsouth_*.nc')
test=xr.open_mfdataset(files)
test

In [None]:
test[var].isel(time=15).plot()

# old code below to write 1 single file per variable

In [None]:
%%time
filename='prcp_nClimGridDaily_1951-2024_USsouth.nc'
print('writing',filename)
ds.prcp.to_netcdf(data_dir+filename)

In [None]:
%%time
filename='tmax_nClimGridDaily_1951-2024_USsouth.nc'
print('writing',filename)
ds.tmax.to_netcdf(data_dir+filename)

In [None]:
%%time
filename='tmin_nClimGridDaily_1951-2024_USsouth.nc'
print('writing',filename)
ds.tmin.to_netcdf(data_dir+filename)

In [None]:
test=xr.open_mfdataset(data_dir+filename)
test

In [None]:
client.shutdown()