In [1]:
import xarray as xr
import numpy as np
import pandas as pd

import dask
import dask.array as da

import matplotlib.pyplot as plt

# Get test data

In [None]:
# pr_file = r'D://data/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
# tmax_file = r'D://data/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'

pr_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
tmax_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'

In [None]:
year_start='1951'
year_end='2023'
lat1, lat2 = 32, 34
lon1, lon2 = -90, -88

In [None]:
%%time
# chunks = {'time':-1,'lat':24,'lon':24}
chunks = -1

pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

# print(pr)
pr = pr.compute()
tmax = tmax.compute()
pr

# Vectorized Calculations

Some things are made easier by xarray functions but some things go way too slow when stored in xarray data structures.

First, we'll take care of the xarray things that go quickly.

Then, we'll complete the rest of the calculations with numpy only.

## xarray calcs

In [None]:
%%time
mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year')
landmask = xr.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

In [None]:
%%time
ndays=7
thresh=8 # inches
badval=-100000
pr_weeksum=pr.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   

# quantify how many grids are never saturated
threshmask = xr.where((pr_weeksum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
# nbad = xr.where((landmask)&(threshmask==0),1,0).sum().item() # 1=grid on land with data but no init date found

day_int = (pr_weeksum>=thresh).argmax('time').astype('int32')
day_int = xr.where(threshmask,day_int,badval).astype('int32')

day_int

## numpy calcs

In [None]:
%%time
# Rain mask and consecutive rain days
rainmask = np.where(pr > 0, 1, 0).astype('int32')
cumsum = np.cumsum(rainmask,axis=0)
reset = np.where(rainmask == 0, cumsum, np.nan)
reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset),axis=0)
# rr_np = cumsum - reset
rr = cumsum - reset

In [None]:
# %%time
# rainmask=xr.where(pr>0,1,0) # 1/0 rain/no rain mask
# rr = rainmask.cumsum('time')-rainmask.cumsum('time').where(rainmask == 0).ffill(dim='time').fillna(0)
# # rr.isel(time=0).plot()
# # rainmask.isel(time=0).plot()

In [None]:
# print(rr_np.shape,rr.shape)
# # ((rr[0,:,:]-rr_np[0,:,:])>0).sum()
# ((rr-rr_np)>0).sum()

In [None]:
%%time
# Categorize rain days
cat = np.where(rr >= 3, 5, rr)
# np.argwhere(rr.data==2)
consecday2_timeind,consecday2_latind,consecday2_lonind = np.where(rr == 2)#.astype('int32')
consecday1_timeind = consecday2_timeind - 1
cat[consecday2_timeind,consecday2_latind,consecday2_lonind] = 5
cat[consecday1_timeind,consecday2_latind,consecday2_lonind]  = 5
cat = np.where(cat == 5, 2, cat)
np.unique(cat)

Pnet could go one of two ways:
- loop through every time and use 2D spatial arrays
- nested loops of consecutive rain indexes and 1D collapsed space

Trying the time loop with vectorized space...

In [None]:
def calc_pnet_vectorized(PR, cat, acc_thresh=0.2):
    """
    Vectorized calculation of pnet for 3D arrays
    
    Parameters:
    PR: 3D array (time, lat, lon) - precipitation data
    cat: 3D array (time, lat, lon) - category data (0, 1, or 2)
    acc_thresh: float - accumulation threshold
    
    Returns:
    pnet: 3D array (time, lat, lon) - processed precipitation
    """
    time_steps, lat_size, lon_size = PR.shape
    pnet = PR.copy()
    
    # Handle category 1 (single rain days) - vectorized
    cat1_mask = (cat == 1)
    pnet = np.where(cat1_mask, pnet - acc_thresh, pnet)
    pnet = np.where(pnet < 0, 0, pnet)
    
    # Handle category 2 (consecutive rain days) - needs loop over time
    accpr = np.zeros((lat_size, lon_size))  # 2D accumulation array
    thresh_flag = np.zeros((lat_size, lon_size), dtype=bool)  # 2D flag array
    
    for t in range(time_steps):
        cat2_mask = (cat[t] == 2)  # 2D mask for current time step
        
        # Only process locations with category 2 on this day
        if np.any(cat2_mask):
            # Accumulate precipitation
            accpr[cat2_mask] += PR[t, cat2_mask]
            
            # Check if we need to reset for new consecutive events
            # This requires detecting breaks in consecutive sequences
            if t > 0:
                prev_cat2 = (cat[t-1] == 2)
                new_event_mask = cat2_mask & ~prev_cat2
                accpr[new_event_mask] = PR[t, new_event_mask]
                thresh_flag[new_event_mask] = False
            
            # Apply threshold logic
            # Case 1: Below threshold and flag not set
            below_thresh_mask = cat2_mask & (accpr <= acc_thresh) & ~thresh_flag
            pnet[t, below_thresh_mask] = 0
            
            # Case 2: Above threshold and flag not set (first time over threshold)
            above_thresh_mask = cat2_mask & (accpr > acc_thresh) & ~thresh_flag
            pnet[t, above_thresh_mask] = accpr[above_thresh_mask] - acc_thresh
            thresh_flag[above_thresh_mask] = True
            
            # Case 3: Flag already set (subsequent days after threshold met)
            # pnet remains unchanged (already copied from PR)
            
            # Reset accumulation for locations where consecutive event ends
            if t < time_steps - 1:
                next_cat2 = (cat[t+1] == 2) if t+1 < time_steps else np.zeros_like(cat2_mask)
                end_event_mask = cat2_mask & ~next_cat2
                accpr[end_event_mask] = 0
                thresh_flag[end_event_mask] = False
    
    return pnet.round(3)

In [None]:
%%time
pnet=calc_pnet_vectorized(pr.data,cat)

Yay, this is the same result as the single grid cell calculation. 

Last step is the kbdi calc...

In [None]:
%%time
# KBDI calculation (inches, Fahrenheit)
# nan initialization
KBDI = np.full(pr.shape,np.nan).astype('float32')

# Replace nan with 0 at init date for each grid
ntimes,nlats,nlons=pr.shape
lat_inds, lon_inds = np.meshgrid(np.arange(nlats), np.arange(nlons), indexing='ij')
KBDI[day_int, lat_inds, lon_inds] = 0

# time independent part of the equation
denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr.data)

# looping in time
for it in range(ntimes):
    if it>0:
        # 2D flags to identify initialization date at each grid
        flag_today = np.isfinite(KBDI[it,:,:])
        flag_prev = np.isfinite(KBDI[it-1,:,:])

        # parts of the KBDI equation
        Q = KBDI[it-1,:,:] - pnet[it,:,:]*100
        Q = np.where(Q<0,0,Q) # correct any negatives
        numerator = (800 - Q) * (0.968 * np.exp(0.0486*tmax.data[it,:,:]) - 8.3)
        
        # replace nan KBDI with finite value only after initialization date at each grid
        # this happens when today's flag is False but yesterday's was True (finds the 0 at each grid)
        KBDI[it,:,:] = np.where((~flag_today)&(flag_prev),Q + (numerator/denominator)*1E-3,KBDI[it,:,:])  
        KBDI[it,:,:] = np.where(KBDI[it,:,:]< 0.0,0,KBDI[it,:,:])  # correct any negatives

        del Q,numerator
        

In [None]:
# for it in range(istart,iend):
#     print(round(test[it].item(),3), round(pr.sel(lat=32,lon=-90,method='nearest')[it].item(),3))


In [None]:
KBDI = xr.DataArray(KBDI, coords = pr.coords)
# KBDI.coords['time_index'] = ('time',time_index)

KBDI.isel(lat=0,lon=0).plot(figsize=(20,3))
plt.axhline(200,color='lightgreen',ls='dashed')
plt.axhline(400,color='gold',ls='dashed')
plt.axhline(600,color='firebrick',ls='dashed')

In [None]:
# look at min/max
print(KBDI[:,0,0].min().item(), KBDI[:,0,0].max().item())

I have verified that this result is the same as from the single grid calculations.

# All in one place

In [None]:
%%time
pr_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
tmax_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'
year_start='1951'
year_end='2023'
lat1, lat2 = 32, 35
lon1, lon2 = -90, -87
chunks = -1

pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

pr = pr.compute()
tmax = tmax.compute()

print(pr.shape)

In [None]:
%%time

mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year')
landmask = xr.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

ndays=7
thresh=8 # inches
badval=-100000
pr_weeksum=pr.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   

# quantify how many grids are never saturated
threshmask = xr.where((pr_weeksum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
# nbad = xr.where((landmask)&(threshmask==0),1,0).sum().item() # 1=grid on land with data but no init date found

day_int = (pr_weeksum>=thresh).argmax('time').astype('int32')
day_int = xr.where(threshmask,day_int,badval).astype('int32')

# Rain mask and consecutive rain days
rainmask = np.where(pr > 0, 1, 0).astype('int32')
cumsum = np.cumsum(rainmask,axis=0)
reset = np.where(rainmask == 0, cumsum, np.nan)
reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset),axis=0)
# rr_np = cumsum - reset
rr = cumsum - reset

# Categorize rain days
cat = np.where(rr >= 3, 5, rr)
consecday2_timeind,consecday2_latind,consecday2_lonind = np.where(rr == 2)#.astype('int32')
consecday1_timeind = consecday2_timeind - 1
cat[consecday2_timeind,consecday2_latind,consecday2_lonind] = 5
cat[consecday1_timeind,consecday2_latind,consecday2_lonind]  = 5
cat = np.where(cat == 5, 2, cat)

def calc_pnet_vectorized(PR, cat, acc_thresh=0.2):
    """
    Vectorized calculation of pnet for 3D arrays
    
    Parameters:
    PR: 3D array (time, lat, lon) - precipitation data
    cat: 3D array (time, lat, lon) - category data (0, 1, or 2)
    acc_thresh: float - accumulation threshold
    
    Returns:
    pnet: 3D array (time, lat, lon) - processed precipitation
    """
    time_steps, lat_size, lon_size = PR.shape
    pnet = PR.copy()
    
    # Handle category 1 (single rain days) - vectorized
    cat1_mask = (cat == 1)
    pnet = np.where(cat1_mask, pnet - acc_thresh, pnet)
    pnet = np.where(pnet < 0, 0, pnet)
    
    # Handle category 2 (consecutive rain days) - needs loop over time
    accpr = np.zeros((lat_size, lon_size))  # 2D accumulation array
    thresh_flag = np.zeros((lat_size, lon_size), dtype=bool)  # 2D flag array
    
    for t in range(time_steps):
        cat2_mask = (cat[t] == 2)  # 2D mask for current time step
        
        # Only process locations with category 2 on this day
        if np.any(cat2_mask):
            # Accumulate precipitation
            accpr[cat2_mask] += PR[t, cat2_mask]
            
            # Check if we need to reset for new consecutive events
            # This requires detecting breaks in consecutive sequences
            if t > 0:
                prev_cat2 = (cat[t-1] == 2)
                new_event_mask = cat2_mask & ~prev_cat2
                accpr[new_event_mask] = PR[t, new_event_mask]
                thresh_flag[new_event_mask] = False
            
            # Apply threshold logic
            # Case 1: Below threshold and flag not set
            below_thresh_mask = cat2_mask & (accpr <= acc_thresh) & ~thresh_flag
            pnet[t, below_thresh_mask] = 0
            
            # Case 2: Above threshold and flag not set (first time over threshold)
            above_thresh_mask = cat2_mask & (accpr > acc_thresh) & ~thresh_flag
            pnet[t, above_thresh_mask] = accpr[above_thresh_mask] - acc_thresh
            thresh_flag[above_thresh_mask] = True
            
            # Case 3: Flag already set (subsequent days after threshold met)
            # pnet remains unchanged (already copied from PR)
            
            # Reset accumulation for locations where consecutive event ends
            if t < time_steps - 1:
                next_cat2 = (cat[t+1] == 2) if t+1 < time_steps else np.zeros_like(cat2_mask)
                end_event_mask = cat2_mask & ~next_cat2
                accpr[end_event_mask] = 0
                thresh_flag[end_event_mask] = False
    
    return pnet.round(3)

pnet=calc_pnet_vectorized(pr.data,cat)

# KBDI calculation (inches, Fahrenheit)
# nan initialization
KBDI = np.full(pr.shape,np.nan).astype('float32')

# Replace nan with 0 at init date for each grid
ntimes,nlats,nlons=pr.shape
lat_inds, lon_inds = np.meshgrid(np.arange(nlats), np.arange(nlons), indexing='ij')
KBDI[day_int, lat_inds, lon_inds] = 0

# time independent part of the equation
denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr.data)

# looping in time
for it in range(ntimes):
    if it>0:
        # 2D flags to identify initialization date at each grid
        flag_today = np.isfinite(KBDI[it,:,:])
        flag_prev = np.isfinite(KBDI[it-1,:,:])

        # parts of the KBDI equation
        Q = KBDI[it-1,:,:] - pnet[it,:,:]*100
        Q = np.where(Q<0,0,Q) # correct any negatives
        numerator = (800 - Q) * (0.968 * np.exp(0.0486*tmax.data[it,:,:]) - 8.3)
        
        # replace nan KBDI with finite value only after initialization date at each grid
        # this happens when today's flag is False but yesterday's was True (finds the 0 at each grid)
        KBDI[it,:,:] = np.where((~flag_today)&(flag_prev),Q + (numerator/denominator)*1E-3,KBDI[it,:,:])  
        KBDI[it,:,:] = np.where(KBDI[it,:,:]< 0.0,0,KBDI[it,:,:])  # correct any negatives

        del Q,numerator

KBDI = xr.DataArray(KBDI, coords = pr.coords)

In [None]:
pr.data.nbytes/1E6

The single grid version of this that used apply_ufunc to vectorize/parallelize (test_parallel.ipynb) took 3.5 minutes for the full timeseries on a 48x48 chunk in space. The vectorized version above is running ten times that amount of data in less than 30 seconds.

Next step, parallelize with dask

# Parallelize with Dask

In [10]:
%%time
pr_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
tmax_file = r'C://Users/kerrie/Documents/02_LocalData/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'
year_start='1951'
year_end='2023'
lat1, lat2 = 32, 34
lon1, lon2 = -90, -88
chunks = {'time':-1,'lat':24,'lon':24}
chunks_2D = {'lat':24,'lon':24}

pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

pr

CPU times: total: 31.2 ms
Wall time: 24 ms


Unnamed: 0,Array,Chunk
Bytes,234.34 MiB,58.59 MiB
Shape,"(26663, 48, 48)","(26663, 24, 24)"
Dask graph,4 chunks in 4 graph layers,4 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 234.34 MiB 58.59 MiB Shape (26663, 48, 48) (26663, 24, 24) Dask graph 4 chunks in 4 graph layers Data type float32 numpy.ndarray",48  48  26663,

Unnamed: 0,Array,Chunk
Bytes,234.34 MiB,58.59 MiB
Shape,"(26663, 48, 48)","(26663, 24, 24)"
Dask graph,4 chunks in 4 graph layers,4 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [27]:
# functions operate on xarray data arrays

def calc_mean_pr(pr):
    # mean annual precipitation at each grid
    mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year')
    return mean_ann_pr

def find_initialization_index(pr):
    # create a land mask
    # landmask = xr.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

    # find the saturation/initialization date (8 inches precip over a week)
    
    # constants
    ndays=7
    thresh=8  # inches
    badval=-100000  # value to use for grids that never reach saturation
    
    # rolling weekly sum of precipitation
    pr_weeksum=pr.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   
    
    # quantify how many grids are never saturated
    threshmask = xr.where((pr_weeksum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
    # nbad = xr.where((landmask)&(threshmask==0),1,0).sum().item() # 1=grid on land with data but no init date found

    # save the initialization index for each grid
    day_int = (pr_weeksum>=thresh).argmax('time').astype('int32') # will yield 0 if no saturation reached
    day_int = xr.where(threshmask,day_int,badval).astype('int32') # indicate grids that don't reach saturation with badval

    return day_int

In [18]:
# functions that operate on numpy arrays

def rain_categories(pr):
    # Rain mask and consecutive rain days
    rainmask = np.where(pr > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask,axis=0)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset),axis=0)
    # rr_np = cumsum - reset
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consecday2_timeind,consecday2_latind,consecday2_lonind = np.where(rr == 2)
    consecday1_timeind = consecday2_timeind - 1
    cat[consecday2_timeind,consecday2_latind,consecday2_lonind] = 5
    cat[consecday1_timeind,consecday2_latind,consecday2_lonind]  = 5
    cat = np.where(cat == 5, 2, cat)
    return cat

def calc_pnet_vectorized(PR, cat, acc_thresh=0.2):
    """
    Vectorized calculation of pnet for 3D arrays
    
    Parameters:
    PR: 3D array (time, lat, lon) - precipitation data
    cat: 3D array (time, lat, lon) - category data (0, 1, or 2)
    acc_thresh: float - accumulation threshold
    
    Returns:
    pnet: 3D array (time, lat, lon) - processed precipitation
    """
    time_steps, lat_size, lon_size = PR.shape
    pnet = PR.copy()
    
    # Handle category 1 (single rain days) - vectorized
    cat1_mask = (cat == 1)
    pnet = np.where(cat1_mask, pnet - acc_thresh, pnet)
    pnet = np.where(pnet < 0, 0, pnet)
    
    # Handle category 2 (consecutive rain days) - needs loop over time
    accpr = np.zeros((lat_size, lon_size))  # 2D accumulation array
    thresh_flag = np.zeros((lat_size, lon_size), dtype=bool)  # 2D flag array
    
    for t in range(time_steps):
        cat2_mask = (cat[t] == 2)  # 2D mask for current time step
        
        # Only process locations with category 2 on this day
        if np.any(cat2_mask):
            # Accumulate precipitation
            accpr[cat2_mask] += PR[t, cat2_mask]
            
            # Check if we need to reset for new consecutive events
            # This requires detecting breaks in consecutive sequences
            if t > 0:
                prev_cat2 = (cat[t-1] == 2)
                new_event_mask = cat2_mask & ~prev_cat2
                accpr[new_event_mask] = PR[t, new_event_mask]
                thresh_flag[new_event_mask] = False
            
            # Apply threshold logic
            # Case 1: Below threshold and flag not set
            below_thresh_mask = cat2_mask & (accpr <= acc_thresh) & ~thresh_flag
            pnet[t, below_thresh_mask] = 0
            
            # Case 2: Above threshold and flag not set (first time over threshold)
            above_thresh_mask = cat2_mask & (accpr > acc_thresh) & ~thresh_flag
            pnet[t, above_thresh_mask] = accpr[above_thresh_mask] - acc_thresh
            thresh_flag[above_thresh_mask] = True
            
            # Case 3: Flag already set (subsequent days after threshold met)
            # pnet remains unchanged (already copied from PR)
            
            # Reset accumulation for locations where consecutive event ends
            if t < time_steps - 1:
                next_cat2 = (cat[t+1] == 2) if t+1 < time_steps else np.zeros_like(cat2_mask)
                end_event_mask = cat2_mask & ~next_cat2
                accpr[end_event_mask] = 0
                thresh_flag[end_event_mask] = False
    
    return pnet.round(3)

def calc_kbdi(pr,mean_ann_pr,pnet,tmax,day_int):
    # KBDI calculation (inches, Fahrenheit)
    # nan initialization
    KBDI = np.full(pr.shape,np.nan).astype('float32')
    
    # Replace nan with 0 at init date for each grid
    ntimes,nlats,nlons=pr.shape
    lat_inds, lon_inds = np.meshgrid(np.arange(nlats), np.arange(nlons), indexing='ij')
    KBDI[day_int, lat_inds, lon_inds] = 0
    
    # time independent part of the equation
    denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr.data)
    
    # looping in time
    for it in range(ntimes):
        if it>0:
            # 2D flags to identify initialization date at each grid
            flag_today = np.isfinite(KBDI[it,:,:])
            flag_prev = np.isfinite(KBDI[it-1,:,:])
    
            # parts of the KBDI equation
            Q = KBDI[it-1,:,:] - pnet[it,:,:]*100
            Q = np.where(Q<0,0,Q) # correct any negatives
            numerator = (800 - Q) * (0.968 * np.exp(0.0486*tmax.data[it,:,:]) - 8.3)
            
            # replace nan KBDI with finite value only after initialization date at each grid
            # this happens when today's flag is False but yesterday's was True (finds the 0 at each grid)
            KBDI[it,:,:] = np.where((~flag_today)&(flag_prev),Q + (numerator/denominator)*1E-3,KBDI[it,:,:])  
            KBDI[it,:,:] = np.where(KBDI[it,:,:]< 0.0,0,KBDI[it,:,:])  # correct any negatives
    
            del Q,numerator
    return KBDI


In [12]:
%%time
mean_ann_pr = calc_mean_pr(pr).compute()
mean_ann_pr = mean_ann_pr.chunk(chunks_2D)
# result = calc_2D_inputs(pr)
mean_ann_pr

CPU times: total: 4.8 s
Wall time: 3.85 s


Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.00 kiB 2.25 kiB Shape (48, 48) (24, 24) Dask graph 4 chunks in 1 graph layer Data type float32 numpy.ndarray",48  48,

Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [26]:
%%time
day_int = calc_2D_inputs(pr,mean_ann_pr).compute()
day_int = day_int.chunk(chunks_2D)

CPU times: total: 3.95 s
Wall time: 3.15 s


In [None]:
# %%time
# day_int = day_int.compute()

In [24]:
%%time
# running this with pr as dask array does not work bcz of the np functions used
# cat = rain_categories(pr.data)

# this will work with dask delayed tho because it uses numpy arrays not dask arrays
test = pr.data.to_delayed().ravel()
task_list = [dask.delayed(rain_categories)(arr) for arr in test]
result = dask.compute(*task_list)

CPU times: total: 6.66 s
Wall time: 4.08 s


In [None]:
%%time
# pnet = calc_pnet_vectorized(pr.data, cat)



In [None]:
%%time
KBDI = calc_kbdi(pr,mean_ann_pr,pnet,tmax,day_int)

# Consolidating

In [31]:
# functions operate on xarray data arrays

def calc_mean_pr(pr):
    # mean annual precipitation at each grid
    mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year')
    return mean_ann_pr

def find_initialization_index(pr):
    # create a land mask
    # landmask = xr.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

    # find the saturation/initialization date (8 inches precip over a week)
    
    # constants
    ndays=7
    thresh=8  # inches
    badval=-100000  # value to use for grids that never reach saturation
    
    # rolling weekly sum of precipitation
    pr_weeksum=pr.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   
    
    # quantify how many grids are never saturated
    threshmask = xr.where((pr_weeksum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
    # nbad = xr.where((landmask)&(threshmask==0),1,0).sum().item() # 1=grid on land with data but no init date found

    # save the initialization index for each grid
    day_int = (pr_weeksum>=thresh).argmax('time').astype('int32') # will yield 0 if no saturation reached
    day_int = xr.where(threshmask,day_int,badval).astype('int32') # indicate grids that don't reach saturation with badval

    return day_int

In [60]:
# functions that operate on numpy arrays

def rain_categories(pr):
    # Rain mask and consecutive rain days
    rainmask = np.where(pr > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask,axis=0)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset),axis=0)
    # rr_np = cumsum - reset
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consecday2_timeind,consecday2_latind,consecday2_lonind = np.where(rr == 2)
    consecday1_timeind = consecday2_timeind - 1
    cat[consecday2_timeind,consecday2_latind,consecday2_lonind] = 5
    cat[consecday1_timeind,consecday2_latind,consecday2_lonind]  = 5
    cat = np.where(cat == 5, 2, cat)
    return cat

def calc_pnet_vectorized(PR, cat, acc_thresh=0.2):
    """
    Vectorized calculation of pnet for 3D arrays
    
    Parameters:
    PR: 3D array (time, lat, lon) - precipitation data
    cat: 3D array (time, lat, lon) - category data (0, 1, or 2)
    acc_thresh: float - accumulation threshold
    
    Returns:
    pnet: 3D array (time, lat, lon) - processed precipitation
    """
    time_steps, lat_size, lon_size = PR.shape
    pnet = PR.copy()
    
    # Handle category 1 (single rain days) - vectorized
    cat1_mask = (cat == 1)
    pnet = np.where(cat1_mask, pnet - acc_thresh, pnet)
    pnet = np.where(pnet < 0, 0, pnet)
    
    # Handle category 2 (consecutive rain days) - needs loop over time
    accpr = np.zeros((lat_size, lon_size))  # 2D accumulation array
    thresh_flag = np.zeros((lat_size, lon_size), dtype=bool)  # 2D flag array
    
    for t in range(time_steps):
        cat2_mask = (cat[t] == 2)  # 2D mask for current time step
        
        # Only process locations with category 2 on this day
        if np.any(cat2_mask):
            # Accumulate precipitation
            accpr[cat2_mask] += PR[t, cat2_mask]
            
            # Check if we need to reset for new consecutive events
            # This requires detecting breaks in consecutive sequences
            if t > 0:
                prev_cat2 = (cat[t-1] == 2)
                new_event_mask = cat2_mask & ~prev_cat2
                accpr[new_event_mask] = PR[t, new_event_mask]
                thresh_flag[new_event_mask] = False
            
            # Apply threshold logic
            # Case 1: Below threshold and flag not set
            below_thresh_mask = cat2_mask & (accpr <= acc_thresh) & ~thresh_flag
            pnet[t, below_thresh_mask] = 0
            
            # Case 2: Above threshold and flag not set (first time over threshold)
            above_thresh_mask = cat2_mask & (accpr > acc_thresh) & ~thresh_flag
            pnet[t, above_thresh_mask] = accpr[above_thresh_mask] - acc_thresh
            thresh_flag[above_thresh_mask] = True
            
            # Case 3: Flag already set (subsequent days after threshold met)
            # pnet remains unchanged (already copied from PR)
            
            # Reset accumulation for locations where consecutive event ends
            if t < time_steps - 1:
                next_cat2 = (cat[t+1] == 2) if t+1 < time_steps else np.zeros_like(cat2_mask)
                end_event_mask = cat2_mask & ~next_cat2
                accpr[end_event_mask] = 0
                thresh_flag[end_event_mask] = False
    
    return pnet.round(3)

def calc_kbdi(pr,tmax,mean_ann_pr,day_int):
    cat = rain_categories(pr)
    pnet = calc_pnet_vectorized(pr, cat)    
    
    # KBDI calculation (inches, Fahrenheit)
    # nan initialization
    KBDI = np.full(pr.shape,np.nan).astype('float32')
    
    # # Replace nan with 0 at init date for each grid
    # ntimes,nlats,nlons=pr.shape
    # lat_inds, lon_inds = np.meshgrid(np.arange(nlats), np.arange(nlons), indexing='ij')
    # KBDI[day_int, lat_inds, lon_inds] = 0
    
    # # time independent part of the equation
    # denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr.data)
    
    # # looping in time
    # for it in range(ntimes):
    #     if it>0:
    #         # 2D flags to identify initialization date at each grid
    #         flag_today = np.isfinite(KBDI[it,:,:])
    #         flag_prev = np.isfinite(KBDI[it-1,:,:])
    
    #         # parts of the KBDI equation
    #         Q = KBDI[it-1,:,:] - pnet[it,:,:]*100
    #         Q = np.where(Q<0,0,Q) # correct any negatives
    #         numerator = (800 - Q) * (0.968 * np.exp(0.0486*tmax.data[it,:,:]) - 8.3)
            
    #         # replace nan KBDI with finite value only after initialization date at each grid
    #         # this happens when today's flag is False but yesterday's was True (finds the 0 at each grid)
    #         KBDI[it,:,:] = np.where((~flag_today)&(flag_prev),Q + (numerator/denominator)*1E-3,KBDI[it,:,:])  
    #         KBDI[it,:,:] = np.where(KBDI[it,:,:]< 0.0,0,KBDI[it,:,:])  # correct any negatives
    
    #         del Q,numerator
    # return KBDI
    return day_int


In [32]:
%%time
mean_ann_pr = calc_mean_pr(pr).compute()
mean_ann_pr = mean_ann_pr.chunk(chunks_2D).data
mean_ann_pr

CPU times: total: 4.77 s
Wall time: 3.89 s


Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.00 kiB 2.25 kiB Shape (48, 48) (24, 24) Dask graph 4 chunks in 1 graph layer Data type float32 numpy.ndarray",48  48,

Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [41]:
%%time
day_int = find_initialization_index(pr).compute()
day_int = day_int.chunk(chunks_2D).data
day_int

CPU times: total: 3.92 s
Wall time: 3.11 s


Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 9.00 kiB 2.25 kiB Shape (48, 48) (24, 24) Dask graph 4 chunks in 1 graph layer Data type int32 numpy.ndarray",48  48,

Unnamed: 0,Array,Chunk
Bytes,9.00 kiB,2.25 kiB
Shape,"(48, 48)","(24, 24)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,int32 numpy.ndarray,int32 numpy.ndarray


In [61]:
%%time

pr_delayed = pr.data.to_delayed().ravel()
tmax_delayed = tmax.data.to_delayed().ravel()
mpr_delayed = mean_ann_pr.to_delayed().ravel()
init_delayed = day_int.to_delayed().ravel()
zipvars = zip(pr_delayed,tmax_delayed,mpr_delayed,init_delayed)

task_list = [dask.delayed(calc_kbdi)(p,t,mp,ind) for p,t,mp,ind in zipvars]
len(task_list)

CPU times: total: 0 ns
Wall time: 2.69 ms


4

In [62]:
%%time
result = dask.compute(*task_list)

CPU times: total: 21.6 s
Wall time: 11.5 s


In [63]:
result

(array([[ 1563,  1563,  1563,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  2453],
        [ 3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705],
        [ 5026,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705],
        [ 5026,  4004,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705],
        [ 5026,  5026,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,  3705,
          3705,  3705,  3705,  3705,  3705,  3705],
        [ 5026,  502

In [None]:
mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year')
landmask = xr.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

ndays=7
thresh=8 # inches
badval=-100000
pr_weeksum=pr.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   

# quantify how many grids are never saturated
threshmask = xr.where((pr_weeksum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
# nbad = xr.where((landmask)&(threshmask==0),1,0).sum().item() # 1=grid on land with data but no init date found

day_int = (pr_weeksum>=thresh).argmax('time').astype('int32')
day_int = xr.where(threshmask,day_int,badval).astype('int32')

# Rain mask and consecutive rain days
rainmask = np.where(pr > 0, 1, 0).astype('int32')
cumsum = np.cumsum(rainmask,axis=0)
reset = np.where(rainmask == 0, cumsum, np.nan)
reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset),axis=0)
# rr_np = cumsum - reset
rr = cumsum - reset

# Categorize rain days
cat = np.where(rr >= 3, 5, rr)
consecday2_timeind,consecday2_latind,consecday2_lonind = np.where(rr == 2)#.astype('int32')
consecday1_timeind = consecday2_timeind - 1
cat[consecday2_timeind,consecday2_latind,consecday2_lonind] = 5
cat[consecday1_timeind,consecday2_latind,consecday2_lonind]  = 5
cat = np.where(cat == 5, 2, cat)

def calc_pnet_vectorized(PR, cat, acc_thresh=0.2):
    """
    Vectorized calculation of pnet for 3D arrays
    
    Parameters:
    PR: 3D array (time, lat, lon) - precipitation data
    cat: 3D array (time, lat, lon) - category data (0, 1, or 2)
    acc_thresh: float - accumulation threshold
    
    Returns:
    pnet: 3D array (time, lat, lon) - processed precipitation
    """
    time_steps, lat_size, lon_size = PR.shape
    pnet = PR.copy()
    
    # Handle category 1 (single rain days) - vectorized
    cat1_mask = (cat == 1)
    pnet = np.where(cat1_mask, pnet - acc_thresh, pnet)
    pnet = np.where(pnet < 0, 0, pnet)
    
    # Handle category 2 (consecutive rain days) - needs loop over time
    accpr = np.zeros((lat_size, lon_size))  # 2D accumulation array
    thresh_flag = np.zeros((lat_size, lon_size), dtype=bool)  # 2D flag array
    
    for t in range(time_steps):
        cat2_mask = (cat[t] == 2)  # 2D mask for current time step
        
        # Only process locations with category 2 on this day
        if np.any(cat2_mask):
            # Accumulate precipitation
            accpr[cat2_mask] += PR[t, cat2_mask]
            
            # Check if we need to reset for new consecutive events
            # This requires detecting breaks in consecutive sequences
            if t > 0:
                prev_cat2 = (cat[t-1] == 2)
                new_event_mask = cat2_mask & ~prev_cat2
                accpr[new_event_mask] = PR[t, new_event_mask]
                thresh_flag[new_event_mask] = False
            
            # Apply threshold logic
            # Case 1: Below threshold and flag not set
            below_thresh_mask = cat2_mask & (accpr <= acc_thresh) & ~thresh_flag
            pnet[t, below_thresh_mask] = 0
            
            # Case 2: Above threshold and flag not set (first time over threshold)
            above_thresh_mask = cat2_mask & (accpr > acc_thresh) & ~thresh_flag
            pnet[t, above_thresh_mask] = accpr[above_thresh_mask] - acc_thresh
            thresh_flag[above_thresh_mask] = True
            
            # Case 3: Flag already set (subsequent days after threshold met)
            # pnet remains unchanged (already copied from PR)
            
            # Reset accumulation for locations where consecutive event ends
            if t < time_steps - 1:
                next_cat2 = (cat[t+1] == 2) if t+1 < time_steps else np.zeros_like(cat2_mask)
                end_event_mask = cat2_mask & ~next_cat2
                accpr[end_event_mask] = 0
                thresh_flag[end_event_mask] = False
    
    return pnet.round(3)

pnet=calc_pnet_vectorized(pr.data,cat)

# KBDI calculation (inches, Fahrenheit)
# nan initialization
KBDI = np.full(pr.shape,np.nan).astype('float32')

# Replace nan with 0 at init date for each grid
ntimes,nlats,nlons=pr.shape
lat_inds, lon_inds = np.meshgrid(np.arange(nlats), np.arange(nlons), indexing='ij')
KBDI[day_int, lat_inds, lon_inds] = 0

# time independent part of the equation
denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr.data)

# looping in time
for it in range(ntimes):
    if it>0:
        # 2D flags to identify initialization date at each grid
        flag_today = np.isfinite(KBDI[it,:,:])
        flag_prev = np.isfinite(KBDI[it-1,:,:])

        # parts of the KBDI equation
        Q = KBDI[it-1,:,:] - pnet[it,:,:]*100
        Q = np.where(Q<0,0,Q) # correct any negatives
        numerator = (800 - Q) * (0.968 * np.exp(0.0486*tmax.data[it,:,:]) - 8.3)
        
        # replace nan KBDI with finite value only after initialization date at each grid
        # this happens when today's flag is False but yesterday's was True (finds the 0 at each grid)
        KBDI[it,:,:] = np.where((~flag_today)&(flag_prev),Q + (numerator/denominator)*1E-3,KBDI[it,:,:])  
        KBDI[it,:,:] = np.where(KBDI[it,:,:]< 0.0,0,KBDI[it,:,:])  # correct any negatives

        del Q,numerator

KBDI = xr.DataArray(KBDI, coords = pr.coords)

In [None]:
# test = xr.open_dataset('kbdi_singlegrid_applyufuncresult.nc').__xarray_dataarray_variable__.sel(time=slice(None,'2023'))
# test

In [None]:
# (KBDI[:,0,0].isel(time=slice(1562,1562+30))-test.isel(time=slice(1562,1562+30))).plot(figsize=(20,3))
# pr[:,0,0].isel(time=slice(1562,1562+30)).plot()
# plt.show()

In [None]:
KBDI[1563:,0,0]

In [None]:
np.argmax(KBDI[:,0,0]==0)

In [None]:
(np.isfinite(KBDI[:,0,0])).sum()

In [None]:
np.argwhere(KBDI[:,0,0]==0)

In [None]:
plt.plot(test[0:30,0,0])

In [None]:
test[0:30,0,0]

In [None]:
def pnet_singlegrid(PR,CAT,acc_thresh = 0.2):

    # Category 0 (no rain days)
    pnet = PR.copy()
    # Category 1 (single rain days)
    pnet = np.where(CAT==1,pnet-acc_thresh,pnet)
    pnet = np.where(pnet<0,0,pnet)
    # Category 2 (consecutive rain days) 
    consec_inds = np.argwhere(CAT==2).flatten()
    # initializations
    thresh_flag=False
    end_event=False
    accpr=0. 
    # loop through days in each multi-day rain event
    for i,ind in enumerate(consec_inds): 
        # accumulated precip per rain event
        accpr=accpr+PR[ind] 
        # if not over the threshold yet, Pnet is 0
        if accpr<=acc_thresh and not thresh_flag:
            pnet[ind]=0        
        # on the day the threshold is met, subtract the threshold amount and change flag    
        elif accpr>acc_thresh and not thresh_flag:
            accpr=accpr-acc_thresh # accumulate precip and subtract threshold
            pnet[ind]=accpr
            thresh_flag=True        
        # any days after the threshold is met, precip will remain unchanged
        else:
            pnet[ind]=PR[ind]     
        # reset accumulation and flag for the next consecutive rain event
        if i != len(consec_inds)-1:
            if (consec_inds[i+1] != consec_inds[i]+1): 
                accpr=0.
                thresh_flag=False
    return pnet

In [None]:
pr_input = pr.isel(lat=0,lon=0).data
cat_input = cat[:,0,0]
test2 = pnet_singlegrid(pr_input,cat_input)

In [None]:
max(test2.round(2)-test[:,0,0].round(2))

In [None]:
def kbdi_single_grid(tmax_1d, pr_1d):
    # Ensure inputs are NumPy arrays
    T = np.asarray(tmax_1d)
    PR = np.asarray(pr_1d)

    if not np.all(np.isfinite(PR)):
        return np.full_like(PR, np.nan)

    # Create a time index
    time_index = np.arange(len(PR),dtype='float32')

    # 7-day rolling precipitation sum
    ndays = 7
    pr_thresh = 8.0  # inches
    pr_weeksum = np.convolve(PR, np.ones(ndays), mode='valid').astype('float32')
    pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum]).astype('float32')

    try:
        day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
    except IndexError:
        return np.full_like(PR, np.nan)

    # Rain mask and consecutive rain days
    rainmask = np.where(PR > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr == 2)[0].astype('int32')
    consec_day1 = consec_day2 - 1
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)

    # Pnet calculation
    acc_thresh = 0.2  # inches
    pnet = np.copy(PR)
    pnet[cat == 1] = pnet[cat == 1] - acc_thresh
    pnet = np.where(pnet < 0, 0, pnet)

    # Adjust for consecutive rain days
    consec_inds = np.where(cat == 2)[0].astype('int32')
    accpr = 0.0
    thresh_flag = False

    for i, ind in enumerate(consec_inds):
        accpr += PR[ind]
        if accpr <= acc_thresh and not thresh_flag:
            pnet[ind] = 0
        elif accpr > acc_thresh and not thresh_flag:
            accpr -= acc_thresh
            pnet[ind] = accpr
            thresh_flag = True
        else:
            pnet[ind] = PR[ind]
        if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
            accpr = 0.0
            thresh_flag = False

    # Mean annual precipitation (approximate)
    days_per_year = 365
    n_years = len(PR) // days_per_year
    ann_pr = []
    for i in range(n_years):
        year_data = PR[i * days_per_year:(i + 1) * days_per_year]
        if np.count_nonzero(~np.isnan(year_data)) >= 360:
            ann_pr.append(np.nansum(year_data))
    if len(ann_pr) == 0:
        return np.full_like(PR, np.nan)
    mean_ann_pr = np.mean(ann_pr)

    # KBDI calculation
    KBDI = np.full_like(PR, np.nan)
    if day_int < len(PR):
        KBDI[day_int] = 0
    else:
        return KBDI

    denominator = 1 + 10.88 * np.exp(-0.0441 * mean_ann_pr)
    for it in range(day_int + 1, len(PR)):
        Q = max(0, KBDI[it - 1] - pnet[it] * 100)
        numerator = (800 - Q) * (0.968 * np.exp(0.0486 * T[it]) - 8.3)
        KBDI[it] = Q + (numerator / denominator) * 1e-3
    
    return KBDI

In [None]:
def calc_kbdi(T,PR,time_axis):
    # determine if the grid is land or water
    # landmask=1 if np.isfinite(PR.mean('time')) else 0
    landmask=1 if np.isfinite(PR.mean(axis=time_axis)) else 0
    
    # # sum precip in 7 day rolling windows
    # ndays=7
    # pr_thresh=8 # inches
    # # pr_weeksum=PR.rolling(time=ndays,min_periods=ndays,center=False).sum()
    # pr_weeksum = np.convolve(PR.data, np.ones(ndays), mode='valid').astype('float32')
    # pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum]).astype('float32')
    
    
    # # get the first index time where the weekly sum meets the threshold
    # # this is index t-1 for the KBDI calc where we'll set it to 0
    # # day_int = np.argwhere(pr_weeksum.data>pr_thresh).flatten()[0] 
    # # day_int = np.argwhere(pr_weeksum>pr_thresh).flatten()[0] 
    # try:
    #     day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
    # except IndexError:
    #     return np.full_like(PR.data, np.nan)
    
    # # calculate number of consecutive rain days at each time step
    # # I got this code to interrupt a cumulative sum here: https://stackoverflow.com/questions/61753567/convert-cumsum-output-to-binary-array-in-xarray
    # # rainmask=xr.where(PR>0,1,0) # 1/0 rain/no rain mask
    # # rr = (rainmask.cumsum()-rainmask.cumsum().where(rainmask == 0).ffill(dim='time').fillna(0)).data
    # rainmask = np.where(PR.data > 0, 1, 0).astype('int32')
    # cumsum = np.cumsum(rainmask)
    # reset = np.where(rainmask == 0, cumsum, np.nan)
    # reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    # rr = cumsum - reset


    
    # # categorize rainfall days: consecutive rain days (2), single rain days (1), no rain days (0)
    # # label all days that are at least the 3rd consecutive rainfall day with a 5
    # cat=np.where(rr>=3,5,rr)
    # # find the indexes of all second consecutive rainfall days
    # consec_day2 = np.argwhere(rr==2).flatten()
    # # find the indexes of all the first consecutive rainfall days
    # consec_day1 = consec_day2-1
    # # label all consecutive rainfall days with a 2 
    # cat[consec_day2]=5 # first put a 5
    # cat[consec_day1]=5 # first put a 5
    # cat = np.where(cat==5,2,cat) # convert to 2's
    # # should be left with only 2's (consecutive rain days), 1's (single rain days), and 0's (no rain days)
    # # np.unique(cat) # change to assert
    
    
    # # Calc Pnet 
    # acc_thresh = 0.2 # inches
    # # Category 0 (no rain days)
    # pr_np = PR.copy().data
    # pnet = pr_np.copy()
    # # Category 1 (single rain days)
    # pnet = np.where(cat==1,pnet-acc_thresh,pnet)
    # pnet = np.where(pnet<0,0,pnet)
    # # Category 2 (consecutive rain days) 
    # consec_inds = np.argwhere(cat==2).flatten()
    # # initializations
    # thresh_flag=False
    # end_event=False
    # accpr=0. 
    # # loop through days in each multi-day rain event
    # for i,ind in enumerate(consec_inds): 
    #     # accumulated precip per rain event
    #     accpr=accpr+pr_np[ind].item() 
    #     # if not over the threshold yet, Pnet is 0
    #     if accpr<=acc_thresh and not thresh_flag:
    #         pnet[ind]=0        
    #     # on the day the threshold is met, subtract the threshold amount and change flag    
    #     elif accpr>acc_thresh and not thresh_flag:
    #         accpr=accpr-acc_thresh # accumulate precip and subtract threshold
    #         pnet[ind]=accpr
    #         thresh_flag=True        
    #     # any days after the threshold is met, precip will remain unchanged
    #     else:
    #         pnet[ind]=pr_np[ind]     
    #     # reset accumulation and flag for the next consecutive rain event
    #     if i != len(consec_inds)-1:
    #         if (consec_inds[i+1] != consec_inds[i]+1): 
    #             accpr=0.
    #             thresh_flag=False
        
    # # mean annual precip (inches)
    # mean_ann_pr = PR.groupby('time.year').sum(min_count=360).mean().item()
    
    # # KBDI initialization
    # KBDI = np.full(PR.shape,np.nan) # set all to nan
    # Q = KBDI.copy()
    # KBDI[day_int]=0   # set to 0 at saturation day t-1
    # # print(np.unique(KBDI)) # change to assert

    # # convert all inputs to scalar or numpy for speed
    # T_np = T.data
    
    
    # # KBDI calculation (inches, Fahrenheit)
    # denominator = 1 + 10.88 * np.exp(-0.0441*mean_ann_pr)
    # # looping in time, save memory
    # for it in range(day_int+1,KBDI.shape[0]):
    #     Q = max(0,KBDI[it-1] - pnet[it]*100)
    #     numerator = (800 - Q) * (0.968 * np.exp(0.0486*T_np[it]) - 8.3)
    #     KBDI[it] = Q + (numerator/denominator)*1E-3  
    #     KBDI[it] =  min(max(KBDI[it], 0.0), 800.)
    #     del numerator,Q
        
    # # convert to xarray
    # KBDI = xr.DataArray(KBDI, coords = {'time':('time',PR.time.data)})
    # KBDI.coords['time_index'] = ('time',time_index)
    # return KBDI
    return landmask

In [None]:
landmask =0
if landmask:
    print(landmask)
else:
    print(False)

In [None]:
def kbdi(T,PR,pr_ann):
    # this is a single grid calculation on 1D numpy timeseries arrays
    # in the aggregated result (on the full grid) apply_ufunc reorders the core dimension (time) to the last axis
    
    landmask=1 if np.isfinite(PR.mean()) else 0

    if landmask:
        # sum precip in 7 day rolling windows
        ndays=7
        pr_thresh=8 # inches
        pr_weeksum = np.convolve(PR, np.ones(ndays), mode='valid')#.astype('float32')
        pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum])#.astype('float32')
    
        # get the first index time where the weekly sum meets the threshold
        # this is where we'll set KBDI to 0
        try:
            day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
        except:
            day_int = (-100000).astype('int32')
            # return np.full_like(PR, -1,dtype='int32')
    
        # calculate number of consecutive rain days at each time step
        # I got this code to interrupt a cumulative sum here: https://stackoverflow.com/questions/61753567/convert-cumsum-output-to-binary-array-in-xarray
        rainmask = np.where(PR > 0, 1, 0).astype('int32')
        cumsum = np.cumsum(rainmask)
        reset = np.where(rainmask == 0, cumsum, np.nan)
        reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
        rr = cumsum - reset
    
        # categorize rainfall days: consecutive rain days (2), single rain days (1), no rain days (0)
        # label all days that are at least the 3rd consecutive rainfall day with a 5
        cat=np.where(rr>=3,5,rr)
        # find the indexes of all second consecutive rainfall days
        consec_day2 = np.argwhere(rr==2).flatten()
        # find the indexes of all the first consecutive rainfall days
        consec_day1 = consec_day2-1
        # label all consecutive rainfall days with a 2 
        cat[consec_day2]=5 # first put a 5
        cat[consec_day1]=5 # first put a 5
        cat = np.where(cat==5,2,cat) # convert to 2's
        # should be left with only 2's (consecutive rain days), 1's (single rain days), and 0's (no rain days)
        # np.unique(cat) # change to assert    
        
        # Calc Pnet 
        acc_thresh = 0.2 # inches
        # Category 0 (no rain days)
        # pr_np = PR.copy().data
        pnet = PR.copy()
        # Category 1 (single rain days)
        pnet = np.where(cat==1,pnet-acc_thresh,pnet)
        pnet = np.where(pnet<0,0,pnet)
        # Category 2 (consecutive rain days) 
        consec_inds = np.argwhere(cat==2).flatten()
        # initializations
        thresh_flag=False
        end_event=False
        accpr=0. 
        # loop through days in each multi-day rain event
        for i,ind in enumerate(consec_inds): 
            # accumulated precip per rain event
            accpr=accpr+PR[ind]
            # if not over the threshold yet, Pnet is 0
            if accpr<=acc_thresh and not thresh_flag:
                pnet[ind]=0        
            # on the day the threshold is met, subtract the threshold amount and change flag    
            elif accpr>acc_thresh and not thresh_flag:
                accpr=accpr-acc_thresh # accumulate precip and subtract threshold
                pnet[ind]=accpr
                thresh_flag=True        
            # any days after the threshold is met, precip will remain unchanged
            else:
                pnet[ind]=PR[ind]     
            # reset accumulation and flag for the next consecutive rain event
            if i != len(consec_inds)-1:
                if (consec_inds[i+1] != consec_inds[i]+1): 
                    accpr=0.
                    thresh_flag=False   
                    
        # KBDI initialization
        KBDI = np.full(PR.shape,np.nan).astype('float32') # set all to nan
        KBDI[day_int]=0   # set to 0 at saturation day t-1    
        
        # KBDI calculation (inches, Fahrenheit)
        denominator = 1 + 10.88 * np.exp(-0.0441*pr_ann)
        # looping in time, save memory
        for it in range(day_int+1,KBDI.shape[-1]):
            Q = max(0,KBDI[it-1] - pnet[it]*100)
            numerator = (800 - Q) * (0.968 * np.exp(0.0486*T[it]) - 8.3)
            KBDI[it] = Q + (numerator/denominator)*1E-3  
            KBDI[it] =  min(max(KBDI[it], 0.0), 800.)
            del numerator,Q
    else:
        KBDI = (np.nan).astype('float32')
    
    return KBDI

def test(T,PR,pr_ann,dim):
    return xr.apply_ufunc(
    kbdi,
    T,
    PR,
    pr_ann,
    input_core_dims=[[dim], [dim],[]],
    output_core_dims=[[dim]],
    vectorize=True, # apply to all grid cells
    dask="parallelized", # because we're inputting chunked dask arrays
    output_dtypes=[np.float32])

# def test(T,PR,pr_ann,dim):
#     return xr.apply_ufunc(
#     kbdi,
#     T,
#     PR,
#     pr_ann,
#     input_core_dims=[[dim], [dim],[]],
#     output_core_dims=[[dim]],
#     output_dtypes=[np.float32])

### xarray inputs in memory

In [None]:
# lazy
# pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end)),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
# pr = xr.open_dataset(pr_file, chunks={'time':-1,'lat':24,'lon':24}).prcp.sel(time=slice(year_start, year_end))#.isel(lat=slice(0,96),lon=slice(0,96))
# pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end)).isel(lat=slice(0,96),lon=slice(0,96)).chunk({'time':-1,'lat':24,'lon':24})
# chunks = {'time':-1,'lat':24,'lon':24}
chunks = -1
pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
# pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end)).chunk(chunks)
pr = (pr / 25.4).round(2)  # convert to inches

# # tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
# tmax = xr.open_dataset(tmax_file, chunks={'time':-1,'lat':48,'lon':48}).tmax.sel(time=slice(year_start, year_end))
tmax = xr.open_dataset(tmax_file).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

pr

In [None]:
%%time
# pr.isel(time=0).plot()
# pr.groupby('time.year').sum(min_count=360).mean('year')


In [None]:
%%time
mean_ann_pr = pr.groupby('time.year').sum(min_count=360).mean('year').compute() 
mean_ann_pr


In [None]:
# %%time
# # time on a single grid
# test1 = tmax.sel(time=slice(year_start,year_end)).sel(lat=32,lon=-90, method='nearest').load()
# test2 = pr.sel(time=slice(year_start,year_end)).sel(lat=32,lon=-90, method='nearest').load()
# result1 = test(test1, test2, mean_ann_pr.sel(lat=32,lon=-90, method='nearest').item(), "time").compute()
# result1

In [None]:
%%time
# 48x48 in 4 chunks 3m29s
# 48x48 in 1 chunk 3m29s
# full grid time est = 6.5 hours
result = test(tmax, pr, mean_ann_pr, "time").compute()
result

In [None]:
# write file

In [None]:
year_start='1951'
year_end='2024'
lat = 32
lon = -90

test_result = result.sel(time=slice(year_start,year_end)).sel(lat=lat,lon=lon, method='nearest')
# look at min/max
print(test_result.min().item(),test_result.max().item())


# plot KBDI

test_result.plot(figsize=(20,3))
plt.axhline(200,color='lightgreen',ls='dashed')
plt.axhline(400,color='gold',ls='dashed')
plt.axhline(600,color='firebrick',ls='dashed')

In [None]:
test_result.to_netcdf('kbdi_singlegrid_applyufuncresult.nc')

In [None]:
%%time
# compute, hold in memory
pr = pr.compute()
tmax = tmax.compute()

In [None]:
# pr

In [None]:
%%time

# Apply to all grid points

# kbdi = xr.apply_ufunc(
#     kbdi_single_grid,
#     tmax,
#     pr,
#     input_core_dims=[["time"], ["time"]],
#     output_core_dims=[["time"]],
#     vectorize=True,
#     dask="parallelized",
#     output_dtypes=[float],
# )


kbdi = xr.apply_ufunc(
    test,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

# kbdi = xr.apply_ufunc(
#     test,
#     tmax,
#     pr,
#     input_core_dims=[["time"], ["time"]],
#     vectorize=True,
#     dask="parallelized",
#     output_dtypes=[float],
# )

In [None]:
%%time 
kbdi.compute()

***54s for data inputs in unchunked xarray data arrays of shape (26907, 24, 24), 59MB*** This is about 1/500th of the grid, so this is too slow. We need an order of magnitude faster

***3m 18s for data inputs in chunked xarray data arrays of shape (26907, 48, 48), 236MB (59x4)*** It is not scaling linearly


### dask inputs lazy


In [None]:
del pr,tmax,kbdi

In [None]:
# lazy
chunks = {'time':-1,'lat':24,'lon':24}
pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)#.round(2)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)#.round(2)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit
pr

In [None]:
# lazy
# Apply to all grid points

# kbdi = xr.apply_ufunc(
#     kbdi_single_grid,
#     tmax,
#     pr,
#     input_core_dims=[["time"], ["time"]],
#     output_core_dims=[["time"]],
#     vectorize=True,
#     dask="parallelized",
#     output_dtypes=[float],
# )

kbdi = xr.apply_ufunc(
    calc_kbdi,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

In [None]:
%%time
kbdi.compute()

***46s for data inputs in a single chunk dask array of shape (26907, 24, 24), 59MB*** 

***3m 9s for data inputs in a single chunk dask array of shape (26907, 48, 48), 236MB*** this scales approxiately linearly

### the same method but using LocalCluster

In [None]:
del pr,tmax,kbdi

In [None]:
from dask.distributed import Client,LocalCluster

nworkers=20
cluster=LocalCluster(n_workers=nworkers,threads_per_worker=1) # a cluster where each thread is a separate process or "worker"
client=Client(cluster)  # connect to your compute cluster
client.wait_for_workers(n_workers=nworkers,timeout=10) # wait up to 10s for the cluster to be fully ready, error if not ready in 10s
client # print info

In [None]:
# lazy
pr = xr.open_dataset(pr_file, chunks=-1).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
pr = (pr / 25.4).round(2)  # convert to inches

tmax = xr.open_dataset(tmax_file, chunks=-1).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))#.round(2)
tmax = ((tmax * 9 / 5) + 32).round(2)  # convert to Fahrenheit

In [None]:
# lazy
# Apply to all grid points

kbdi = xr.apply_ufunc(
    kbdi_single_grid,
    tmax,
    pr,
    input_core_dims=[["time"], ["time"]],
    output_core_dims=[["time"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)

In [None]:
%%time
kbdi.compute()

***untested for data inputs in a single chunk dask array of shape (26907, 24, 24), 59MB*** 

***memory crash for data inputs in a single chunk dask array of shape (26907, 48, 48), 236MB*** 

# METHOD: dask delayed

This method sends chunks to dask delayed and performs vectorized compute on each chunk. Everything that can be computed outside of the time loop should be in a separate function.

In [None]:
# units conversion functions
def prep_pr(p):
    p = (p / 25.4).round(2)
    p.coords['time_index']=('time',time_index)
    return p

def convert_t_units(t):
    t = ((t * 9 / 5) + 32).round(2)
    return t

In [None]:
# mean annual precip function
def mean_ann_precip(p):
    mean_ann_pr = p.groupby('time.year').sum(min_count=360).mean('year')
    return mean_ann_pr

In [None]:
landmask = np.where(np.isfinite(mean_ann_pr),1,0).astype('int32')

In [None]:
# initialization function
# def init_date(ndays,thresh,p,landmask):
def init_date(p,lm):
    ndays=7
    thresh=8 # inches
    # find number of grids with data
    # mask=xr.where(np.isfinite(p.mean('time')),1,0)
    # ngrids = mask.sum().data

    # rolling sum
    p_rollsum=p.rolling(time=ndays,min_periods=ndays,center=False).sum('time')   
    # pr_rollsum = np.convolve(p, np.ones(ndays), mode='valid').astype('float32')
    # pr_rollsum = np.concatenate([np.full(ndays - 1, np.nan), pr_rollsum]).astype('float32')

    # quantify how many grids are never saturated
    threshmask = xr.where((p_rollsum>=thresh).sum('time')>0,1,0) # 1=init date found, 0=no init date found
    nbad = xr.where((mask)&(threshmask==0),1,0).sum().data # 1=grid on land with data but no init date found
    # threshmask = np.where((p_rollsum>=thresh).sum(axis=0)>0,1,0) # 1=init date found, 0=no init date found
    # nbad = np.where((lm)&(threshmask==0),1,0).sum().data # 1=grid on land with data but no init date found
    
    if nbad == 0:      
        return xr.where(p_rollsum>thresh,p_rollsum.time_index,np.nan).min('time')
        # return np.where(p_rollsum>thresh,p_rollsum.time_index,np.nan).min('time')
        # return p_rollsum[p_rollsum>thresh].isel(time=0).time_index.item()
    else:
        return nbad        


In [None]:
pr_delayed = pr.to_delayed().ravel()
pr_delayed

In [None]:
# rain category function

def rain_cat(p):

    # Rain mask and consecutive rain days
    # rainmask = np.where(p > 0, 1, 0).astype('int32')
    # cumsum = np.cumsum(rainmask, axis=0)
    # reset = np.where(rainmask == 0, cumsum, np.nan)
    # reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    # rr = cumsum - reset

    # rainmask=xr.where(p>0,1,0).astype('int32')
    # rr=rainmask.cumsum('time')-rainmask.cumsum('time').where(rainmask == 0).ffill(dim='time').fillna(0)    
    rainmask=xr.where(p>0,1,0).astype('int8')
    temp=rainmask.cumsum('time').astype('int8')
    rr=temp-temp.where(rainmask == 0).ffill(dim='time').fillna(0).astype('int8')

    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr.data == 2)
    consec_day2 = [arr.astype('int32') for arr in consec_day2]
    consec_day1 = [arr-1 for arr in consec_day2]
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)
    # # Categorize rain days
    # cat = np.where(rr >= 3, 5, rr)
    # consec_day2 = np.where(rr == 2)[0].astype('int32')
    # consec_day1 = consec_day2 - 1
    # cat[consec_day2] = 5
    # cat[consec_day1] = 5
    # cat = np.where(cat == 5, 2, cat)
    return cat

In [None]:
# pnet function

In [None]:
# kbdi function

In [None]:
%%time
# main code

# %%time these
#1) get one cell for the time dim, create time index, delay it
time = xr.open_dataset(pr_file).time.sel(time=slice(year_start, year_end))
time_index = np.arange(len(time),dtype='float32')
timeind_delay = dask.delayed(time_index)

#2) lazy read pr into chunked object
chunks = {'time':-1,'lat':24,'lon':24}
pr = xr.open_dataset(pr_file).prcp.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2)).chunk(chunks)

#3) call units function
pr_inch = prep_pr(pr)
pr_inch = pr_inch.compute()
# pr_inch
# %%time these
#5) call delayed initialization function
day_int = init_date(7,8,pr_inch)

#6) call dealyed rain category function (time this because it may be faster to start at day_int)
cat = rain_cat(pr_inch)#.compute()
cat

#7) call delayed pnet function
# clean up
# lazy read t into chunk object
# tmax = xr.open_dataset(tmax_file, chunks=chunks).tmax.sel(time=slice(year_start, year_end),lat=slice(lat1,lat2), lon=slice(lon1,lon2))
# 
#8) call delayed kbdi function


In [None]:
# rainmask=xr.where(pr_inch>0,1,0).astype('int8')
# rr=rainmask.cumsum('time')-rainmask.cumsum('time').where(rainmask == 0).ffill(dim='time').fillna(0)

rainmask=xr.where(pr_inch>0,1,0).astype('int8')
temp=rainmask.cumsum('time').astype('int8')
rr=temp-temp.where(rainmask == 0).ffill(dim='time').fillna(0).astype('int8')
rr


In [None]:
%%time
# rr = rr.compute()
# rr

In [None]:
%%time
# input to this needs to be in-memory xarray object rr, less than 2s to compute, 45 with dask arrays
cat = np.where(rr >= 3, 5, rr)
consec_day2 = np.where(rr.data == 2)
consec_day2 = [arr.astype('int32') for arr in consec_day2]
consec_day1 = [arr-1 for arr in consec_day2]
cat[consec_day2] = 5
cat[consec_day1] = 5
cat = np.where(cat == 5, 2, cat)

In [None]:
%%time
consec_day2 = np.where(rr.data == 2)

In [None]:
%%time
consec_day2 = [arr.astype('int32') for arr in consec_day2]


In [None]:
%%time
consec_day1 = [arr-1 for arr in consec_day2]


In [None]:
%%time
cat[consec_day2] = 5


In [None]:
%%time
cat[consec_day1] = 5


In [None]:
%%time
cat = np.where(cat == 5, 2, cat)

In [None]:
%%time
# Pnet calculation
acc_thresh = 0.2  # inches
pnet = np.copy(pr_inch)
pnet[cat == 1] = pnet[cat == 1] - acc_thresh
pnet = np.where(pnet < 0, 0, pnet)
pnet
# # Adjust for consecutive rain days
# consec_inds = np.where(cat == 2)[0].astype('int32')
# accpr = 0.0
# thresh_flag = False

# for i, ind in enumerate(consec_inds):
#     accpr += PR[ind]
#     if accpr <= acc_thresh and not thresh_flag:
#         pnet[ind] = 0
#     elif accpr > acc_thresh and not thresh_flag:
#         accpr -= acc_thresh
#         pnet[ind] = accpr
#         thresh_flag = True
#     else:
#         pnet[ind] = PR[ind]
#     if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
#         accpr = 0.0
#         thresh_flag = False

In [None]:
np.unique(cat)

In [None]:
[arr-1 for arr in test2]

In [None]:
def kbdi_single_grid(tmax_1d, pr_1d):
    # Ensure inputs are NumPy arrays
    T = np.asarray(tmax_1d)
    PR = np.asarray(pr_1d)

    if not np.all(np.isfinite(PR)):
        return np.full_like(PR, np.nan)

    # Create a time index
    time_index = np.arange(len(PR),dtype='float32')

    # 7-day rolling precipitation sum
    ndays = 7
    pr_thresh = 8.0  # inches
    pr_weeksum = np.convolve(PR, np.ones(ndays), mode='valid').astype('float32')
    pr_weeksum = np.concatenate([np.full(ndays - 1, np.nan), pr_weeksum]).astype('float32')

    try:
        day_int = np.where(pr_weeksum > pr_thresh)[0][0].astype('int32')
    except IndexError:
        return np.full_like(PR, np.nan)

    # Rain mask and consecutive rain days
    rainmask = np.where(PR > 0, 1, 0).astype('int32')
    cumsum = np.cumsum(rainmask)
    reset = np.where(rainmask == 0, cumsum, np.nan)
    reset = np.maximum.accumulate(np.where(np.isnan(reset), -1, reset))
    rr = cumsum - reset

    # Categorize rain days
    cat = np.where(rr >= 3, 5, rr)
    consec_day2 = np.where(rr == 2)[0].astype('int32')
    consec_day1 = consec_day2 - 1
    cat[consec_day2] = 5
    cat[consec_day1] = 5
    cat = np.where(cat == 5, 2, cat)

    # Pnet calculation
    acc_thresh = 0.2  # inches
    pnet = np.copy(PR)
    pnet[cat == 1] = pnet[cat == 1] - acc_thresh
    pnet = np.where(pnet < 0, 0, pnet)

    # Adjust for consecutive rain days
    consec_inds = np.where(cat == 2)[0].astype('int32')
    accpr = 0.0
    thresh_flag = False

    for i, ind in enumerate(consec_inds):
        accpr += PR[ind]
        if accpr <= acc_thresh and not thresh_flag:
            pnet[ind] = 0
        elif accpr > acc_thresh and not thresh_flag:
            accpr -= acc_thresh
            pnet[ind] = accpr
            thresh_flag = True
        else:
            pnet[ind] = PR[ind]
        if i != len(consec_inds) - 1 and consec_inds[i + 1] != consec_inds[i] + 1:
            accpr = 0.0
            thresh_flag = False

    # # Mean annual precipitation (approximate)
    # days_per_year = 365
    # n_years = len(PR) // days_per_year
    # ann_pr = []
    # for i in range(n_years):
    #     year_data = PR[i * days_per_year:(i + 1) * days_per_year]
    #     if np.count_nonzero(~np.isnan(year_data)) >= 360:
    #         ann_pr.append(np.nansum(year_data))
    # if len(ann_pr) == 0:
    #     return np.full_like(PR, np.nan)
    # mean_ann_pr = np.mean(ann_pr)

    # KBDI calculation
    KBDI = np.full_like(PR, np.nan)
    if day_int < len(PR):
        KBDI[day_int] = 0
    else:
        return KBDI

    denominator = 1 + 10.88 * np.exp(-0.0441 * mean_ann_pr)
    for it in range(day_int + 1, len(PR)):
        Q = max(0, KBDI[it - 1] - pnet[it] * 100)
        numerator = (800 - Q) * (0.968 * np.exp(0.0486 * T[it]) - 8.3)
        KBDI[it] = Q + (numerator / denominator) * 1e-3
    
    return KBDI