### Workflow overview

For each year in the range of years specified, do the following:

- read global grid (ALOS mask)
- put monthly pxv data onto global grid
- put daily deviation pxv data onto global grid
- create daily data 
    - calculate daily precip, srad, tmax, tmin, vapr, 10m wind 
    - interpolate to 2m wind and change units for wind and srad
    - write as npy stacks (precip, srad, tmax, tmin, vapr, 2m wind) 
- compute daily relative humidity and write as npy stack

### Processing time

 approximately 8-10 minutes using a full node (40 tasks) to process each year (includes all 6 daily pyaez variables)

### NOTE

You cannot modify the code below to process srad or wind on its own. You must process these variables with precip. 

This is because valid srad and wind values in the pxv files span over the fill value used (-9999). So, we need to create a mask from a different variable in order to correctly identify which srad and wind grids should be scaled vs which are masked grids. I've chosen precip to create the additional mask.

### Python environment information

This notebook must be run in an environment where the following packages are installed:

xarray, rioxarray, numpy, pandas, dask

and matplotlib is optional if you want to uncomment the plotting portions of the loop

It is suggested that the user builds a conda environment with these packages on HPC Orion using the system module miniconda. The user also needs to create a jupyter kernel from this conda environment before launching the notebook.

### How to launch this notebook

Instructions for launching are provided elsewhere for security of the system. Please ask Kerrie if you need a copy of the instructions.

# Set Up

In [None]:
import numpy as np
from calendar import isleap
import xarray as xr
import dask.array as da
import dask
import rioxarray as rio
import pandas as pd
import os
from time import time as timer
import logging
logging.captureWarnings(True)

# import matplotlib.pyplot as plt

### Things we need to know up front from Gunther

In [None]:
# 1) shape of the ALOS mask. This is the global grid ny and nx including Antarctia. 
ny_global=2160
nx_global=4320

# 2) shape of the global grid ny and nx excluding Antarctica. (ny corresponds to IRmax in Gunther's fortran program)
ny=1800
nx=nx_global

# 3) data type of what's inside the daily deviation and monthly pxv files 
# fortran 2 byte int = python np.int16
# fortran 4 byte float = np.float32
dtype_d=np.int16   # 2 byte integers for daily data
dtype_m=np.float32 # 4 byte floats for monthly data

# 4) fill value used in the daily deviation and monthly pxv files
fillval_d=-9999   # daily
fillval_m=-9999.0 # monthly

# 5) total number of data points (number of lines) in the pxv files (this includes points set to the fillval) 
npts=2295358

# 6) total number of data points in the pxv files with valid data values
# this is so we can identify which points should be masked and which should be scaled 
# when the data range spans over the fillval which is a problem for srad and wind
npts_valid_d=2287408  # daily files
npts_valid_m=2268708  # monthly files

# 7) scale factors for putting the data in the pxv files into the units in the table below

# Variable	Monthly data	Daily deviations/distr.	Scale factor
# Precip	     mm/day	         %_of_month×100	       0.0001
# Srad	       J/m2/day	            kJ/m2/day	        1000.
# Tmax	         °C	                 °C×100	            0.01
# Tmin	         °C	                 °C×100	            0.01
# Vapr	         hPa	                Pa	            0.01
# Wind	        m/sec	              mm/sec	        0.001

# in alphabetical order by variable name 
scale_factors=[0.0001,1000.,0.01,0.01,0.01,0.001]

# 8) how many decimal places the daily output variables should have
output_trunc=[4,2,3,3,2,3]

### other constants 

In [None]:
fao_basedir='/change/this/to/the/shared/datasets/direcory/' # the group's shared datasets dir

# pxv things
pxv_basedir=fao_basedir+'gaez_v5/clim/AgERA5/Hist/' # this is where Gunther has the pxv files
dataset='AgERA5'
experiment='Hist'
pxvsuf='_5m.pxv'
connector='_'
dailytag='365'
sep='/'
pxvdirnames=['prec','srad','tmax','tmin','vapr','wind']
varnames=['Precip','Srad','Tmax-2m','Tmin-2m','Vapr','Wind-10m']

# raster things
maskfile=fao_basedir+'gaez_v5/land/ALOSmask5m_fill.rst'
ydimname='y'
xdimname='x'

# output things
out_basedir=fao_basedir+'pyaez/inputs/global/daily365_npy/'

# parallel computing things
xrchunks={'time':-1,'y':-1,'x':54} # 80 chunks, xr format
dachunks=(-1,54,-1) # the same 80 chunks, da format

# months to process
nmonths=12

### user to input which years to process here

In [None]:
years=np.arange(2021,2024) # the end year is not inclusive

# Begin Main Code

### First, get the ALOS mask

In [None]:
### get the mask from rst into an array of 2 dims (y,x)
### and check that it has the expected number of data points

# open the maskfile but don't include antarctica so mask has shape (y:1800,x:4320)
ds=xr.open_dataset(maskfile,engine='rasterio').isel(y=slice(0,ny)).squeeze()
del ds.coords['band']

# clean up some metadata
ds[xdimname]=ds[xdimname].astype(np.float32)
ds[ydimname]=ds[ydimname].astype(np.float32)
mask2D=ds.band_data

# convert to 0 & 1 mask
mask2D=xr.where(mask2D>0,1,0).astype('int8')

mask1D=mask2D.stack(space=[ydimname,xdimname]) # collapse mask to 1D: 1800*4320 = 7776000 points
inds_data=mask1D==1  # keep track of which points are not masked out

# error checking
npts_mask=int(mask2D.sum().data) # number of data points in mask
print('total data points in mask file', npts_mask,'expecting',npts)
# if this throws an error, coordinate with Gunther
assert npts_mask==npts, f"data points expected {npts}, data points present in mask {npts_mask}"

In [None]:
# this function puts data from pxv files onto a 2-dimensional global grid
# we will call it later with dask delayed to compute in parallel
def data_to_nd_array(i,inds,arr1D,pxv,arr2D):
    arr1D[inds]=pxv.squeeze()  # remove singleton dim (day)
    arr2D[:,:]=arr1D.unstack() # put 1D data onto the 2D grid
    return arr2D.copy()

In [None]:
# main data processing loop through years
# inputs: the monthly and daily deviation pxv files from Gunther
# outpus: npy stacks of daily precip, srad, tmax, tmin, 2m wind, and rhum for pyaez

for yyyy in years:
    
    start_time=timer()
    print('########################################################################')
    print('################################',yyyy,'################################')
    print('########################################################################')

    for varind, var in enumerate(varnames):
        
        print('*****************************************')
        print('*************** Processing',varnames[varind],'***************')    
        print('*****************************************')
        
        ######################################################################################################
        #  STEP 1: Translate monthly means from PXV to xarray data structures on global grid    
        ######################################################################################################
        print('################################ STEP1: PROCESSING MONTHLY PXV ################################')
        pxvfile_m=pxv_basedir+pxvdirnames[varind]+sep+varnames[varind]+connector+dataset+connector+experiment+connector+str(yyyy)+pxvsuf
        filename=pxvfile_m.split(sep)[-1]

        # read file to 1D array
        # monthly files have more data in them than we need so we subset the read with count=nmonths*npts
        with open(pxvfile_m,'rb') as f:
            array1D_m=np.fromfile(f,dtype=dtype_m,count=nmonths*npts)

        # limit precision here?

        # error checking
        nvals=array1D_m.shape[0]             # total number of data values from the file
        npts_flt=nvals/nmonths               # number of grid points in the file, float format
        npts_int=int(nvals/nmonths)          # convert to integer
        # check number of grids found in flt and int are equivalent, if not, the file was read incorrectly
        assert npts_flt*10==float(npts_int*10), f"reading pxv file {filename} with incorrect number of days: {ndays}"
        # check number of grids found in monthly pxv data file is the number expected, if not, coordinate with Gunther
        assert npts_int==npts, f"pxv file {filename} has {npts_int} total data points, expecting {npts} total data points"
        print(f"total data points from monthly pxv file {npts_int}, expecting {npts}")

        # reshape the array to 2 dimensions (npoints,ndays)
        array2D_m=array1D_m.reshape(npts,nmonths) # reshape

        # find out if data value range spans across the fillvalue
        # if it does, we'll need to apply an extra mask later
        flag_m=True if array2D_m.min() < fillval_m else False

        # check data values
        print('data min/max values',array2D_m.min(),array2D_m.max())
        print('apply extra mask?',flag_m)

        # set up for putting data on full grid
        empty1D_m=mask1D.copy().astype(dtype_m)            # placeholder array for 1D space 
        empty1D_m.rio.write_nodata(fillval_m,inplace=True) # set the fill value attribute
        empty1D_m[:]=fillval_m                             # fill array with fill value

        empty2D_m=mask2D.copy().astype(dtype_m)            # placeholder array for 2D grid 
        empty2D_m.rio.write_nodata(fillval_m,inplace=True) # set the fill value attribute
        empty2D_m[:,:]=fillval_m                           # fill array with fill value

        # put the monthly pxv data onto a global grid using dask parallel computing 
        # first convert data to a chunked dask array object, 1 day per chunk 
        # and save array chunks to a list of delayed dask objects
        pxv_delay=da.from_array(array2D_m,chunks=(-1,1)).to_delayed().ravel() 

        # using the function we defined earlier (data_to_nd_array),
        # build a list a computational tasks to be executed in parallel
        task_list=[dask.delayed(data_to_nd_array)(imonth,inds_data,empty1D_m.copy(),pxvdata,empty2D_m.copy())\
                   for imonth,pxvdata in enumerate(pxv_delay)] 
        # double check there is 1 task per day of data
        assert len(task_list)==nmonths, f'{len(task_list)} tasks in list, should be {nmonths}' 

        # execute all computations
        print('putting 1D data on a 2D grid...')
        result_chunks_m=dask.compute(*task_list)

        # concatenate the resulting daily data chunks along a new time dimension
        print('concatenating...')
        data3D_m=xr.concat(result_chunks_m,dim='time')

        # replace fillval with nan
        print('adding nans...')
        data3D_m=xr.where(data3D_m==fillval_m,np.nan,data3D_m)

        # check we have correct number of non-missing data points
        data_mask_m=xr.where(np.isnan(data3D_m.data),0,1)         # 0 where nan, 1 where not nan for all times
        ngrids_data_m=int(data_mask_m.sum()/data_mask_m.shape[0]) # divide by time to get data points per time step
        print('total number of non-missing data points',ngrids_data_m,'expecting',npts_valid_m)
        assert ngrids_data_m==npts_valid_m, f'data mask creation issue. found {ngrids_data_m} valid data points (non missing), expecting {npts_valid_m}'    

        # visual check January
        # figure=plt.figure(figsize=(6,4))
        # data3D_m.isel(time=0).plot()
        # plt.title(varnames[varind]+' data from monthly mean pxv, Jan '+str(yyyy))
        # plt.tight_layout()
        # plt.show()
        del array1D_m,pxvfile_m,filename,nvals,npts_flt,npts_int,array2D_m,flag_m,empty1D_m,empty2D_m,pxv_delay,task_list,result_chunks_m
        ######################################################################################################
        # END STEP 1
        #####################################################################################################

        ######################################################################################################
        # STEP 2: Translate daily deviations from PXV to scaled xarray data structures on global grid
        ######################################################################################################
        print('################################ STEP 2: PROCESSING DAILY DEV PXV ################################')
        pxvfile=pxv_basedir+pxvdirnames[varind]+sep+varnames[varind]+dailytag+connector+dataset+connector+experiment+connector+str(yyyy)+pxvsuf
        filename=pxvfile.split(sep)[-1]

        # read entire file into 1D array
        with open(pxvfile,'rb') as f:
            array1D_d=np.fromfile(f,dtype=dtype_d)            

        # error checking
        nvals=array1D_d.shape[0]           # total number of data points
        ndays=366 if isleap(yyyy) else 365 # number of days of data at each grid point
        npts_flt=nvals/ndays               # number of grid points in the file, float format
        npts_int=int(nvals/ndays)          # convert to integer
        # check number of grids found in flt and int are equivalent, if not, the file was read incorrectly
        assert npts_flt*10==float(npts_int*10), f"reading pxv file {filename} with incorrect number of days: {ndays}"
        # check number of grids found in daily dev pxv data file is the number expected, if not, coordinate with Gunther
        assert npts_int==npts, f"pxv file {filename} has {npts_int} total data points, expecting {npts} total data points"
        print('total data points in daily dev pxv file',npts_int,'expecting',npts)

        # reshape the array to 2 dimensions (npoints,ndays)
        array2D_d=array1D_d.reshape(npts,ndays)

        # find out if data value range spans across the fillvalue
        # if it does, we'll need to apply an extra mask later
        flag_d=True if array2D_d.min() < fillval_d else False

        # check data values
        print('data min/max values before scaling:',array2D_d.min(),array2D_d.max())
        print('apply extra mask?',flag_d)

        # set up for putting data on full grid
        empty1D_d=mask1D.copy().astype(dtype_d)            # placeholder array for 1D space 
        empty1D_d.rio.write_nodata(fillval_d,inplace=True) # set the fill value attribute
        empty1D_d[:]=fillval_d                             # fill array with fill value  

        empty2D_d=mask2D.copy().astype(dtype_d)            # placeholder array for 2D grid 
        empty2D_d.rio.write_nodata(fillval_d,inplace=True) # set the fill value attribute
        empty2D_d[:,:]=fillval_d                           # fill array with fill value  

        # put the daily dev pxv data onto a global grid using dask parallel computing 
        # first convert data to a chunked dask array object, 1 day per chunk 
        # and save array chunks to a list of delayed dask objects
        pxv_delay=da.from_array(array2D_d,chunks=(-1,1)).to_delayed().ravel() 

        # using the function we defined earlier (data_to_nd_array),
        # build a list a computational tasks to be executed in parallel
        task_list=[dask.delayed(data_to_nd_array)(iday,inds_data,empty1D_d.copy(),pxvdata,empty2D_d.copy())\
                   for iday,pxvdata in enumerate(pxv_delay)] 
        # double check we've got 1 task per day of data        
        assert len(task_list)==ndays, f'{len(task_list)} tasks in list, should be {ndays}' 

        # execute all computations
        print('putting 1D data on a 2D grid...')
        result_chunks_d=dask.compute(*task_list)

        # concatenate the resulting daily chunks along a new time dimension
        print('concatenating...')
        data3D_d=xr.concat(result_chunks_d,dim='time')

        # Now scale data, change fill value to nan, and apply extra mask if necessary
        # th extra mask is required if the valid range of the data includes the -9999 fillval
        # which is the case for srad and wind      
        print('changing dtype...')      
        data3D_d=data3D_d.astype(np.float32) # force float32 for nans in output

        # if valid data range includes fillval (srad, wind), do scaling then additional masking
        if flag_d:
            print('processing pxv variable with data values that span over the fillval')
            print('scaling...')
            data3D_d=data3D_d*scale_factors[varind]    
            print('applying additional mask...')
            data3D_d=xr.where(data_mask_d,data3D_d,np.nan)
            # verify that the masking worked
            valid_arr=np.where(np.isnan(data3D_d.data[14,:,:]),0,1) # pick one day to verify  
            nvalid=valid_arr.sum()
            print('total number of non-missing data points',nvalid,'expecting',npts_valid_d)
            assert nvalid==npts_valid_d, f'data mask application issue. found {nvalid} valid data points (non missing), expecting {npts_valid_d}'    
            del valid_arr,nvalid
        # if valid data range doesn't include fillval (precip,tmin,tmax,vapr), convert fillval to nan then scale
        else:
            print('adding nans...')
            data3D_d=xr.where(data3D_d==fillval_d,np.nan,data3D_d)
            print('scaling...')
            data3D_d=data3D_d*scale_factors[varind]

            # compute the additional mask from precipitation
            if varnames[varind] =='Precip':
                print('computing additional mask...')
                data_mask_d=xr.where(np.isnan(data3D_d.data),0,1)  
                ngrids_data_d=int(data_mask_d.sum()/data_mask_d.shape[0])
                print('total number of non-missing data points',ngrids_data_d,'expecting',npts_valid_d)
                assert ngrids_data_d==npts_valid_d, f'data mask creation issue. found {ngrids_data_d} valid data points (non missing), expecting {npts_valid_d}' 

        print('data min/max values after scaling:',data3D_d.min().data,data3D_d.max().data)
                
        # visual check January
        # figure=plt.figure(figsize=(6,4))
        # data3D_d.isel(time=14).plot()
        # plt.title(varnames[varind]+' scaled data from daily dev pxv, 15 Jan '+str(yyyy))
        # plt.tight_layout()
        # plt.show()
        del array1D_d,pxvfile,filename,nvals,npts_flt,npts_int,array2D_d,flag_d,empty1D_d,empty2D_d,pxv_delay,task_list,result_chunks_d
        ######################################################################################################
        # END STEP 2
        ######################################################################################################

        ######################################################################################################
        # STEP 3: Create daily data from monthly means and daily deviations
        ######################################################################################################
        print('################################ STEP 3: CREATING DAILY DATA FROM MONTHLY MEAN AND DAILY DEV ################################')
        
        # create metadata labels for the time dimension
        if (yyyy>=1980) & (yyyy<=2024):
            time_m=pd.date_range(str(yyyy)+'-01-01',str(yyyy)+'-12-31',freq='MS')
            time_d=pd.date_range(str(yyyy)+'-01-01',str(yyyy)+'-12-31',freq='D')
        else:
            time_m=pd.date_range('1900-01-01','1900-12-31',freq='MS')  
            time_d=pd.date_range('1900-01-01','1900-12-31',freq='D') 

        # assign the time labels to the data arrays
        data3D_m=data3D_m.assign_coords(time=("time",time_m))
        data3D_d=data3D_d.assign_coords(time=("time",time_d))

        # if precipitation, compute daily value from monthly accumulation and daily fraction
        if varnames[varind]=='Precip':
            print('computing daily values for Precip...')
            # chunk the data
            var_acc=data3D_m.chunk(xrchunks)
            var_frac=data3D_d.chunk(xrchunks)

            # time labels --> months for the groupby in next step
            var_acc=var_acc.rename({'time':'month'})
            months=np.arange(12)+1
            var_acc['month']=months

            # lazy parallel calculation
            with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                var_daily=var_frac.groupby('time.month')*var_acc  # times here instead of add
            del var_acc, var_frac
        # if not precipitation, compute daily value from monthly mean and daily deviation
        else:
            print('computing daily values for',varnames[varind],'...')
            # chunk the data
            var_mean=data3D_m.chunk(xrchunks)
            var_prime=data3D_d.chunk(xrchunks)

            # time labels --> months for the groupby in next step
            var_mean=var_mean.rename({'time':'month'})
            months=np.arange(12)+1
            var_mean['month']=months

            # lazy parallel calculation
            with dask.config.set(**{'array.slicing.split_large_chunks': False}):
                var_daily=var_prime.groupby('time.month') + var_mean
            del var_mean, var_prime

        # change srad units
        attrs={}
        newvarname=varnames[varind]
        if varnames[varind] == 'Srad':
            print('changing units to W/m2...')
            # attrs=ds[varname].attrs
            attrs['units']='W/m2'

            # Convert J/m2/day to W/m2
            s_per_day=86400
            var_daily=var_daily/s_per_day
            var_daily.attrs=attrs    

        # interpolate winds to 2m and change units
        if varnames[varind] == 'Wind-10m':
            # interp from 10m to 2m height
            print('interpolating wind to 2m...')
            z=10
            z_adjust=4.87/(np.log(67.8*z-5.42))
            var_daily=var_daily*z_adjust

            # fix metadata
            newvarname='Wind-2m'
            attrs={'standard_name':newvarname,'long_name':'2m Wind Speed','units':'m/s'}
            var_daily.attrs=attrs   

        # drop leap day if there is one and reorder the array dimensions to (y,x,time)
        if (yyyy>=1980) & (yyyy<=2024):
            dropdate=str(yyyy)+'-02-29'
        else: 
            dropdate='1900-02-29'

        with dask.config.set(**{'array.slicing.split_large_chunks': False}):
            try:
                data_out = var_daily.drop_sel(time=dropdate).transpose('y','x','time')
                print('dropping date',dropdate)
                del dropdate
            except:
                data_out = var_daily.transpose('y','x','time')    

        # limit precision here?
        # data_out=np.trunc(data_out*10**output_trunc[varind])/(10**output_trunc[varind]) 

        print('min/max values of daily data:',data_out.min().compute().data,data_out.max().compute().data)
        print('data_out dtype:',data_out.dtype)
        
        # visual check January
        # figure=plt.figure(figsize=(6,4))
        # data_out.isel(time=14).plot()
        # plt.title(varnames[varind]+' daily data for input to pyaez, 15 Jan '+str(yyyy))    
        # plt.tight_layout()
        # plt.show()
        del var_daily, data3D_m, data3D_d, time_m, time_d
        ######################################################################################################
        # END STEP 3
        ######################################################################################################

        ######################################################################################################
        # STEP 4: write out npy file
        ######################################################################################################
        print('################################ STEP 4: WRITING DATA FILES ################################')
        out_dir=out_basedir+str(yyyy)+sep+newvarname+sep # directory to write files to

        # create dir for writing npy if it doesn't exist
        isExist = os.path.exists(out_dir)
        if not isExist:
            os.makedirs(out_dir)

        # execute the parallel computation and write files
        # the npy stack is chunked along the longitude dimension (axis 1)
        print('writing stack to',out_dir+'...')  
        da.to_npy_stack(out_dir,data_out.data,axis=1)            
        del out_dir, data_out
        print('done with',var)
        print('####################################################################################')
        ######################################################################################################
        # END STEP 4
        ######################################################################################################   

    print('*****************************************')
    print('*************** Processing Rhum ***************')    
    print('*****************************************')
    
    ######################################################################################################
    # STEP 1: create relative humidity
    ######################################################################################################
    print('################################ STEP RH1: RELATIVE HUMIDITY CALC ################################')
    
    # lazy load chunked data arrays and convert vapr units
    vapr=da.from_npy_stack(out_basedir+str(yyyy)+'/Vapr/').rechunk(dachunks)*0.1 # hPa-->kPa
    tmax=da.from_npy_stack(out_basedir+str(yyyy)+'/Tmax-2m/').rechunk(dachunks)
    tmin=da.from_npy_stack(out_basedir+str(yyyy)+'/Tmin-2m/').rechunk(dachunks)

    # limit precision
    vapr=(np.trunc(vapr*10**3)/(10**3))
    tmax=(np.trunc(tmax*10**3)/(10**3))
    tmix=(np.trunc(tmin*10**3)/(10**3))
    
    # lazy parallel calculation and limit precision
    print('lazy calc...')
    vapr_sat=(0.5*( np.exp((17.27*tmax)/(tmax+237.3)) + np.exp((17.27*tmin)/(tmin+237.3)) )) # kPa
    Rhum=(vapr/vapr_sat) # relative humidity in fraction not percent
    Rhum=np.trunc(Rhum*10**4)/(10**4) # limit precision

    ######################################################################################################
    # STEP 2: write out npy file
    ######################################################################################################
    print('################################ STEP RH2: WRITING DATA FILE ################################')
    out_dir=out_basedir+str(yyyy)+sep+'Rhum'+sep # directory to write files to

    # create dir for writing npy if it doesn't exist
    isExist = os.path.exists(out_dir)
    if not isExist:
        os.makedirs(out_dir)

    # execute the parallel computation and write files
    # the npy stack is chunked along the longitude dimension (axis 1)        
    print('computing and writing stack to',out_dir+'...')     
    da.to_npy_stack(out_dir,Rhum,axis=1)   
    
    del out_dir, Rhum, vapr, tmax, tmin, vapr_sat
    print('done with Rhum')
    print('####################################################################################')        
    task_time=(timer()-start_time)/60.
    print('DONE',yyyy,'IN',task_time,'MINUTES')