In [1]:
import xarray as xr
import numpy as np
import pandas as pd

import dask
import dask.array as da

import matplotlib.pyplot as plt

# Get test data

Using 10 years and 4 grid cells of data

In [2]:
pr_file = r'D://data/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
tmax_file = r'D://data/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'

In [3]:
year_start='1951'
year_end='2024'
lat1, lat2 = 32, 34
lon1, lon2 = -90, -88

In [None]:
# pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
# # pr = xr.open_dataset(pr_file, chunks={'time':10,'lat':-1,'lon':-1}).prcp.sel(time=slice(year_start, year_end))#,lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
# pr = (pr / 25.4).round(2)  # convert to inches
# pr

In [None]:
# pr.isel(time=0).plot()
# mask=np.isfinite(pr.isel(time=0))
# mask.sum().compute()

In [None]:
# tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
# tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit
# tmax

In [None]:
# %%time
# pr = pr.compute()
# tmax = tmax.compute()

# METHOD: xarray apply_ufunc to individual cells

In [None]:
def kbdi_single_grid(tmax_1d, pr_1d):
    # Ensure inputs are NumPy arrays
    T = np.asarray(tmax_1d)
    PR = np.asarray(pr_1d)

    if not np.all(np.isfinite(PR)):
        return np.full_like(PR, np.nan)

    # Create a time index
    time_index = np.arange(len(PR),dtype='float32')

    # 7-day rolling precipitation sum
    ndays = 7
    pr_thresh = 8.0  # inches
    pr_weeksum = np.convolve(PR, np.ones(ndays), mode='valid').astype('float32')
    pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum]).astype('float32')

    try:
        day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
    except IndexError:
        return np.full_like(PR, np.nan)

    # Rain mask and consecutive rain days
    rainmask = np.where(PR > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr == 2)[0].astype('int32')
    consec_day1 = consec_day2 - 1
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)

    # Pnet calculation
    acc_thresh = 0.2  # inches
    pnet = np.copy(PR)
    pnet[cat == 1] = pnet[cat == 1] - acc_thresh
    pnet = np.where(pnet < 0, 0, pnet)

    # Adjust for consecutive rain days
    consec_inds = np.where(cat == 2)[0].astype('int32')
    accpr = 0.0
    thresh_flag = False

    for i, ind in enumerate(consec_inds):
        accpr += PR[ind]
        if accpr <= acc_thresh and not thresh_flag:
            pnet[ind] = 0
        elif accpr > acc_thresh and not thresh_flag:
            accpr -= acc_thresh
            pnet[ind] = accpr
            thresh_flag = True
        else:
            pnet[ind] = PR[ind]
        if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
            accpr = 0.0
            thresh_flag = False

    # Mean annual precipitation (approximate)
    days_per_year = 365
    n_years = len(PR) // days_per_year
    ann_pr = []
    for i in range(n_years):
        year_data = PR[i * days_per_year:(i + 1) * days_per_year]
        if np.count_nonzero(~np.isnan(year_data)) >= 360:
            ann_pr.append(np.nansum(year_data))
    if len(ann_pr) == 0:
        return np.full_like(PR, np.nan)
    mean_ann_pr = np.mean(ann_pr)

    # KBDI calculation
    KBDI = np.full_like(PR, np.nan)
    if day_int < len(PR):
        KBDI[day_int] = 0
    else:
        return KBDI

    denominator = 1 + 10.88 * np.exp(-0.0441 * mean_ann_pr)
    for it in range(day_int + 1, len(PR)):
        Q = max(0, KBDI[it - 1] - pnet[it] * 100)
        numerator = (800 - Q) * (0.968 * np.exp(0.0486 * T[it]) - 8.3)
        KBDI[it] = Q + (numerator / denominator) * 1e-3
    
    return KBDI

### xarray inputs in memory

In [None]:
# lazy
pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

In [None]:
%%time
# compute, hold in memory
pr = pr.compute()
tmax = tmax.compute()

In [None]:
%%time

# Apply to all grid points

kbdi = xr.apply_ufunc(
    kbdi_single_grid,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

***54s for data inputs in unchunked xarray data arrays of shape (26907, 24, 24), 59MB*** This is about 1/500th of the grid, so this is too slow. We need an order of magnitude faster

***3m 18s for data inputs in chunked xarray data arrays of shape (26907, 48, 48), 236MB (59x4)*** It is not scaling linearly


### dask inputs lazy


In [None]:
del pr,tmax,kbdi

In [None]:
# lazy
pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

In [None]:
# lazy
# Apply to all grid points

kbdi = xr.apply_ufunc(
    kbdi_single_grid,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

In [None]:
%%time
kbdi.compute()

***46s for data inputs in a single chunk dask array of shape (26907, 24, 24), 59MB*** 

***3m 9s for data inputs in a single chunk dask array of shape (26907, 48, 48), 236MB*** this scales approxiately linearly

### the same method but using LocalCluster

In [None]:
del pr,tmax,kbdi

In [None]:
from dask.distributed import Client,LocalCluster

nworkers=20
cluster=LocalCluster(n_workers=nworkers,threads_per_worker=1) # a cluster where each thread is a separate process or "worker"
client=Client(cluster)  # connect to your compute cluster
client.wait_for_workers(n_workers=nworkers,timeout=10) # wait up to 10s for the cluster to be fully ready, error if not ready in 10s
client # print info

In [None]:
# lazy
pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

In [None]:
# lazy
# Apply to all grid points

kbdi = xr.apply_ufunc(
    kbdi_single_grid,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

In [None]:
%%time
kbdi.compute()

***untested for data inputs in a single chunk dask array of shape (26907, 24, 24), 59MB*** 

***memory crash for data inputs in a single chunk dask array of shape (26907, 48, 48), 236MB*** 

# METHOD: dask delayed

This method sends chunks to dask delayed and performs vectorized compute on each chunk. Everything that can be computed outside of the time loop should be in a separate function.

In [4]:
# units conversion functions
def prep_pr(p):
    p = (p / 25.4).round(2)
    p.coords['time_index']=('time',time_index)
    return p

def convert_t_units(t):
    t = ((t * 9 / 5) + 32).round(2)
    return t

In [5]:
# mean annual precip function
def mean_ann_precip(p):
    mean_ann_pr = p.groupby('time.year').sum(min_count=360).mean('year')
    return mean_ann_pr

In [6]:
# initialization function
def init_date(ndays,thresh,p):

    # find number of grids with data
    mask=xr.where(np.isfinite(p.mean('time')),1,0)
    ngrids = mask.sum().data

    # rolling sum
    p_rollsum=p.rolling(time=ndays,min_periods=ndays,center=False).sum('time')    

    # quantify how many grids are never saturated
    threshmask = xr.where((p_rollsum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
    nbad = xr.where((mask)&(threshmask==0),1,0).sum().data # 1=grid on land with data but no init date found
    
    if nbad == 0:      
        return xr.where(p_rollsum>thresh,p_rollsum.time_index,np.nan).min('time')
        # return p_rollsum[p_rollsum>thresh].isel(time=0).time_index.item()
    else:
        return nbad        

In [7]:
# rain category function

def rain_cat(p):

    # Rain mask and consecutive rain days
    # rainmask = np.where(p > 0, 1, 0).astype('int32')
    # cumsum = np.cumsum(rainmask, axis=0)
    # reset = np.where(rainmask == 0, cumsum, np.nan)
    # reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    # rr = cumsum - reset

    # rainmask=xr.where(p>0,1,0).astype('int32')
    # rr=rainmask.cumsum('time')-rainmask.cumsum('time').where(rainmask == 0).ffill(dim='time').fillna(0)    
    rainmask=xr.where(p>0,1,0).astype('int8')
    temp=rainmask.cumsum('time').astype('int8')
    rr=temp-temp.where(rainmask == 0).ffill(dim='time').fillna(0).astype('int8')

    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr.data == 2)
    consec_day2 = [arr.astype('int32') for arr in consec_day2]
    consec_day1 = [arr-1 for arr in consec_day2]
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)
    # # Categorize rain days
    # cat = np.where(rr >= 3, 5, rr)
    # consec_day2 = np.where(rr == 2)[0].astype('int32')
    # consec_day1 = consec_day2 - 1
    # cat[consec_day2] = 5
    # cat[consec_day1] = 5
    # cat = np.where(cat == 5, 2, cat)
    return cat

In [8]:
# pnet function

In [9]:
# kbdi function

In [10]:
%%time
# main code

# %%time these
#1) get one cell for the time dim, create time index, delay it
time = xr.open_dataset(pr_file).time.sel(time=slice(year_start, year_end))
time_index = np.arange(len(time),dtype='float32')
timeind_delay = dask.delayed(time_index)

#2) lazy read pr into chunked object
chunks = {'time':-1,'lat':24,'lon':24}
pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)

#3) call units function
pr_inch = prep_pr(pr)
pr_inch = pr_inch.compute()
# pr_inch
# %%time these
#5) call delayed initialization function
day_int = init_date(7,8,pr_inch)

#6) call dealyed rain category function (time this because it may be faster to start at day_int)
cat = rain_cat(pr_inch)#.compute()
cat

#7) call delayed pnet function
# clean up
# lazy read t into chunk object
# tmax = xr.open_dataset(tmax_file, chunks=chunks).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))
# 
#8) call delayed kbdi function


CPU times: total: 10.2 s
Wall time: 10.1 s


array([[[2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        ...,
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2]],

       [[2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        ...,
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2]],

       [[2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        ...,
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2]],

       ...,

       [[2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        ...,
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2]],

       [[2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        [2, 2, 2, ..., 2, 2, 2],
        ...,
        [2, 2, 2, ..., 

In [11]:
# rainmask=xr.where(pr_inch>0,1,0).astype('int8')
# rr=rainmask.cumsum('time')-rainmask.cumsum('time').where(rainmask == 0).ffill(dim='time').fillna(0)

rainmask=xr.where(pr_inch>0,1,0).astype('int8')
temp=rainmask.cumsum('time').astype('int8')
rr=temp-temp.where(rainmask == 0).ffill(dim='time').fillna(0).astype('int8')
rr


In [None]:
%%time
# rr = rr.compute()
# rr

In [12]:
%%time
# input to this needs to be in-memory xarray object rr, less than 2s to compute, 45 with dask arrays
cat = np.where(rr >= 3, 5, rr)
consec_day2 = np.where(rr.data == 2)
consec_day2 = [arr.astype('int32') for arr in consec_day2]
consec_day1 = [arr-1 for arr in consec_day2]
cat[consec_day2] = 5
cat[consec_day1] = 5
cat = np.where(cat == 5, 2, cat)

CPU times: total: 1.75 s
Wall time: 1.75 s


In [None]:
%%time
consec_day2 = np.where(rr.data == 2)

In [None]:
%%time
consec_day2 = [arr.astype('int32') for arr in consec_day2]


In [None]:
%%time
consec_day1 = [arr-1 for arr in consec_day2]


In [None]:
%%time
cat[consec_day2] = 5


In [None]:
%%time
cat[consec_day1] = 5


In [None]:
%%time
cat = np.where(cat == 5, 2, cat)

In [None]:
%%time
# Pnet calculation
acc_thresh = 0.2  # inches
pnet = np.copy(pr_inch)
pnet[cat == 1] = pnet[cat == 1] - acc_thresh
pnet = np.where(pnet < 0, 0, pnet)
pnet
# # Adjust for consecutive rain days
# consec_inds = np.where(cat == 2)[0].astype('int32')
# accpr = 0.0
# thresh_flag = False

# for i, ind in enumerate(consec_inds):
#     accpr += PR[ind]
#     if accpr <= acc_thresh and not thresh_flag:
#         pnet[ind] = 0
#     elif accpr > acc_thresh and not thresh_flag:
#         accpr -= acc_thresh
#         pnet[ind] = accpr
#         thresh_flag = True
#     else:
#         pnet[ind] = PR[ind]
#     if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
#         accpr = 0.0
#         thresh_flag = False

In [None]:
np.unique(cat)

In [None]:
[arr-1 for arr in test2]

In [None]:
def kbdi_single_grid(tmax_1d, pr_1d):
    # Ensure inputs are NumPy arrays
    T = np.asarray(tmax_1d)
    PR = np.asarray(pr_1d)

    if not np.all(np.isfinite(PR)):
        return np.full_like(PR, np.nan)

    # Create a time index
    time_index = np.arange(len(PR),dtype='float32')

    # 7-day rolling precipitation sum
    ndays = 7
    pr_thresh = 8.0  # inches
    pr_weeksum = np.convolve(PR, np.ones(ndays), mode='valid').astype('float32')
    pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum]).astype('float32')

    try:
        day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
    except IndexError:
        return np.full_like(PR, np.nan)

    # Rain mask and consecutive rain days
    rainmask = np.where(PR > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr == 2)[0].astype('int32')
    consec_day1 = consec_day2 - 1
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)

    # Pnet calculation
    acc_thresh = 0.2  # inches
    pnet = np.copy(PR)
    pnet[cat == 1] = pnet[cat == 1] - acc_thresh
    pnet = np.where(pnet < 0, 0, pnet)

    # Adjust for consecutive rain days
    consec_inds = np.where(cat == 2)[0].astype('int32')
    accpr = 0.0
    thresh_flag = False

    for i, ind in enumerate(consec_inds):
        accpr += PR[ind]
        if accpr <= acc_thresh and not thresh_flag:
            pnet[ind] = 0
        elif accpr > acc_thresh and not thresh_flag:
            accpr -= acc_thresh
            pnet[ind] = accpr
            thresh_flag = True
        else:
            pnet[ind] = PR[ind]
        if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
            accpr = 0.0
            thresh_flag = False

    # # Mean annual precipitation (approximate)
    # days_per_year = 365
    # n_years = len(PR) // days_per_year
    # ann_pr = []
    # for i in range(n_years):
    #     year_data = PR[i * days_per_year:(i + 1) * days_per_year]
    #     if np.count_nonzero(~np.isnan(year_data)) >= 360:
    #         ann_pr.append(np.nansum(year_data))
    # if len(ann_pr) == 0:
    #     return np.full_like(PR, np.nan)
    # mean_ann_pr = np.mean(ann_pr)

    # KBDI calculation
    KBDI = np.full_like(PR, np.nan)
    if day_int < len(PR):
        KBDI[day_int] = 0
    else:
        return KBDI

    denominator = 1 + 10.88 * np.exp(-0.0441 * mean_ann_pr)
    for it in range(day_int + 1, len(PR)):
        Q = max(0, KBDI[it - 1] - pnet[it] * 100)
        numerator = (800 - Q) * (0.968 * np.exp(0.0486 * T[it]) - 8.3)
        KBDI[it] = Q + (numerator / denominator) * 1e-3
    
    return KBDI