In [None]:
import numpy as np
import xarray as xr
from dask.distributed import Client,LocalCluster
import matplotlib.pyplot as plt
import scipy

In [None]:
path_to_data = '/path/to/our/shared/datasets/dir/processed_data/'

In [None]:
nworkers=40 
cluster=LocalCluster(n_workers=nworkers,threads_per_worker=1) 
client=Client(cluster) 
client.wait_for_workers(n_workers=nworkers,timeout=10) 
client 

In [None]:
def custom_reg(data_chunk, min_obs = 10):

    # create 3D x values with nans where y is nan and where total obs <10
    # eliminate the intermediate step of non_valid will save some RAM
    # using isfinite instead of ~np.isnan is more direct but would also account for inf values if there were any
    n_obs = np.isfinite(data_chunk).sum('time') # 2D, range 0-19 no nans, retains xy coord labels
    n_obs_clean = n_obs.where(n_obs > min_obs)  # put nan where too few total obs  
    x = data_chunk.time.dt.year.expand_dims(dim = {"y": data_chunk.y, "x": data_chunk.x}, axis = (1, 2)) # to 3D (time,y,x)

    # fill x and y with nan where data is missing and where less than 10 total obs
    x_clean =x.where( (np.isfinite(data_chunk)) & (np.isfinite(n_obs_clean)) ) 
    y_clean = data_chunk.where(np.isfinite(n_obs_clean))
    
    # # linear regression
    y_mean = y_clean.mean('time')
    y_var = ((y_clean - y_mean)**2).sum(dim = 'time')/n_obs_clean
    y_std = np.sqrt(y_var)
    x_mean = x_clean.mean('time')
    x_var = ((x_clean - x_mean)**2).sum(dim = 'time')/n_obs_clean
    x_std = np.sqrt(x_var)

    cov = ((x_clean - x_mean)*(y_clean - y_mean)).sum(dim = 'time')/n_obs_clean
    cor = cov/(x_std*y_std)
    slope = cov/(x_std**2)

    # significance
    t_stats = cor*np.sqrt(n_obs_clean - 2)/np.sqrt(1 - cor**2)
    p = scipy.stats.t.sf(abs(t_stats), n_obs_clean - 2)*2
    p = xr.DataArray(p, dims = cor.dims, coords = cor.coords)

    # convert results to a dataset
    result = slope.to_dataset(name = 'wealth_pc_trend').assign(p_value = p).assign(n_obs = n_obs)
    return result

In [None]:
%%time
print('lazy load data')
# lazy load wealth to chunked array (not in memory)
wealth = xr.open_dataset(path_to_data+'wealth.nc',chunks={'time':-1,'y':1000,'x':1000}).wealth_pc

# set up a chunked dataset template that has the exact dims, coords, and variable names as the output of custom_reg
print('create output template')
template=wealth.isel(time=0).drop_vars(['time'])
template.attrs = {'standard_name': 'wealth_pc_trend',
                  'long_name': 'Annual time trend for the per capita wealth'}
template.name=template.attrs['standard_name']
template=template.to_dataset()
template['p_value']=template['wealth_pc_trend']
template['n_obs']=template['wealth_pc_trend']

# do the parallel compute
print('execute')
wealth_trend=wealth.map_blocks(custom_reg,template=template,kwargs={'min_obs':10}).compute()
wealth_trend

In [None]:
# plot a subset where significant
wealth_trend.wealth_pc_trend.sel(y=slice(2E6,0.34E6),x=slice(0.34E6,2E6)).where(wealth_trend.p_value<0.1).plot()

# compare to polyfit_parallel

In [None]:
# function to call with map_blocks
# operates on xarray chunks and returns xarray chunks
def polyfit_parallel(data_chunk,skipna):
    data_chunk.coords['datetime']=data_chunk.time
    data_chunk.coords['time']=data_chunk.datetime.dt.year
    result_chunk = data_chunk.polyfit('time',1,skipna=skipna)
    return result_chunk.polyfit_coefficients.sel(degree=1).drop_vars('degree')

In [None]:
%%time
print('lazy load data')
# lazy load wealth to chunked array (not in memory)
wealth = xr.open_dataset(path_to_data+'wealth.nc',chunks={'time':-1,'y':1000,'x':1000}).wealth_pc

# set up a chunked array template that has the exact dims and coords as the output of function polyfit_parallel will have
print('create output template array')
template=wealth.isel(time=0).drop_vars(['time','spatial_ref'])
template.attrs={'standard_name':'wealth_pc trend'} # put whatever attributes you want

# do the parallel compute
print('execute')
wealth_trend_polyfit=wealth.map_blocks(polyfit_parallel,template=template,kwargs={'skipna':'True'}).compute()
wealth_trend_polyfit

In [None]:
fig=plt.figure(figsize=(20,6))

ax=fig.add_subplot(131)
wealth_trend.wealth_pc_trend.sel(y=slice(2E6,0.34E6),x=slice(0.34E6,2E6)).plot()
plt.title('custom_reg')

ax=fig.add_subplot(132)
wealth_trend_polyfit.sel(y=slice(2E6,0.34E6),x=slice(0.34E6,2E6)).plot()
plt.title('polyfit_parallel')

ax=fig.add_subplot(133)
(wealth_trend.wealth_pc_trend - wealth_trend_polyfit).sel(y=slice(2E6,0.34E6),x=slice(0.34E6,2E6)).plot(cmap = 'bwr')
plt.title('custom_reg minus polyfit_parallel')

plt.show()


In [None]:
diff=wealth_trend.wealth_pc_trend - wealth_trend_polyfit
diff.min().item(), diff.max().item()

In [None]:
# client.restart()
client.shutdown()