In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from numba import jit
import dask
import dask.array as da
import matplotlib.pyplot as plt
import gc

In [None]:
# This version is also optimized to run in the memory available on my laptop

@jit(nopython=True)
def _calc_pnet_consecutive_numba(pr_data, cat_data, acc_thresh=0.2):
    """Numba-optimized function to calculate Pnet for consecutive rain days."""
    pnet = pr_data.copy()
    n_time = len(pr_data)
    
    # Find consecutive rain day indices - more efficient approach
    i = 0
    while i < n_time:
        if cat_data[i] == 2:
            # Found start of consecutive rain event
            event_start = i
            # Find end of consecutive event
            while i < n_time and cat_data[i] == 2:
                i += 1
            event_end = i - 1
            
            # Process this consecutive rain event
            accpr = np.float32(0)
            thresh_flag = False
            
            for j in range(event_start, event_end + 1):
                accpr += pr_data[j]
                
                if accpr <= acc_thresh and not thresh_flag:
                    pnet[j] = np.float32(0)
                elif accpr > acc_thresh and not thresh_flag:
                    accpr -= acc_thresh
                    pnet[j] = accpr
                    thresh_flag = True
                else:
                    pnet[j] = pr_data[j]
        else:
            i += 1
    
    return pnet


@jit(nopython=True,fastmath=True)
def _calc_kbdi_timeseries_numba(T_data, pnet_data, mean_ann_pr, day_int):
    """Numba-optimized KBDI time series calculation."""
    n_time = len(T_data)
    KBDI = np.full(n_time, np.float32(np.nan))
    
    if day_int >= 0 and day_int < n_time:
        KBDI[day_int] = np.float32(0)
        
        # Pre-calculate constant denominator
        denominator = np.float32(1 + 10.88 * np.exp(-0.0441 * mean_ann_pr))
        inv_denominator = np.float32(1e-3) / denominator  # Pre-calculate division
        
        for it in range(day_int + 1, n_time):
            Q = max(np.float32(0), KBDI[it-1] - pnet_data[it] *np.float32(100))
            numerator = np.float32((800 - Q) * (0.968 * np.exp(0.0486 * T_data[it]) - 8.3))
            KBDI[it] = Q + numerator * inv_denominator
    
    return KBDI


@jit(nopython=True)
def _calculate_consecutive_rain_categories_numba(rainmask):
    """Optimized calculation of consecutive rain day categories."""
    n_time = len(rainmask)
    cat_data = np.zeros(n_time, dtype=np.int8)
    
    i = 0
    while i < n_time:
        if rainmask[i] > 0:
            # Found start of rain event
            event_start = i
            event_length = 0
            
            # Count consecutive rain days
            while i < n_time and rainmask[i] > 0:
                event_length += 1
                i += 1
            
            # Assign categories based on event length
            if event_length == 1:
                cat_data[event_start] = 1  # Single rain day
            else:
                # Multiple consecutive rain days
                for j in range(event_start, event_start + event_length):
                    cat_data[j] = 2
        else:
            i += 1
    
    return cat_data


def calc_kbdi_vectorized_optimized(T, PR):
    """
    Optimized vectorized KBDI calculation for multiple grid points.
    
    Parameters:
    -----------
    T : xarray.DataArray
        Temperature data with dimensions (time, lat, lon) in Fahrenheit
    PR : xarray.DataArray  
        Precipitation data with dimensions (time, lat, lon) in inches
    
    Returns:
    --------
    KBDI : xarray.DataArray
        KBDI values with same dimensions as input
    """
    # Create time index
    time_index = np.arange(len(PR.time), dtype=np.int32)
    PR = PR.assign_coords(time_index=('time', time_index))
    T = T.assign_coords(time_index=('time', time_index))
    
    # Parameters
    ndays = np.int8(7)
    pr_thresh = np.float32(8.0)  # inches
    acc_thresh = np.float32(0.2)  # inches
    

    # LAZY CALCULATIONS ON CHUNKED XARRAY (DASK) ARRAYS

    # Calculate 7-day rolling precipitation sum - more efficient
    pr_weeksum = PR.rolling(time=ndays, min_periods=ndays, center=False).sum('time')

    # Calculate mean annual precipitation
    mean_ann_pr = PR.groupby('time.year').sum(min_count=360).mean('year')
    
    # GRID-BY-GRID CALCULATIONS
    # VECTORIZED AND PARALLELIZED WITH XARRAY APPLY_UFUNC AND DASK
    # OPTIMIZED WITH NUMBA
    
    # Optimized saturation day finding
    def find_first_saturation_day_optimized(pr_week_1d):
        """Optimized version using numpy operations."""
        valid_mask = ~np.isnan(pr_week_1d)
        if not valid_mask.any():
            return -1
        
        exceeds = pr_week_1d > pr_thresh
        if not exceeds.any():
            return -1
        
        return int(np.argmax(exceeds))
    
    # Apply across lat/lon dimensions
    saturation_days = xr.apply_ufunc(
        find_first_saturation_day_optimized,
        pr_weeksum,
        input_core_dims=[['time']],
        output_dtypes=[np.int32],
        vectorize=True,
        dask='parallelized')

    # Define optimized function to process a single grid point
    def process_single_point_optimized(pr_1d, t_1d, mean_ann_pr_val, sat_day):
        """Optimized processing for a single lat/lon point."""
        if np.isnan(mean_ann_pr_val) or sat_day < 0:
            return np.full(len(pr_1d), np.float32(np.nan))

        # Create rain mask (0 or 1)
        rainmask = (pr_1d > 0).astype(np.int8)
        
        # Calculate rainfall categories using optimized numba function
        cat_1d = _calculate_consecutive_rain_categories_numba(rainmask)
        
        # Calculate Pnet for consecutive rain days
        pnet_1d = _calc_pnet_consecutive_numba(
            pr_1d, 
            cat_1d, 
            acc_thresh)
        
        # Apply single rain day adjustment
        single_mask = (cat_1d == 1)
        pnet_1d = np.where(single_mask, np.maximum(np.float32(0), pnet_1d - acc_thresh), pnet_1d)
        
        # Calculate KBDI time series
        kbdi_1d = _calc_kbdi_timeseries_numba(
            t_1d,
            pnet_1d,
            mean_ann_pr_val,
            sat_day)
        
        return kbdi_1d
    
    # Apply the function across all grid points
    KBDI = xr.apply_ufunc(
        process_single_point_optimized,
        PR.swap_dims({'time': 'time_index'}),
        T.swap_dims({'time': 'time_index'}),
        mean_ann_pr,
        saturation_days,
        input_core_dims=[['time_index'], ['time_index'], [], []],
        output_core_dims=[['time_index']],
        output_dtypes=[np.float32],
        vectorize=True,
        dask='parallelized'
    )
    
    # Convert back to original time coordinate
    KBDI = KBDI.swap_dims({'time_index': 'time'})
    KBDI = KBDI.assign_coords(time=PR.time)

    return KBDI


def calc_kbdi_chunked_processing(T, PR, spatial_chunk_size=10):
    """
    Process KBDI calculation in spatial chunks to minimize memory usage.
    
    Parameters:
    -----------
    T : xarray.DataArray
        Temperature data - will be converted to Fahrenheit
    PR : xarray.DataArray  
        Precipitation data - will be converted to inches
    spatial_chunk_size : int
        Size of spatial chunks for processing
    
    Returns:
    --------
    KBDI : xarray.DataArray
        KBDI values with same dimensions as input
    """
    
    # Get dimensions
    n_lat, n_lon = T.sizes['lat'], T.sizes['lon']
    
    # Create output array template
    kbdi_template = xr.zeros_like(T, dtype=np.float32)
    kbdi_template = kbdi_template.rename('kbdi')
    
    # Process data in spatial chunks
    lat_chunks = range(0, n_lat, spatial_chunk_size)
    lon_chunks = range(0, n_lon, spatial_chunk_size)
    
    total_chunks = len(lat_chunks) * len(lon_chunks)
    processed_chunks = 0
    
    print(f"Processing {total_chunks} spatial chunks...")
    
    # Store results for each chunk
    kbdi_chunks = []
    
    for i, lat_start in enumerate(lat_chunks):
        lat_end = min(lat_start + spatial_chunk_size, n_lat)
        lat_slice = slice(lat_start, lat_end)
        
        lon_row_chunks = []
        
        for j, lon_start in enumerate(lon_chunks):
            lon_end = min(lon_start + spatial_chunk_size, n_lon)
            lon_slice = slice(lon_start, lon_end)
            
            print(f"Processing chunk {processed_chunks + 1}/{total_chunks} "
                  f"(lat: {lat_start}-{lat_end}, lon: {lon_start}-{lon_end})")
            
            # Extract chunk data
            T_chunk = T.isel(lat=lat_slice, lon=lon_slice).load()
            PR_chunk = PR.isel(lat=lat_slice, lon=lon_slice).load()
            
            # Convert units
            T_chunk = T_chunk.round(2).astype('float32')
            T_chunk = T_chunk * np.float32(9/5) + np.float32(32.0)  # Convert to Fahrenheit
            
            PR_chunk = PR_chunk.round(2).astype('float32')
            PR_chunk = PR_chunk * np.float32(1/25.4)  # Convert to inches
            
            # Process this chunk
            kbdi_chunk = calc_kbdi_vectorized_optimized(T_chunk, PR_chunk)
            
            # Store result
            lon_row_chunks.append(kbdi_chunk)
            
            # Clean up memory
            del T_chunk, PR_chunk, kbdi_chunk
            gc.collect()
            
            processed_chunks += 1
        
        # Concatenate longitude chunks for this latitude band
        lat_band_result = xr.concat(lon_row_chunks, dim='lon')
        kbdi_chunks.append(lat_band_result)
        
        # Clean up
        del lon_row_chunks
        gc.collect()
    
    # Concatenate all latitude bands
    print("Combining results...")
    kbdi_result = xr.concat(kbdi_chunks, dim='lat')
    kbdi_result.name='kbdi'
    
    return kbdi_result


def calc_kbdi_memory_efficient(pr_file, tmax_file, year_start='1951', year_end='2024', 
                              spatial_chunk_size=20):
    """
    Memory-efficient KBDI calculation for large datasets.
    Note: KBDI calculation is cumulative over time, so temporal chunking is not possible.
    
    Parameters:
    -----------
    pr_file : str
        Path to precipitation NetCDF file
    tmax_file : str  
        Path to temperature NetCDF file
    year_start : str
        Start year for processing
    year_end : str
        End year for processing
    spatial_chunk_size : int
        Size of spatial chunks for processing (default: 20)
    
    Returns:
    --------
    KBDI : xarray.DataArray
        KBDI values
    """
    
    # Configure Dask for memory efficiency
    dask.config.set({
        'array.chunk-size': '50MB',   # Smaller chunks for better memory control
        'array.slicing.split_large_chunks': True,
        'distributed.worker.memory.target': 0.6,  # Use 60% of available memory
        'distributed.worker.memory.spill': 0.75,  # Spill to disk at 75%
        'distributed.worker.memory.pause': 0.85   # Pause at 85%
    })
    
    # Use smaller chunks for initial data loading - keep time dimension intact
    load_chunks = {'lat': spatial_chunk_size, 'lon': spatial_chunk_size, 'time': -1}
    
    print(f"Loading precipitation data...")
    with xr.open_dataset(pr_file, chunks=load_chunks) as ds_pr:
        pr = ds_pr.prcp.sel(time=slice(year_start, year_end))
        
    print(f"Loading temperature data...")
    with xr.open_dataset(tmax_file, chunks=load_chunks) as ds_tmax:
        tmax = ds_tmax.tmax.sel(time=slice(year_start, year_end))
    
    print(f"Data shape: {pr.shape}")
    print(f"Processing with spatial chunks of size {spatial_chunk_size}x{spatial_chunk_size}")
    print("Note: Time dimension kept intact due to cumulative nature of KBDI calculation")
    
    # Process only spatially chunked (time must remain intact)
    kbdi_result = calc_kbdi_chunked_processing(tmax, pr, spatial_chunk_size)
    
    # Add metadata
    kbdi_result.attrs = {
        'standard_name': 'keetch_byram_drought_index',
        'long_name': 'Keetch-Byram Drought Index',
        'units': 'dimensionless'
    }
    
    return kbdi_result




In [None]:
%%time 
# Main execution with memory-efficient approach
if __name__ == "__main__":
    pr_file = r'D://data/nclimgrid_daily/prcp_nClimGridDaily_1951-2024_USsouth.nc'
    tmax_file = r'D://data/nclimgrid_daily/tmax_nClimGridDaily_1951-2024_USsouth.nc'
    
    year_start = '1951'
    year_end = '2024'
    
    # Adjust spatial_chunk_size based on your memory constraints
    # For 30GB RAM with your data size (26907, 358, 753):
    # - spatial_chunk_size=20: ~3-4GB peak memory usage
    # - spatial_chunk_size=15: ~2-3GB peak memory usage  
    # - spatial_chunk_size=10: ~1-2GB peak memory usage
    # - spatial_chunk_size=8:  ~1GB peak memory usage (most conservative)
    
    chunk_sizes_to_try = [128,64]  # Try progressively smaller chunks
    
    for chunk_size in chunk_sizes_to_try:
        try:
            print(f"\nAttempting with spatial_chunk_size={chunk_size}...")
            kbdi_result = calc_kbdi_memory_efficient(
                pr_file, tmax_file, year_start, year_end, 
                spatial_chunk_size=chunk_size
            )
            
            print("Calculation successful! Saving results...")
            # Save to NetCDF in chunks to avoid memory issues during writing
            output_file = f'kbdi_nclimgrid_{year_start}-{year_end}.nc'
            
            # Use chunked encoding for efficient storage
            # this part is broken
            # encoding = {
            #     kbdi_result.name: {
            #         'zlib': True, 
            #         'complevel': 4,
            #         'chunksizes': (min(365, len(kbdi_result.time)), 
            #                       min(chunk_size, len(kbdi_result.lat)), 
            #                       min(chunk_size, len(kbdi_result.lon)))
            #     }
            # }
            
            kbdi_result.to_netcdf(output_file)#, encoding=encoding)
            print(f"Results saved to {output_file}")
            print(f"Final data shape: {kbdi_result.shape}")
            break  # Success - exit the retry loop
            
        except MemoryError as e:
            print(f"Memory error with chunk_size={chunk_size}: {e}")
            if chunk_size == chunk_sizes_to_try[-1]:  # Last attempt
                print("ERROR: All chunk sizes failed. Consider:")
                print("1. Reducing the time range (process fewer years)")
                print("2. Processing subregions separately") 
                print("3. Using a machine with more RAM")
                raise
            else:
                print(f"Trying smaller chunk size...")
                # Force garbage collection before next attempt
                gc.collect()
                continue
                
        except Exception as e:
            print(f"Unexpected error: {e}")
            raise

In [None]:
kbdi_result.to_netcdf(output_file)