# Area sampling using Dask and Xarray
Computational improvement based on test_area_sample_PETandLtheta.ipynb

In [1]:
from dask.distributed import Client
import xarray as xr
import numpy as np
from datetime import datetime
import os
import rasterio as rio
import glob
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
from osgeo import gdal
import cartopy.crs as ccrs

In [2]:
###### Define constants ###########

# Changable
plot_results = False

# Area
network_name = "Australia"
minx = 112.908583
miny = -43.658417
maxx = 153.638417
maxy = -10.064583
# network_name = "California"
# minx = -124.5
# miny = 32.5
# maxx = -114
# maxy = 42.5
bbox = {'minx':minx, 'maxx':maxx, 'miny':miny, 'maxy':maxy}

# Chunks
# https://blog.dask.org/2021/11/02/choosing-dask-chunk-sizes
# Chunk size between 100MB and 1GB are generally good
chunks = {'x': 100, 'y': 100}
pet_chunks = {'longitude': 100, 'latitude': 100}

# Thresholds
lower_quantile_thresh = 0.25
upper_quantile_thresh = 0.75
precip_thresh = 0.00002

# Dates
startDate = datetime(2016, 1, 1)
endDate = datetime(2017, 1, 1)

# Non-changable
SMAPL4_times = ['0130', '0430', '0730', '1030', '1330', '1630', '1930', '2230'] # 3-hourly data

###### PATH ###########
input_path = r"..\1_data"
output_path = r"..\3_data_out"
SMAPL3_path = "SPL3SMP_E"
SMAPL4_path = "SPL4SMGP"
SMAPL4_grid_path = "SMAPL4SMGP_EASEreference"
PET_path = "PET"

In [3]:
os.environ['NUMEXPR_MAX_THREADS'] = '48'
client = Client(n_workers=12, threads_per_worker=4, memory_limit='auto')
client
# See https://distributed.dask.org/en/stable/client.html
# https://distributed.dask.org/en/stable/api.html#distributed.LocalCluster
# https://superfastpython.com/threadpool-number-of-workers/


# n_workers should be #workes 
# 64 CPU cores on this machine
# The number of worker threads in the ThreadPool is not related to the number of CPUs or CPU cores in your system.

# A good rule of thumb is to create arrays with a minimum chunksize of at least one million elements 

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 12
Total threads: 48,Total memory: 95.74 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:56326,Workers: 12
Dashboard: http://127.0.0.1:8787/status,Total threads: 48
Started: Just now,Total memory: 95.74 GiB

0,1
Comm: tcp://127.0.0.1:56409,Total threads: 4
Dashboard: http://127.0.0.1:56410/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56329,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-koppxw0v,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-koppxw0v

0,1
Comm: tcp://127.0.0.1:56403,Total threads: 4
Dashboard: http://127.0.0.1:56404/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56330,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-wy9hqg4y,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-wy9hqg4y

0,1
Comm: tcp://127.0.0.1:56398,Total threads: 4
Dashboard: http://127.0.0.1:56400/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56331,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-lnv8l9w9,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-lnv8l9w9

0,1
Comm: tcp://127.0.0.1:56391,Total threads: 4
Dashboard: http://127.0.0.1:56392/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56332,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-voa50rje,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-voa50rje

0,1
Comm: tcp://127.0.0.1:56383,Total threads: 4
Dashboard: http://127.0.0.1:56386/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56333,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-34wisayu,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-34wisayu

0,1
Comm: tcp://127.0.0.1:56382,Total threads: 4
Dashboard: http://127.0.0.1:56384/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56334,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-1vvlyzlh,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-1vvlyzlh

0,1
Comm: tcp://127.0.0.1:56406,Total threads: 4
Dashboard: http://127.0.0.1:56407/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56335,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-0exadqno,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-0exadqno

0,1
Comm: tcp://127.0.0.1:56388,Total threads: 4
Dashboard: http://127.0.0.1:56389/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56336,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-_tnpbv0e,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-_tnpbv0e

0,1
Comm: tcp://127.0.0.1:56397,Total threads: 4
Dashboard: http://127.0.0.1:56399/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56337,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-sa04bcfq,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-sa04bcfq

0,1
Comm: tcp://127.0.0.1:56394,Total threads: 4
Dashboard: http://127.0.0.1:56395/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56338,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-e02w98sn,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-e02w98sn

0,1
Comm: tcp://127.0.0.1:56412,Total threads: 4
Dashboard: http://127.0.0.1:56413/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56339,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-9a02lo0_,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-9a02lo0_

0,1
Comm: tcp://127.0.0.1:56367,Total threads: 4
Dashboard: http://127.0.0.1:56368/status,Memory: 7.98 GiB
Nanny: tcp://127.0.0.1:56340,
Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-zkvvzjho,Local directory: C:\Users\RARAKI~1\AppData\Local\Temp\dask-worker-space\worker-zkvvzjho


## Read data 

### Read SMAP L4 data

In [4]:
def _preprocess_SMAPL4(ds):
    # Assign missing time dimension
    startTime = datetime.strptime(ds.rangeBeginningDateTime.split(".")[0], '%Y-%m-%dT%H:%M:%S')
    endTime = datetime.strptime(ds.rangeEndingDateTime.split(".")[0], '%Y-%m-%dT%H:%M:%S')
    midTime = startTime + (startTime - endTime)/2
    ds = ds.assign_coords(time=midTime)
    return ds

In [5]:
# A small-scale test to check if the client is working 
# 1 file is 3GB
chunks = {'x': 1200, 'y': 1200, 'time':1, 'band':1}
SMAPL4_fn_pattern_test = f'SMAP_L4_SM_gph_{startDate.year}010*.nc' ####### CHNAGE LATER: testing with 2016 Jan 1-9 data ####### 
SMAPL4_file_path_test = glob.glob(rf'{input_path}/{SMAPL4_path}/{SMAPL4_fn_pattern_test}')
# TODO/IMPROVEMENT #3: open_mfdataset(parallel=True) is not really making things super fast. Need to otimize Clients
# Load data
ds_SMAPL4_3hrly_test = xr.open_mfdataset(SMAPL4_file_path_test, group='Geophysical_Data', engine="rasterio", preprocess=_preprocess_SMAPL4, chunks=chunks, combine='nested', concat_dim='time', parallel=True)
# Takes approx 10 min. If it taking more than that, probably it's failing
# Fails when Client is properly connected to the dask dashboard, and succeeds when it's not ... 

# # https://docs.xarray.dev/en/stable/generated/xarray.open_mfdataset.html#xarray.open_mfdataset

In [6]:
# Get a list of files 
# 1 file is 3GB
chunks = {'x': 1200, 'y': 1200, 'time':1, 'band':1}
SMAPL4_fn_pattern = f'SMAP_L4_SM_gph_*.nc'
# SMAPL4_fn_pattern = f'SMAP_L4_SM_gph_{startDate.year}010*.nc' ####### CHNAGE LATER: testing with 2016 Jan 1-9 data ####### 
SMAPL4_file_paths = glob.glob(rf'{input_path}/{SMAPL4_path}/{SMAPL4_fn_pattern}')
# TODO/IMPROVEMENT #3: open_mfdataset(parallel=True) is not really making things super fast. Need to otimize Clients

# # https://docs.xarray.dev/en/stable/generated/xarray.open_mfdataset.html#xarray.open_mfdataset

# Load data
ds_SMAPL4_3hrly = xr.open_mfdataset(SMAPL4_file_paths, group='Geophysical_Data', engine="rasterio", preprocess=_preprocess_SMAPL4, chunks=chunks, combine='nested', concat_dim='time', parallel=True)
# Takes approx 10 min. If it taking more than that, probably it's failing
# Fails when Client is properly connected to the dask dashboard, and succeeds when it's not ... 

In [8]:
ds_SMAPL4_3hrly.precipitation_total_surface_flux

Unnamed: 0,Array,Chunk
Bytes,535.41 GiB,5.49 MiB
Shape,"(20457, 1, 1822, 3856)","(1, 1, 1200, 1200)"
Dask graph,163656 chunks in 61372 graph layers,163656 chunks in 61372 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 535.41 GiB 5.49 MiB Shape (20457, 1, 1822, 3856) (1, 1, 1200, 1200) Dask graph 163656 chunks in 61372 graph layers Data type float32 numpy.ndarray",20457  1  3856  1822  1,

Unnamed: 0,Array,Chunk
Bytes,535.41 GiB,5.49 MiB
Shape,"(20457, 1, 1822, 3856)","(1, 1, 1200, 1200)"
Dask graph,163656 chunks in 61372 graph layers,163656 chunks in 61372 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [9]:
# Re-assign x and y coordinates
SMAPL4_template_fn = r"G:\Araki\SMSigxSMAP\1_data\SPL4SMGP\SMAP_L4_SM_gph_20180911T103000_Vv7032_001_HEGOUT.nc"
SMAPL4_template = xr.open_dataset(SMAPL4_template_fn)
ds_SMAPL4_3hrly = ds_SMAPL4_3hrly.assign_coords(x=SMAPL4_template['x'][:], y=SMAPL4_template['y'][:]*(-1))

# ax = plt.axes(projection=ccrs.PlateCarree())
# ax.coastlines()
# ds_SMAPL4.precipitation_total_surface_flux.sel(time='2016-01-01 01:30:00').plot(ax=ax)
# ds_SMAPL4.precipitation_total_surface_flux

In [10]:
ds_SMAPL4_3hrly = ds_SMAPL4_3hrly.sel(x=slice(minx, maxx), y=slice(miny, maxy)).copy()

if plot_results:
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines(color='white')
    ds_SMAPL4_3hrly.sel(time='2016-01-01 01:30:00').precipitation_total_surface_flux.plot(ax=ax)
    ds_SMAPL4_3hrly

### Read SMAP L3 data

In [11]:
def _preprocess_SMAPL3(ds):
    # Assign missing time dimension
    # Doesn't care about hour amd minutes, as it is daily data
    startTime = datetime.strptime(ds.rangeBeginningDateTime.split("T")[0], '%Y-%m-%d')
    ds = ds.assign_coords(time=startTime)
    return ds

In [12]:
# Get a list of files 
# Test with 2016 Jan 1-9 data first
chunks = {'x': 1200, 'y': 1200, 'time':1, 'band':1}
SMAPL3_fn_pattern = f'SMAP_L3_SM_P_E_*.nc'
# SMAPL3_fn_pattern = f'SMAP_L3_SM_P_E_{startDate.year}01*.nc' ####### CHNAGE LATER: testing with 2016 Jan 1-9 data #######
SMAPL3_file_paths = glob.glob(rf'{input_path}/{SMAPL3_path}/{SMAPL3_fn_pattern}')
# Load data
ds_SMAPL3 = xr.open_mfdataset(SMAPL3_file_paths, preprocess=_preprocess_SMAPL3, engine="rasterio", chunks=chunks, combine="nested", concat_dim="time", parallel=True)

# Takes approx. 1min/yr x 8yr = 8min

In [13]:
ds_SMAPL3 = ds_SMAPL3.sel(x=slice(minx, maxx), y=slice(maxy, miny))
ds_SMAPL3.rio.write_crs('epsg:4326', inplace=True)
# 3.3 sec for 1 mo of data

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.48 GiB 308.97 kiB Shape (2520, 1, 360, 437) (1, 1, 181, 437) Dask graph 5040 chunks in 7562 graph layers Data type float32 numpy.ndarray",2520  1  437  360  1,

Unnamed: 0,Array,Chunk
Bytes,1.48 GiB,308.97 kiB
Shape,"(2520, 1, 360, 437)","(1, 1, 181, 437)"
Dask graph,5040 chunks in 7562 graph layers,5040 chunks in 7562 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [14]:
if plot_results:
    ds_SMAPL3.sel(time='2016-01-03').soil_moisture.plot()
    ds_SMAPL3.soil_moisture
# TODO/IMPROVEMENT: Add dropna(how=all) somewhere to skip calculation of the ocean etc.

### Read Singer PET data

In [15]:
# Get a list of files 
PET_fn_pattern = f'*_daily_pet.nc'
PET_file_paths = glob.glob(rf'{input_path}/{PET_path}/{PET_fn_pattern}')

# Load data
ds_PET = xr.open_mfdataset(PET_file_paths, combine="nested", chunks=pet_chunks, concat_dim="time", parallel=True)
ds_PET['pet']

Unnamed: 0,Array,Chunk
Bytes,61.76 GiB,13.96 MiB
Shape,"(2557, 1801, 3600)","(366, 100, 100)"
Dask graph,4788 chunks in 15 graph layers,4788 chunks in 15 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 61.76 GiB 13.96 MiB Shape (2557, 1801, 3600) (366, 100, 100) Dask graph 4788 chunks in 15 graph layers Data type float32 numpy.ndarray",3600  1801  2557,

Unnamed: 0,Array,Chunk
Bytes,61.76 GiB,13.96 MiB
Shape,"(2557, 1801, 3600)","(366, 100, 100)"
Dask graph,4788 chunks in 15 graph layers,4788 chunks in 15 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [16]:
# Clip to California
ds_PET = ds_PET.rename({'latitude': 'y', 'longitude':'x'})
ds_PET = ds_PET.sel(x=slice(minx, maxx), y=slice(maxy, miny)).copy()

# Interpolate to SMAP grid
ds_PET.rio.write_crs('epsg:4326', inplace=True)
PET_resampled = ds_PET['pet'].interp(coords={'x': ds_SMAPL3['x'], 'y': ds_SMAPL3['y']}, method='linear', kwargs={'fill_value': np.nan})
ds_SMAPL3['PET'] = PET_resampled

# Plot
if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ds_PET.pet.sel(time='2016-01-01').plot(vmax=6, ax=ax1)
    PET_resampled.sel(time='2016-01-01').plot(vmax=6, ax=ax2)

# See: https://docs.xarray.dev/en/stable/generated/xarray.DataArray.interp.html
# If I have to reproject, see: https://github.com/corteva/rioxarray/issues/119

## Processing data

### Get daily mean values

In [17]:
# SMAP L4
ds_SMAPL4 = ds_SMAPL4_3hrly.chunk(chunks={'x': 50, 'y': 50})
ds_SMAPL4 = ds_SMAPL4_3hrly.precipitation_total_surface_flux.resample(time='D', skipna=True, keep_attrs=True).mean('time')
ds_SMAPL4.rio.write_crs('epsg:4326', inplace=True)
ds_SMAPL4 = ds_SMAPL4.sel(band=1).rio.reproject_match(ds_SMAPL3)
del ds_SMAPL4_3hrly

# Takes 20min

Task exception was never retrieved
future: <Task finished name='Task-677559' coro=<Client._gather.<locals>.wait() done, defined at c:\Users\raraki8159\.conda\envs\SMAP_v2\Lib\site-packages\distributed\client.py:2134> exception=AllExit()>
Traceback (most recent call last):
  File "c:\Users\raraki8159\.conda\envs\SMAP_v2\Lib\site-packages\distributed\client.py", line 2143, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-677541' coro=<Client._gather.<locals>.wait() done, defined at c:\Users\raraki8159\.conda\envs\SMAP_v2\Lib\site-packages\distributed\client.py:2134> exception=AllExit()>
Traceback (most recent call last):
  File "c:\Users\raraki8159\.conda\envs\SMAP_v2\Lib\site-packages\distributed\client.py", line 2143, in wait
    raise AllExit()
distributed.client.AllExit


In [None]:
ds_SMAPL3 = ds_SMAPL3.chunk(chunks={'x': 50, 'y': 50})
ds_SMAPL3.soil_moisture

In [None]:
# SMAP L3
# Mask low-quality data
ds_SMAPL3['soil_moisture_am_masked'] = ds_SMAPL3.soil_moisture.where((ds_SMAPL3.retrieval_qual_flag == 0) | (ds_SMAPL3.retrieval_qual_flag == 8))
ds_SMAPL3['soil_moisture_pm_masked'] = ds_SMAPL3.soil_moisture_pm.where((ds_SMAPL3.retrieval_qual_flag_pm == 0) | (ds_SMAPL3.retrieval_qual_flag_pm == 8))
stacked_data = ds_SMAPL3[['soil_moisture_am_masked', 'soil_moisture_pm_masked']].to_array(dim='new_dim')
ds_SMAPL3['soil_moisture_daily'] = stacked_data.mean(skipna=True, dim="new_dim")

In [None]:
if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ds_SMAPL4.sel(time='2016-01-01').plot(ax=ax1)
    ds_SMAPL3.soil_moisture_daily.sel(time='2016-01-01').plot(ax=ax2)

### Calculate dS/dt

In [None]:
# Mask by precipitation
# https://geohackweek.github.io/nDarrays/09-masking/
precip_mask = ds_SMAPL4.where(ds_SMAPL4 < precip_thresh)

# Insert dummy soil moisture record where (precipitation is present) && (soil moisture record does not exist)
# In this case, drydown pattern is disrupted and shouldn't be calculated. 
# So I put extremely large values for those records, calculate dS, and drop the dS afterwards
no_sm_record_but_precip_present = ds_SMAPL4.where((precip_mask.isnull()) & (ds_SMAPL3['soil_moisture_daily'].isnull()))
ds_SMAPL3['sm_for_dS_calc'] = ds_SMAPL3['soil_moisture_daily'].where(no_sm_record_but_precip_present.isnull(), 9999)

# print(precip_mask.sel(x=sample_x, y=sample_y, method='nearest').values.T)
# print(ds_SMAPL3['soil_moisture_daily'].sel(x=sample_x, y=sample_y, method='nearest').values.T)
# print(no_sm_record_but_precip_present.sel(x=sample_x, y=sample_y, method='nearest').values.T)
# print(sm_for_dS_calc.sel(x=sample_x, y=sample_y, method='nearest').values.T)

In [None]:
# Calculate dS
ds_SMAPL3['dS'] = ds_SMAPL3['sm_for_dS_calc'].bfill(dim="time", limit=5).diff(dim="time").where(ds_SMAPL3['sm_for_dS_calc'].notnull().shift(time=+1))

# Drop the dS where  (precipitation is present) && (soil moisture record does not exist)
ds_SMAPL3['dS'] = ds_SMAPL3['dS'].where((ds_SMAPL3['dS'] > -1) & (ds_SMAPL3['dS'] < 1))

# Calculate dt
non_nulls = ds_SMAPL3['sm_for_dS_calc'].isnull().cumsum(dim='time')
nan_length = non_nulls.where(ds_SMAPL3['sm_for_dS_calc'].notnull()).bfill(dim="time")+1 - non_nulls +1
ds_SMAPL3['dt'] = nan_length.where(ds_SMAPL3['sm_for_dS_calc'].isnull()).fillna(1)

# Calculate dS/dt
ds_SMAPL3['dSdt'] = ds_SMAPL3['dS']/ds_SMAPL3['dt']
ds_SMAPL3['dSdt'] = ds_SMAPL3['dSdt'].shift(time=-1)

if plot_results:
    ds_SMAPL3['dSdt'].sel(time='2016-01-03').plot()

In [None]:
# Mask where precipitation is on the day 1 of soil moisture measruement
ds_SMAPL3['dSdt'] = ds_SMAPL3['dSdt'].where(precip_mask.notnull())

# print(ds_SMAPL3['dSdt'].sel(x=sample_x, y=sample_y, method='nearest').values.T)
# print(ds_SMAPL3_masked.sel(x=sample_x, y=sample_y, method='nearest').values.T)
# print(test.sel(x=sample_x, y=sample_y, method='nearest').values)

if plot_results:
    sample_x = -114
    sample_y = 34
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 3))
    ds_SMAPL3['soil_moisture_daily'].sel(x=sample_x, y=sample_y, method='nearest').interp(method='linear').plot.scatter(ax=ax1)
    ds_SMAPL3['dSdt'].sel(x=sample_x, y=sample_y, method='nearest').interp(method='linear').plot.scatter(ax=ax2)
    ds_SMAPL4.sel(x=sample_x, y=sample_y, method='nearest').plot.scatter(ax=ax3)

## Fit regression b/w dS/dt & S for upper/lower PET quantile

### Get upper/lower PET quantile

In [None]:
# Get PET quantile values 
ds_SMAPL3['PET'] = ds_SMAPL3['PET'].chunk({'time': len(ds_SMAPL3['PET'].time), 'x': 'auto', 'y': 'auto'})
ds_quantile = ds_SMAPL3['PET'].where(precip_mask.notnull()).quantile(dim="time", q=[lower_quantile_thresh, upper_quantile_thresh])

if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ds_quantile.sel(quantile=lower_quantile_thresh).plot(ax=ax1, vmax=6)
    ds_quantile.sel(quantile=upper_quantile_thresh).plot(ax=ax2, vmax=6)

In [None]:
ds_SMAPL3['PET_upper_mask'] = ds_SMAPL3['PET'].where(ds_SMAPL3['PET'] >= ds_quantile.sel(quantile=upper_quantile_thresh))
ds_SMAPL3['PET_lower_mask'] = ds_SMAPL3['PET'].where(ds_SMAPL3['PET'] <= ds_quantile.sel(quantile=lower_quantile_thresh))

In [None]:
if plot_results:
    fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
    ds_SMAPL3['PET'].sel(x=sample_x, y=sample_y, method='nearest').plot.scatter(ax=ax1, color='blue')
    ds_SMAPL3['PET_upper_mask'].sel(x=sample_x, y=sample_y, method='nearest').plot.scatter(ax=ax1, color='green')
    ds_SMAPL3['PET_lower_mask'].sel(x=sample_x, y=sample_y, method='nearest').plot.scatter(ax=ax1, color='red')

### Fit regression line

In [None]:
# Get the minimum soil mositure values over the observation period for a given pixel
sm_min = ds_SMAPL3.soil_moisture_daily.min(dim="time")

In [None]:
# Shift x values 
ds_SMAPL3['shifted_sm'] = ds_SMAPL3.soil_moisture_daily - sm_min
ds_SMAPL3['neg_dSdt'] = ds_SMAPL3['dSdt'] * (-1)
input_sm_upper = ds_SMAPL3.where((ds_SMAPL3['PET_upper_mask'].notnull()) & (ds_SMAPL3['neg_dSdt'] > 0))
input_sm_lower = ds_SMAPL3.where((ds_SMAPL3['PET_lower_mask'].notnull()) & (ds_SMAPL3['neg_dSdt'] > 0))

if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ds_SMAPL3.soil_moisture_daily.sel(x=sample_x, y=sample_y, method='nearest').interp(method='linear').plot(ax=ax1)
    ds_SMAPL3.shifted_sm.sel(x=sample_x, y=sample_y, method='nearest').interp(method='linear').plot(ax=ax2)
    print(sm_min.sel(x=sample_x, y=sample_y, method='nearest').values)

In [None]:
if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    input_sm_upper.PET.sel(time='2016-01-01').plot(ax=ax1)
    input_sm_lower.PET.sel(time='2016-01-01').plot(ax=ax2)

In [None]:
def fit_regression_through_origin(input_sm):

    # Fit regression of linear line through the origin
    # the slope a is calculated is: a = sum(xi * yi) / sum((xi)^2)
    # If the weight is assumed to be w=1/x**2 (in case of this data)
    # a_out = sum(y/x) / len(x)
    # Proofs in: 
    # https://onlinelibrary.wiley.com/doi/10.1111/1467-9639.00136
    # http://sites.msudenver.edu/ngrevsta/wp-content/uploads/sites/416/2020/02/Notes_07.pdf
    # https://www.jstor.org/stable/2527698?seq=2

    # numerator = input_sm.shifted_sm * input_sm.neg_dSdt
    # denominator = input_sm.shifted_sm * input_sm.shifted_sm
    # denominator_masked = denominator.where((~numerator.isnull()))
    # a = numerator.sum(dim="time", skipna=True) / denominator_masked.sum(dim="time", skipna=True)

    numerator = input_sm.neg_dSdt/input_sm.shifted_sm
    denominator = numerator.notnull().sum(dim='time')
    a = numerator.sum(dim="time", skipna=True) / denominator

    # Calculate error metrics 

    # https://web.ist.utl.pt/~ist11038/compute/errtheory/,regression/regrthroughorigin.pdf
    # R2 = sum(Yi_modeled^2)/sum(Yi_observed^2)

    # https://rpubs.com/aaronsc32/regression-through-the-origin
    # http://sites.msudenver.edu/ngrevsta/wp-content/uploads/sites/416/2020/02/Notes_07.pdf
    # SSE = sum(Yi_obs ^2) - a_i^2 * sum(xi_obs^2)
    # MSE = SSE/ (n-1)

    y2 =  ds_SMAPL3.neg_dSdt *  ds_SMAPL3.neg_dSdt
    n = denominator.where(~numerator.isnull()).time.shape[0]
    SSE = y2.where(numerator.notnull()).sum(dim="time", skipna=True) - a * denominator.where(numerator.notnull()).sum(dim="time", skipna=True)
    MSE = SSE / (n-1)
    
    # https://pubs.cif-ifc.org/doi/pdf/10.5558/tfc71326-3
    # https://dynamicecology.wordpress.com/2017/04/13/dont-force-your-regression-through-zero-just-because-you-know-the-true-intercept-has-to-be-zero/

    return a, MSE


In [None]:
a_upper, MSE_upper = fit_regression_through_origin(input_sm_upper)
a_lower, MSE_lower = fit_regression_through_origin(input_sm_lower)

In [None]:
if plot_results:
    fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    a_upper.sel(band=1).plot(ax=ax1, vmin=-1, vmax=1)
    a_lower.sel(band=1).plot(ax=ax2, vmin=-1, vmax=1)
    MSE_upper.sel(band=1).plot(ax=ax3, vmin=0)
    MSE_lower.sel(band=1).plot(ax=ax4, vmin=0)

In [None]:
a_diff = a_upper - a_lower
if plot_results:
    a_diff.plot()

In [None]:
# Plot 
if plot_results:
    sample_x = -120
    sample_y = 35.3
    S = ds_SMAPL3.shifted_sm.sel(x=sample_x, y=sample_y, method='nearest').values
    dSdt = ds_SMAPL3.neg_dSdt.sel(x=sample_x, y=sample_y, method='nearest').values
    PET = ds_SMAPL3.PET.sel(x=sample_x, y=sample_y, method='nearest').values
    S_min = sm_min.sel(x=sample_x, y=sample_y, method='nearest').values
    a_upper_sel = a_upper.sel(x=sample_x, y=sample_y, method='nearest').values
    a_lower_sel = a_lower.sel(x=sample_x, y=sample_y, method='nearest').values
    a_diff_sel = a_diff.sel(x=sample_x, y=sample_y, method='nearest').values

    print(S.T)
    print(dSdt.T)
    print(dSdt.T)
    print(S_min)
    print(a_upper_sel)
    print(a_lower_sel)
    print(a_diff_sel)


In [None]:
if plot_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    sc1 = ax1.scatter(S, dSdt, c=PET)
    plt.colorbar(sc1)
    x = np.linspace(0, np.nanmax(S),100)
    y_lower = a_lower_sel*x # (x+S_min)
    y_upper = a_upper_sel*x
    ax1.plot(x, y_lower, '-r')
    ax1.plot(x, y_upper, '-b')

    sc = ax2.scatter(S+S_min, dSdt, c=PET)
    plt.colorbar(sc)
    x = np.linspace(0, np.nanmax(S+S_min),100)
    y_lower = a_lower_sel*x - a_lower_sel*S_min 
    y_upper = a_upper_sel*x - a_upper_sel*S_min 
    ax2.plot(x, y_lower, '-r')
    ax2.plot(x, y_upper, '-b')

In [None]:

low_diff_sample_x = -118
low_diff_sample_y = 37.6
high_diff_sample_x = -120
high_diff_sample_y = 35.8

low_diff_sample = ds_SMAPL3.sel(x=low_diff_sample_x, y=low_diff_sample_y, method='nearest')
high_diff_sample = ds_SMAPL3.sel(x=high_diff_sample_x, y=high_diff_sample_y, method='nearest')
low_diff_sample



In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
high_diff_sample.soil_moisture_daily.plot(ax=ax1, label='high a_diff')
low_diff_sample.soil_moisture_daily.plot(ax=ax1, label='low a_diff')
high_diff_sample.neg_dSdt.plot(ax=ax2, label='high a_diff')
low_diff_sample.neg_dSdt.plot(ax=ax2, label='low a_diff')
fig.legend()

In [None]:
x1 = high_diff_sample.soil_moisture_daily.values
x2 = low_diff_sample.soil_moisture_daily.values
y1 = high_diff_sample.neg_dSdt.values
y2 = low_diff_sample.neg_dSdt.values
z1 = high_diff_sample.PET.values
z2 = high_diff_sample.PET.values


In [None]:
fig, (ax3, ax4) = plt.subplots(1, 2, figsize=(15, 5),sharex = True,sharey=True)
high_diff_plot = ax3.scatter(x1, y1, c=z1, label='high a_diff', cmap='viridis')
ax3.set_ylim([0,0.075])
ax3.set_title(f'high a_diff at x={high_diff_sample_x}, y={high_diff_sample_y}')
ax3.set_xlabel('S')
ax3.set_ylabel('dS/dt')

low_diff_plot = ax4.scatter(x2, y2, c=z2, label='low a_diff', cmap='viridis')
ax4.set_ylim([0,0.075])
ax4.set_title(f'low a_diff at x={low_diff_sample_x}, y={low_diff_sample_y}')
ax4.set_xlabel('S')
ax4.set_ylabel('dS/dt')

cbar3 = plt.colorbar(high_diff_plot)
cbar3.set_label('PET', rotation=270)
cbar4 = plt.colorbar(low_diff_plot)
cbar4.set_label('PET', rotation=270)
plt.show()


In [None]:
meanSM = ds_SMAPL3.soil_moisture_daily.mean(dim="time")
meanP = ds_SMAPL4.mean(dim="time")

# Save results

In [None]:
results = xr.Dataset({'a_diff': a_diff.sel(band=1), 
                      'a_upper': a_upper.sel(band=1), 'a_lower': a_lower.sel(band=1), 
                      'MSE_upper': MSE_upper.sel(band=1), 'MSE_lower': MSE_lower.sel(band=1),
                      'PET_upper': ds_quantile.sel(quantile=upper_quantile_thresh).drop('quantile'), 
                      'PET_lower': ds_quantile.sel(quantile=lower_quantile_thresh).drop('quantile'), 
                      'meanSM': meanSM.sel(band=1),
                      'meanP': meanP})
results.rio.write_crs('epsg:4326')
results = results.drop_vars(["band", "/crs", "projection_information", "quantile", "spatial_ref"])

In [None]:
out_path = r'G:\Araki\SMSigxSMAP\3_data_out\a_diff_202303'
results.load().to_netcdf(os.path.join(out_path, 'results_20230318_australia.nc'))

In [None]:
print(f'Finished running at {datetime.now()}')