# Custom 1-in-X Analysis for Effective Temperature

## Table of Contents
- [Overview](#overview)
- [Key Features](#key-features)
- [Requirements](#requirements)
- [Configuration](#configuration)
- [Implementation Details](#implementation-details)
- [Usage Examples](#usage-examples)
- [Performance Considerations](#performance-considerations)
- [Troubleshooting](#troubleshooting)

## Overview

This notebook demonstrates how to calculate 1-in-X year return values for custom metrics (like Effective Temperature) using extreme value analysis.

The workflow follows `cava_data()` patterns from `climakitae.explore.vulnerability` with 
a specially created function called `calculate_1_in_X_custom()`.

The function is currently designed to take in an xr.Dataset and a pd.DataFrame of (lat, lon)
locations at which to calculate the 1-in-X values.

> [!WARNING]
> **Thread Safety**: `DataInterface` does not play nicely with multi-threading. If you want to multi-thread your calls to `get_data()`, please use code cell 2 to monkey patch the `__init__()` method.

> [!TIP]
> For quiet output, set `VERBOSE = False` in the configuration section.

> [!IMPORTANT]
> This notebook is optimized for speed over precision. Bootstrap samples are set to 0, which means confidence intervals are not representative. Increase `bootstrap_runs` in `_process_one_simulation()` for accurate confidence intervals.

## Requirements

### Python Packages
- `climakitae` (latest version)
- `xarray`
- `pandas`
- `numpy`
- `concurrent.futures` (for parallelization)

### System Requirements
- **Memory**: ~5 GB RAM per worker
- **Recommended**: JupyterHub with 30+ GB RAM for parallel processing
- **Runtime**: ~3.5 hours per 50 locations per worker

### Custom Metric

The custom metric used as the example in this notebook is effective temperature calculated 
as:  
`Teff = 0.7*Tmax0 + 0.003*Tmin0*Tmax1 + 0.002*Tmin1*Tmax2`

### Parallelization

`DataInterface` does not play nicely with multi-threading, if you want to multi-thread your calls to `get_data()` please use code cell 2 to monkey patch the `__init__()` method or update your version of `climakitae` by pulling from the main repo on GitHub.

### Run Time & Memory

On a JupyterHub machine with 30 GB RAM, I can have a maximum of 4 or 5 workers since each
worker uses about 5 GB ram. Each worker takes around 3.5 hours to finish a batch of 50 locations.
The testing location size is ~1600 meaning roughly 32 batches. Since there are 4 batches
running at a time this results in a total run time of 28 hours.

In [None]:
import concurrent
import traceback
import pandas as pd
import xarray as xr
import numpy as np
from climakitae.core.data_interface import get_data, DataInterface
from climakitae.core.data_export import export

# Import functions for 1-in-X event calculations
from climakitae.explore.threshold_tools import (
    get_block_maxima,
    get_ks_stat,
    get_return_value,
)
from climakitae.core.constants import UNSET
from climakitae.util.utils import add_dummy_time_to_wl, get_closest_gridcells

from lat_lons import lat_lons

VERBOSE = True

data_interface = DataInterface() # create to avoid race conditions
# Verify everything loaded correctly
if VERBOSE:
    print("✓ DataInterface initialized successfully!")
    print(f"  Data catalog loaded: {len(data_interface.data_catalog.df)} entries")
    print(f"  Variable descriptions loaded: {len(data_interface.variable_descriptions)} variables")
    print(f"  Stations loaded: {len(data_interface.stations)} stations")
    print(f"  Boundaries loaded:")
    print(f"    - US States: {len(data_interface.geographies._us_states)} states")
    print(f"    - CA Counties: {len(data_interface.geographies._ca_counties)} counties")
    print(f"    - CA Watersheds: {len(data_interface.geographies._ca_watersheds)} watersheds")
    print(f"    - CA Utilities: {len(data_interface.geographies._ca_utilities)} utilities")
    print()
    print("🎉 Now safe to use multi-threading with get_data()!")
    print("="*60)


## Monkey Patch `DataInterface.__init__` for Thread Safety

When running parallel batch processing (e.g., with `ThreadPoolExecutor`), multiple threads may attempt to instantiate the `DataInterface` at the same time. The original `__init__` method is not thread-safe and can lead to race conditions, causing errors or inconsistent state if two threads initialize it simultaneously.

**Monkey patching** the `__init__` method with a thread lock ensures that initialization occurs only once, regardless of how many threads attempt to create a `DataInterface` instance. This guarantees safe, predictable behavior in multi-threaded workflows and prevents issues such as duplicate resource loading or partial initialization.

## Additional Resolutions

This fix applies to all instances of attempting to parallelize data access calls in 
`climakitae`. The following cell should be included in any instance where data access
needs to be multi-threaded including: `cava_data()`, `get_data()`, `DataParameters.retrieve()`, etc.

In [None]:
# ADVANCED: Monkey-patch to prevent re-initialization
# Only use this if kernel restart + STEP 1 doesn't work
import threading

# Save the original __init__
_original_init = DataInterface.__init__

# Create a lock for thread safety
_init_lock = threading.Lock()
_initialized = False

def _patched_init(self, **params):
    """Patched __init__ that only runs once."""
    global _initialized
    
    with _init_lock:
        if _initialized:
            # Already initialized, skip
            return
        
        # First time - run original init
        _original_init(self, **params)
        _initialized = True
        print("✓ DataInterface initialized (first time only)")

# Apply the patch
DataInterface.__init__ = _patched_init

if VERBOSE:
    print("✓ Monkey-patch applied - DataInterface will only initialize once")
    print("Now run STEP 1 to initialize it:")
    print("="*60)

## Custom Metric Definition

The effective temperature (T<sub>eff</sub>) is calculated using a weighted combination of current and lagged temperature values:

**Formula:**
```
T_eff = 0.7 × T_max(0) + 0.003 × T_min(0) × T_max(1) + 0.002 × T_min(1) × T_max(2)
```

**Where:**
| Variable | Description |
|----------|-------------|
| T_max(0) | Current day maximum temperature |
| T_min(0) | Current day minimum temperature |
| T_max(1) | Previous day maximum temperature (1-day lag) |
| T_min(1) | Previous day minimum temperature (1-day lag) |
| T_max(2) | Maximum temperature from 2 days ago (2-day lag) |

In [None]:
def calculate_teff(min_temp, max_temp):
    """
    Calculate effective temperature index using min and max temperature data.
    
    Teff = 0.7*Tmax0 + 0.003*Tmin0*Tmax1 + 0.002*Tmin1*Tmax2
    
    Where:
    - Tmax0: current day max temperature
    - Tmin0: current day min temperature
    - Tmax1: 1-day lag max temperature
    - Tmin1: 1-day lag min temperature
    - Tmax2: 2-day lag max temperature
    
    Parameters
    ----------
    min_temp : xr.DataArray or xr.Dataset
        Minimum temperature data with time or time_delta dimension
    max_temp : xr.DataArray or xr.Dataset
        Maximum temperature data with time or time_delta dimension
        
    Returns
    -------
    xr.DataArray or xr.Dataset
        Effective temperature index (Teff)
        
    Notes
    -----
    The first two time steps will contain NaN values due to lagging.
    """
    # Determine which temporal dimension is present
    if 'time_delta' in max_temp.dims:
        time_dim = 'time_delta'
    elif 'time' in max_temp.dims:
        time_dim = 'time'
    else:
        raise ValueError("Data must have either 'time' or 'time_delta' dimension")
    
    # Create lagged versions using the appropriate dimension
    tmax0 = max_temp  # Current day
    tmin0 = min_temp  # Current day
    tmax1 = max_temp.shift({time_dim: 1})  # 1-day lag
    tmin1 = min_temp.shift({time_dim: 1})  # 1-day lag
    tmax2 = max_temp.shift({time_dim: 2})  # 2-day lag
    
    # Calculate effective temperature
    teff = 0.7 * tmax0 + 0.003 * tmin0 * tmax1 + 0.002 * tmin1 * tmax2
    
    return teff

## Understanding the 1-in-X Custom Calculation Functions

This section explains the two main functions used for calculating 1-in-X year return values for custom climate metrics.

---

### `calculate_1_in_x_custom()` - Main Calculation Function

This function calculates 1-in-X year return values (e.g., 1-in-10 year, 1-in-100 year events) for custom climate metrics using extreme value analysis. It processes multiple locations and climate model simulations efficiently in batch mode.

#### **What It Does**

The function follows a 5-step process:

1. **Validation & Setup**: Validates input data structure and parameters
2. **Spatial Extraction**: Extracts data for all requested locations using `get_closest_gridcells()`
3. **Data Preparation**: Converts time dimensions if needed (e.g., `time_delta` → `time`)
4. **Simulation Processing**: Iterates through each simulation (and warming level if present), computing return values
5. **Output Structuring**: Assembles results into a well-structured xarray Dataset

#### **Key Design Principles**

- **Batch Processing**: Processes all locations together rather than one at a time for efficiency
- **Lazy Evaluation**: Keeps data as dask arrays (not computed) as long as possible to minimize memory usage
- **Memory Management**: Only computes one simulation at a time into memory, then releases it
- **Flexible Dimensions**: Handles both historical (`time`) and warming level (`time_delta`) data
- **Multi-scenario Support**: Can process data with optional `warming_level` dimension

#### **Input Parameters**

| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `custom_data` | xr.DataArray or xr.Dataset | Gridded climate metric with spatial (`lat`, `lon`), temporal (`time` or `time_delta`), and `simulation` dimensions | Required |
| `input_locations` | pd.DataFrame | DataFrame with `lat` and `lon` columns for analysis locations | Required |
| `return_periods` | list of int | Return periods in years (e.g., `[10, 100]` for 1-in-10 and 1-in-100 year events) | `[10, 100]` |
| `metric` | str | Type of extreme: `'max'` for maxima or `'min'` for minima | `'max'` |
| `distr` | str | Distribution to fit: `'gev'` (Generalized Extreme Value) or `'gumbel'` | `'gev'` |
| `event_duration` | tuple | Event duration as `(value, unit)`, e.g., `(1, "day")` | `(1, "day")` |

#### **Output Structure**

Returns an xarray Dataset with:

- **`return_value`**: Return values for each location, simulation, and return period
  - Dimensions: `(location, simulation, [warming_level,] one_in_x)`
  - Units: Same as input data
  
- **`p_values`**: Kolmogorov-Smirnov goodness-of-fit p-values
  - Dimensions: `(location, simulation, [warming_level])`
  - Interpretation: Values > 0.05 indicate good distribution fit
  
- **Coordinates**:
  - `location_lat`: Latitude for each location
  - `location_lon`: Longitude for each location
  - `simulation`: Model simulation identifiers
  - `warming_level`: Global warming levels (if applicable)
  - `one_in_x`: Return period values

#### **Example Usage**

```python
# Calculate 1-in-10 and 1-in-100 year maximum effective temperature
results = calculate_1_in_x_custom(
    teff_data,                 # Your custom metric (e.g., effective temperature)
    locations_df,              # DataFrame with lat/lon columns
    return_periods=[10, 100],  # 1-in-10 and 1-in-100 year events
    metric="max",              # Maximum events
    distr="gev"                # Use GEV distribution
)

# Access results
return_vals = results['return_value']  # Shape: (n_locations, n_simulations, n_return_periods)
p_values = results['p_values']         # Shape: (n_locations, n_simulations)
```

---

### `_process_one_simulation()` - Helper Function

This internal helper function processes a single climate simulation (or simulation-warming_level combination) to calculate return values and goodness-of-fit statistics.

#### **What It Does**

1. **Dimension Reshaping**: Converts 1D `location` dimension into 2D spatial grid (`y`, `x`)
   - Required because `threshold_tools` functions expect 2D spatial data
   - Creates a virtual grid where each location occupies a unique (y, x) position
   
2. **Block Maxima Extraction**: Extracts annual maxima (or minima) using `get_block_maxima()`
   - For daily data with 1-day duration, automatically extracts annual maxima
   - Respects the `extremes_type` parameter ('max' or 'min')
   
3. **Return Value Calculation**: Fits extreme value distribution and calculates return values
   - Uses `get_return_value()` to fit GEV or Gumbel distribution
   - Calculates values for specified return periods
   
4. **Goodness-of-Fit Testing**: Performs Kolmogorov-Smirnov test using `get_ks_stat()`
   - Tests whether the fitted distribution matches the data well
   - Returns p-values for statistical validation
   
5. **Dimension Restoration**: Reshapes results back to 1D `location` dimension

#### **Why the Reshaping?**

The `threshold_tools` functions (`get_block_maxima`, `get_return_value`, `get_ks_stat`) were designed to work with gridded climate data having 2D spatial dimensions (`y`, `x`). However, after extracting specific locations, our data has a 1D `location` dimension. 

The reshaping process:
- **Before**: `(time, location)` where `location` is 1D
- **Intermediate**: `(time, y, x)` where `y=1` and `x=n_locations` (virtual 2D grid)
- **After**: `(location, one_in_x)` for return values, `(location)` for p-values

This allows us to leverage the existing `multiple_points=True` functionality in `threshold_tools` while working with point locations rather than full grids.

#### **Parameters**

| Parameter | Type | Description |
|-----------|------|-------------|
| `one_sim_computed` | xr.DataArray | Computed data for one simulation with dims `(time, location)` |
| `return_periods` | list | List of return period values |
| `metric` | str | `'max'` or `'min'` for extreme type |
| `distr` | str | Distribution name (`'gev'` or `'gumbel'`) |
| `groupby` | value or UNSET | Groupby parameter for `get_block_maxima` |
| `duration` | value or UNSET | Duration parameter for `get_block_maxima` |

#### **Returns**

Tuple of `(ret_val, p_val)`:
- `ret_val`: Return values with dimensions `(location, one_in_x)`
- `p_val`: P-values with dimension `(location)`

#### **Internal Workflow Diagram**

```
Input: (time, location)
       ↓
[Reshape to 2D grid]
       ↓
Data: (time, y, x)  where y=1, x=n_locations
       ↓
[Extract block maxima]
       ↓
Annual Maxima: (year, y, x)
       ↓
[Fit distribution & calculate return values]
       ↓
Return Values: (y, x, one_in_x)
       ↓
[Reshape back to 1D]
       ↓
Output: (location, one_in_x)
```

---

### **Performance Considerations**

- **Memory**: Each simulation is computed separately and released, keeping memory usage manageable
- **Speed**: Batch processing all locations together is faster than iterating individually
- **Parallelization**: The notebook demonstrates parallel batch processing using `ThreadPoolExecutor`
- **Data Locality**: Uses lazy evaluation with dask arrays until computation is absolutely necessary

In [None]:
def calculate_1_in_x_custom(
    custom_data,          # User's custom metric (gridded, with lat/lon)
    input_locations,      # DataFrame with 'lat' and 'lon' columns
    return_periods=[10, 100],
    metric="max",
    distr="gev",
    event_duration=(1, "day"),
):
    """
    Calculate 1-in-X year return values for custom metrics.
    
    Processes all locations and simulations together in batch mode,
    following the cava_data(batch_mode=True) pattern. Keeps operations
    lazy (dask arrays) as long as possible.
    
    Parameters
    ----------
    custom_data : xr.DataArray or xr.Dataset
        Gridded custom metric with lat/lon dimensions and simulation dimension.
        Can have either 'time' or 'time_delta' dimension for temporal data.
        Can optionally have 'warming_level' dimension which will be preserved.
    input_locations : pd.DataFrame
        DataFrame with 'lat' and 'lon' columns specifying locations to analyze.
    return_periods : list of int, optional
        Return periods for 1-in-X year events (e.g., [10, 100] for 1-in-10 
        and 1-in-100 year events). Default is [10, 100].
    metric : str, optional
        Type of extreme to calculate: 'max' or 'min'. Default is 'max'.
    distr : str, optional
        Distribution to fit: 'gev' (Generalized Extreme Value) or 'gumbel'.
        Default is 'gev'.
    event_duration : tuple of (int, str), optional
        Duration of events as (value, unit) where unit is 'day' or 'hour'.
        Default is (1, "day").
    
    Returns
    -------
    xr.Dataset
        Dataset containing:
        - 'return_value': Return values for each location, simulation, and 
          return period. Dimensions: (location, simulation, [warming_level,] one_in_x)
        - 'p_values': Kolmogorov-Smirnov goodness-of-fit p-values for each
          location and simulation. Dimensions: (location, simulation, [warming_level])
    """
    
    # ============================================================
    # STEP 1: VALIDATE AND SETUP
    # ============================================================
    
    print("=" * 60)
    print("BATCH MODE 1-IN-X CALCULATION")
    print("=" * 60)
    
    # Validate input_locations
    if not isinstance(input_locations, pd.DataFrame):
        raise TypeError(
            f"input_locations must be a pandas DataFrame, got {type(input_locations)}"
        )
    
    if 'lat' not in input_locations.columns:
        raise ValueError(
            "input_locations DataFrame must contain 'lat' column"
        )
    
    if 'lon' not in input_locations.columns:
        raise ValueError(
            "input_locations DataFrame must contain 'lon' column"
        )
    
    # Check for empty DataFrame
    if len(input_locations) == 0:
        raise ValueError(
            "input_locations DataFrame is empty. Please provide at least one location."
        )
    
    # Validate custom_data has required dimensions
    if not isinstance(custom_data, (xr.DataArray, xr.Dataset)):
        raise TypeError(
            f"custom_data must be xarray DataArray or Dataset, got {type(custom_data)}"
        )
    
    # Check for spatial dimensions
    if 'lat' not in custom_data.dims and 'latitude' not in custom_data.dims:
        raise ValueError(
            "custom_data must have 'lat' or 'latitude' dimension for spatial extraction"
        )
    
    if 'lon' not in custom_data.dims and 'longitude' not in custom_data.dims:
        raise ValueError(
            "custom_data must have 'lon' or 'longitude' dimension for spatial extraction"
        )
    
    # Check for temporal dimension
    has_time = 'time' in custom_data.dims
    has_time_delta = 'time_delta' in custom_data.dims
    
    if not (has_time or has_time_delta):
        raise ValueError(
            "custom_data must have either 'time' or 'time_delta' dimension for temporal analysis"
        )
    
    # Check for simulation dimension
    if 'simulation' not in custom_data.dims:
        raise ValueError(
            "custom_data must have 'simulation' dimension for batch processing"
        )
    
    # Check for warming_level dimension
    has_warming_level = 'warming_level' in custom_data.dims
    
    # Convert return_periods to list if needed
    if not isinstance(return_periods, list):
        return_periods = [return_periods]
    
    # Validate return_periods
    if len(return_periods) == 0:
        raise ValueError("return_periods must contain at least one value")
    
    for rp in return_periods:
        if not isinstance(rp, (int, float)) or rp <= 0:
            raise ValueError(
                f"All return periods must be positive numbers, got {rp}"
            )
    
    # Print summary
    num_locations = len(input_locations)
    num_simulations = len(custom_data.simulation)
    num_warming_levels = len(custom_data.warming_level) if has_warming_level else 1

    if VERBOSE:
        print(f"\nConfiguration:")
        print(f"  Locations: {num_locations}")
        print(f"  Simulations: {num_simulations}")
        if has_warming_level:
            print(f"  Warming Levels: {num_warming_levels}")
        print(f"  Return periods: {return_periods}")
        print(f"  Metric: {metric}")
        print(f"  Distribution: {distr}")
        print(f"  Event duration: {event_duration[0]} {event_duration[1]}")
        print(f"  Temporal dimension: {'time_delta' if has_time_delta else 'time'}")
        print()
    
        print("\n--- Step 2: Extracting Gridcells ---")
        print(f"Extracting closest gridcells for {num_locations} location(s)...")
    
    # Extract lat/lon arrays from DataFrame
    lats = input_locations['lat'].values
    lons = input_locations['lon'].values
    
    # Extract all locations at once (batch mode)
    # This returns data with 'points' dimension stacking all locations
    custom_data_batch = get_closest_gridcells(
        custom_data, 
        lats, 
        lons
    )
    
    # Rename 'points' to 'location' for clarity
    custom_data_batch = custom_data_batch.rename({'points': 'location'})
    
    # Keep data lazy (don't compute yet)
    if VERBOSE:
        print(f"✓ Extracted gridcells successfully")
        print(f"  Data type: {type(custom_data_batch).__name__}")
        print(f"  Dimensions: {list(custom_data_batch.dims)}")
        print(f"  Sizes: {dict(custom_data_batch.sizes)}")
        print(f"  Shape: {custom_data_batch.shape}")
        if hasattr(custom_data_batch, 'data'):
            print(f"  Data remains lazy: {hasattr(custom_data_batch.data, 'dask')}")
        print()
    
        print("\n--- Step 3: Preparing Data ---")
    # Check if data has time_delta dimension (warming level data)
    if has_time_delta:
        if VERBOSE: print("Converting time_delta to time dimension for resampling...")
        custom_data_batch = add_dummy_time_to_wl(custom_data_batch)
        if VERBOSE:
            print(f"✓ Converted to time dimension")
            print(f"  New dimensions: {list(custom_data_batch.dims)}")
            if hasattr(custom_data_batch, 'data'):
                print(f"  Data remains lazy: {hasattr(custom_data_batch.data, 'dask')}")
        
            # NOTE: We will drop NaN values inside the simulation loop
            # to keep data lazy as long as possible
            print("  Note: NaN values will be dropped during simulation processing")
    else:
        if VERBOSE: print("Data already has 'time' dimension, no conversion needed")

    if VERBOSE:
        print("\n--- Step 4: Processing Simulations ---")
        total_iterations = num_simulations * num_warming_levels
        print(f"Processing {num_simulations} simulations" + 
              (f" × {num_warming_levels} warming levels = {total_iterations} total iterations..." 
               if has_warming_level else "..."))
    
    # Prepare event duration parameters based on data frequency
    # For daily data from warming levels, we use annual maxima directly
    if event_duration[1] == "day":
        if event_duration[0] == 1:
            # For 1-day events with daily data, don't pass groupby or duration
            # get_block_maxima will automatically extract annual maxima
            groupby = UNSET
            duration = UNSET
        else:
            raise ValueError(
                f"Multi-day duration events ({event_duration[0]} days) not yet supported. "
                "Use duration=(1, 'day') for daily data."
            )
    elif event_duration[1] == "hour":
        # Hourly data can use duration parameter
        groupby = UNSET
        duration = event_duration
    else:
        raise ValueError(f"Unsupported duration unit: {event_duration[1]}. Use 'day' or 'hour'.")
    
    return_vals_list = []
    p_vals_list = []
    
    # Template variables for creating NaN-filled results
    ret_val_template = None
    p_val_template = None
    
    # Determine iteration order: simulation first, then warming_level if present
    iteration_count = 0
    
    # Loop over simulations
    # Loop over simulations
    for i, sim in enumerate(custom_data_batch.simulation.values):
        
        # Select this simulation (keeps all locations and warming levels!)
        one_sim = custom_data_batch.sel(simulation=sim)
        
        # If warming_level dimension exists, loop over it
        if has_warming_level:
            for j, wl in enumerate(custom_data_batch.warming_level.values):
                iteration_count += 1
                if VERBOSE: print(f"  Processing simulation {i+1}/{num_simulations}, " +
                      f"warming level {j+1}/{num_warming_levels} ({wl}°C) " +
                      f"[{iteration_count}/{total_iterations}]...", end=" ")
                
                # Select this warming level
                # After selection, warming_level becomes a scalar coordinate
                one_sim_wl = one_sim.sel(warming_level=wl)
                
                # Drop the warming_level coordinate if it exists (it's now scalar, not a dimension)
                if 'warming_level' in one_sim_wl.coords and 'warming_level' not in one_sim_wl.dims:
                    one_sim_wl = one_sim_wl.drop_vars('warming_level')
                
                # NOW compute to memory for this single simulation-warming_level combo
                one_sim_computed = one_sim_wl.compute()
                
                # Drop NaN values AFTER computing
                original_time = one_sim_computed.sizes['time']
                one_sim_computed = one_sim_computed.dropna(dim='time')
                final_time = one_sim_computed.sizes['time']
                
                if original_time != final_time:
                    if VERBOSE: print(f"(dropped {original_time - final_time} NaN timesteps)", end=" ")
                
                # Check if there's any data left after dropping NaN values
                if final_time == 0:
                    if VERBOSE: print("⚠️  WARNING: No valid data remaining - creating NaN results")
                    
                    # Use template or create from scratch
                    if ret_val_template is not None:
                        # Copy structure from template and fill with NaNs
                        ret_val = ret_val_template.copy(deep=True)
                        ret_val.values[:] = np.nan
                        p_val = p_val_template.copy(deep=True)
                        p_val.values[:] = np.nan
                    else:
                        # Fallback: create simple structure (only for first failure before any success)
                        n_locs = one_sim_computed.sizes['location']
                        ret_val = xr.DataArray(
                            np.full((n_locs, len(return_periods)), np.nan),
                            dims=('location', 'one_in_x'),
                            coords={
                                'location': np.arange(n_locs),
                                'one_in_x': return_periods
                            }
                        )
                        p_val = xr.DataArray(
                            np.full(n_locs, np.nan),
                            dims=('location',),
                            coords={'location': np.arange(n_locs)}
                        )
                else:
                    # Process this simulation-warming_level combination
                    ret_val, p_val = _process_one_simulation(
                        one_sim_computed, 
                        return_periods, 
                        metric, 
                        distr, 
                        groupby, 
                        duration
                    )
                    
                    # Save as template for future NaN results
                    if ret_val_template is None:
                        ret_val_template = ret_val.copy(deep=True)
                        p_val_template = p_val.copy(deep=True)
                
                # Clean up coordinates to ensure consistency
                ret_val = ret_val.drop_vars([c for c in ret_val.coords 
                                            if c not in ['location', 'one_in_x']], errors='ignore')
                p_val = p_val.drop_vars([c for c in p_val.coords 
                                        if c != 'location'], errors='ignore')
                
                return_vals_list.append(ret_val)
                p_vals_list.append(p_val)
                
                if VERBOSE: print("✓")
        else:
            # No warming level dimension - process simulation directly
            iteration_count += 1
            if VERBOSE: print(f"  Processing simulation {i+1}/{num_simulations} [{iteration_count}/{total_iterations}]: {sim}...", end=" ")
            
            # NOW compute to memory for this single simulation
            one_sim_computed = one_sim.compute()
            
            # Drop NaN values AFTER computing
            original_time = one_sim_computed.sizes['time']
            one_sim_computed = one_sim_computed.dropna(dim='time')
            final_time = one_sim_computed.sizes['time']
            
            if original_time != final_time:
                if VERBOSE: print(f"(dropped {original_time - final_time} NaN timesteps)", end=" ")
            
            # Check if there's any data left after dropping NaN values
            if final_time == 0:
                if VERBOSE: print("⚠️  WARNING: No valid data remaining - creating NaN results")
                
                # Use template or create from scratch
                if ret_val_template is not None:
                    # Copy structure from template and fill with NaNs
                    ret_val = ret_val_template.copy(deep=True)
                    ret_val.values[:] = np.nan
                    p_val = p_val_template.copy(deep=True)
                    p_val.values[:] = np.nan
                else:
                    # Fallback: create simple structure (only for first failure before any success)
                    n_locs = one_sim_computed.sizes['location']
                    ret_val = xr.DataArray(
                        np.full((n_locs, len(return_periods)), np.nan),
                        dims=('location', 'one_in_x'),
                        coords={
                            'location': np.arange(n_locs),
                            'one_in_x': return_periods
                        }
                    )
                    p_val = xr.DataArray(
                        np.full(n_locs, np.nan),
                        dims=('location',),
                        coords={'location': np.arange(n_locs)}
                    )
            else:
                # Process this simulation
                ret_val, p_val = _process_one_simulation(
                    one_sim_computed, 
                    return_periods, 
                    metric, 
                    distr, 
                    groupby, 
                    duration
                )
                
                # Save as template for future NaN results
                if ret_val_template is None:
                    ret_val_template = ret_val.copy(deep=True)
                    p_val_template = p_val.copy(deep=True)
            
            # Clean up coordinates to ensure consistency
            ret_val = ret_val.drop_vars([c for c in ret_val.coords 
                                        if c not in ['location', 'one_in_x']], errors='ignore')
            p_val = p_val.drop_vars([c for c in p_val.coords 
                                    if c != 'location'], errors='ignore')
            
            return_vals_list.append(ret_val)
            p_vals_list.append(p_val)
            
            if VERBOSE: print("✓")
    
    if VERBOSE: 
        print(f"✓ Completed all {total_iterations} iterations")
        print()
    
    # ============================================================
    # STEP 5: STRUCTURE OUTPUT
    # ============================================================
    if VERBOSE: 
        print("\n--- Step 5: Structuring Output ---")
        print("Assembling results into xarray Dataset...")
    
    # Stack return values across simulations (and warming levels if present)
    # Each element in return_vals_list has dims: (location, one_in_x)
    # After concat, will have dims: (sim_wl_combo, location, one_in_x)
    return_vals_stacked = xr.concat(
        return_vals_list, 
        dim="sim_wl_combo",
        coords='minimal',
        compat='override'
    )
    
    # Stack p-values across simulations (and warming levels if present)
    p_vals_stacked = xr.concat(
        p_vals_list,
        dim="sim_wl_combo",
        coords='minimal',
        compat='override'
    )
    
    # Reshape to separate simulation and warming_level dimensions
    if has_warming_level:
        # Create MultiIndex for sim_wl_combo dimension
        sim_index = np.repeat(custom_data_batch.simulation.values, num_warming_levels)
        wl_index = np.tile(custom_data_batch.warming_level.values, num_simulations)
        
        # Assign coordinates
        return_vals_stacked = return_vals_stacked.assign_coords(
            simulation=("sim_wl_combo", sim_index),
            warming_level=("sim_wl_combo", wl_index)
        )
        p_vals_stacked = p_vals_stacked.assign_coords(
            simulation=("sim_wl_combo", sim_index),
            warming_level=("sim_wl_combo", wl_index)
        )
        
        # Set multi-index and unstack
        return_vals_stacked = return_vals_stacked.set_index(
            sim_wl_combo=["simulation", "warming_level"]
        ).unstack("sim_wl_combo")
        
        p_vals_stacked = p_vals_stacked.set_index(
            sim_wl_combo=["simulation", "warming_level"]
        ).unstack("sim_wl_combo")
        
        # Transpose to desired order: (location, simulation, warming_level, one_in_x)
        return_vals_final = return_vals_stacked.transpose("location", "simulation", "warming_level", "one_in_x")
        p_vals_final = p_vals_stacked.transpose("location", "simulation", "warming_level")
        
    else:
        # No warming level - simpler structure
        # Assign simulation coordinate values
        return_vals_stacked = return_vals_stacked.assign_coords(
            simulation=("sim_wl_combo", custom_data_batch.simulation.values)
        )
        p_vals_stacked = p_vals_stacked.assign_coords(
            simulation=("sim_wl_combo", custom_data_batch.simulation.values)
        )
        
        # Rename dim
        return_vals_final = return_vals_stacked.rename({"sim_wl_combo": "simulation"})
        p_vals_final = p_vals_stacked.rename({"sim_wl_combo": "simulation"})
        
        # Transpose to desired order: (location, simulation, one_in_x)
        return_vals_final = return_vals_final.transpose("location", "simulation", "one_in_x")
        p_vals_final = p_vals_final.transpose("location", "simulation")
    
    # Create output Dataset
    result = xr.Dataset({
        "return_value": return_vals_final,
        "p_values": p_vals_final
    })
    
    # Add lat/lon coordinates for each location
    location_lats = input_locations['lat'].values
    location_lons = input_locations['lon'].values
    
    result = result.assign_coords({
        "location_lat": ("location", location_lats),
        "location_lon": ("location", location_lons)
    })
    
    # Add comprehensive metadata
    # Convert return_periods to a string representation for NetCDF compatibility
    return_periods_str = str(return_periods)
    
    result.attrs.update({
        "processing_mode": "batch",
        "num_locations": int(num_locations),
        "num_simulations": int(num_simulations),
        "num_warming_levels": int(num_warming_levels),  # Always an integer now
        "return_periods": return_periods_str,  # Convert list to string
        "fitted_distribution": distr,
        "extremes_type": metric,
        "event_duration_value": int(event_duration[0]),
        "event_duration_unit": event_duration[1],
        "created_with": "calculate_1_in_x_custom",
        "temporal_input_dimension": "time_delta" if has_time_delta else "time",
        "has_warming_level_dimension": int(has_warming_level),
    })
    
    # Add variable-level metadata
    result["return_value"].attrs.update({
        "long_name": f"{metric.capitalize()} {event_duration[0]}-{event_duration[1]} return values",
        "units": "same as input data",
        "description": f"Return values for {return_periods_str} year return periods using {distr} distribution",
    })
    
    result["p_values"].attrs.update({
        "long_name": "Kolmogorov-Smirnov goodness-of-fit p-values",
        "description": f"P-values for {distr} distribution fit quality. Values > 0.05 indicate good fit.",
    })

    if VERBOSE: 
        print("✓ Output Dataset created successfully")
        print(f"  Dimensions: {list(result.dims)}")
        print(f"  Data variables: {list(result.data_vars)}")
        print(f"  Coordinates: {list(result.coords)}")
        print()
        
        print("=" * 60)
        print("BATCH MODE CALCULATION COMPLETE!")
        print("=" * 60)
        print(f"\nResults summary:")
        print(f"  Locations analyzed: {num_locations}")
        print(f"  Simulations processed: {num_simulations}")
        if has_warming_level:
            print(f"  Warming levels: {num_warming_levels}")
        print(f"  Return periods: {return_periods}")
        print(f"  Output shape: {result['return_value'].shape}")
        print()
    
    return result


def _process_one_simulation(one_sim_computed, return_periods, metric, distr, groupby, duration):
    """
    Helper function to process a single simulation (or simulation-warming_level combo).
    
    Parameters
    ----------
    one_sim_computed : xr.DataArray
        Computed data for one simulation, with dims (time, location)
    return_periods : list
        List of return periods
    metric : str
        'max' or 'min'
    distr : str
        Distribution name
    groupby : value or UNSET
        Groupby parameter for get_block_maxima
    duration : value or UNSET
        Duration parameter for get_block_maxima
    
    Returns
    -------
    tuple of (ret_val, p_val)
        Return values and p-values for this simulation
    """
    # Reshape data to have separate x and y dimensions
    # threshold_tools expects (time, y, x) not (time, location)
    # We create a 2D grid where each location is at a unique (y, x) position
    n_locations = one_sim_computed.sizes['location']
    
    # Create a virtual 2D grid (can be a row or column)
    # Use a row layout: y=1, x varies
    one_sim_reshaped = one_sim_computed.assign_coords({
        'x': ('location', np.arange(n_locations)),
        'y': ('location', np.zeros(n_locations, dtype=int))  # All same y
    })
    
    # Expand location into y and x dimensions
    one_sim_reshaped = one_sim_reshaped.set_index(location=['y', 'x']).unstack('location')
    
    # Now data has dims: (time, y, x)
    # where y=1 and x=n_locations
    
    # Get block maxima (handles multiple locations via y, x dimensions)
    ams = get_block_maxima(
        one_sim_reshaped,
        extremes_type=metric,
        duration=duration,
        groupby=groupby,
        check_ess=False,
    )
    
    # Calculate return values (multiple_points=True handles y, x dimensions)
    ret_val = get_return_value(
        ams,
        return_period=return_periods,
        multiple_points=True,
        distr=distr,
        bootstrap_runs=0,  # increase for better estimates on confidence intervals
    )["return_value"]
    
    # Reshape back to location dimension
    # Stack x and y back into a single location dimension
    ret_val = ret_val.stack(location=['y', 'x']).reset_index('location', drop=True)
    ret_val = ret_val.assign_coords(location=np.arange(n_locations))
    
    # Drop any extra coordinates that might cause concat issues
    # Keep only location and one_in_x coordinates
    coords_to_keep = ['location', 'one_in_x']
    extra_coords = [c for c in ret_val.coords if c not in coords_to_keep]
    if extra_coords:
        ret_val = ret_val.drop_vars(extra_coords)
    
    # Goodness of fit
    ks_result = get_ks_stat(ams, distr=distr, multiple_points=True)
    
    # Reshape p_values back to location dimension
    p_val = ks_result.p_value.stack(location=['y', 'x']).reset_index('location', drop=True)
    p_val = p_val.assign_coords(location=np.arange(n_locations))
    
    # Drop any extra coordinates from p_val
    # Keep only location coordinate
    extra_coords = [c for c in p_val.coords if c != 'location']
    if extra_coords:
        p_val = p_val.drop_vars(extra_coords)
    
    return ret_val, p_val


## Example: Calculate 1-in-X Events for Effective Temperature

Now you can calculate the Effective Temperature and then compute 1-in-X year return values:

In [None]:
# location loading
num_pts = -1
locations = pd.DataFrame({
    'lat': [x[0] for x in lat_lons[:num_pts]],
    'lon': [x[1] for x in lat_lons[:num_pts]] 
})

batch_size = 5
batches = np.array_split(locations, len(locations)//batch_size)

## Defining parameters
In the following cell we define a bunch of programmatic variable used for fetching the required data.

| Variable | Type / Example | Meaning |
|---|---:|---|
| space | float (0.02) | Padding in degrees added around batch bounding boxes when requesting data (lat/lon buffer). |
| min_temp_var | str ("Minimum air temperature at 2m") | Name/description of the minimum temperature variable to request from the data API. |
| max_temp_var | str ("Maximum air temperature at 2m") | Name/description of the maximum temperature variable to request from the data API. |
| downscaling | str ("Statistical") | Downscaling method to request (e.g., "Statistical" or "Dynamical"). |
| resolution | str ("3 km") | Spatial resolution of the downscaled data to request. |
| timescale | str ("daily") | Temporal aggregation of the data (e.g., "daily", "hourly"). |
| approach | str ("Warming Level") | Retrieval approach: e.g., warming level ("Warming Level") vs. time-slice/historical. |
| units | str ("degF") | Output units to request from the API (ensures consistent units across variables). |
| wls | list of float ([2.0, 2.5]) | Warming levels (°C) to request when using the warming-level approach. |
| return_periods | list of int ([10, 100]) | Return periods (in years) for which to compute 1-in-X events. |
| time_slice | tuple (2000, 2029) | Start and end years for historical/time-slice data requests. |
| scenario | list of str (["Historical Climate", "SSP 3-7.0"]) | Scenario names to request (historical and future scenario(s)). |
| num_pts | int (-1) | Number of points from lat_lons to include (-1 means use all). |
| batch_size | int (5) | Number of locations per processing batch. |
| batches | list of DataFrame | List of DataFrame batches created by splitting locations for parallel processing. |

In [None]:
# specific inputs for grabbing data
space = 0.02 # you probably don't need to change this
min_temp_var = "Minimum air temperature at 2m"
max_temp_var = "Maximum air temperature at 2m"
downscaling = "Statistical"
resolution = "3 km"
timescale = "daily"
approach = "Warming Level"
units = "degF"
wls = [2.0, 2.5]
return_periods = [10, 100]
time_slice = (2000, 2029)
scenario = ["Historical Climate", "SSP 3-7.0"]

## Purpose of the next cell — job functions for concurrent futures

This cell defines two worker functions intended to be submitted to a ThreadPoolExecutor:

- Names: `process_batch_time` and `process_batch_gwl`
- Purpose: fetch min/max temperature data for a batch of locations, compute the effective-temperature metric, run the batch 1‑in‑X calculation, export results, and return the Dataset.
- Signature: both accept a single argument `(idx, batch)` where `batch` is a pandas DataFrame with `lat`/`lon` columns; they return the computed xarray Dataset or `None` on failure.
- Behavior:
    - Compute bounding box from `batch` and call `get_data(...)` (different call for warming‑level vs time slice).
    - Call `calculate_teff()` then `calculate_1_in_x_custom()` with shared globals.
    - Export result with `export(..., filename=...)`.
    - Catch exceptions, print detailed traceback and return `None`.
- Dependencies (must exist in the kernel):
    - `get_data`, `calculate_teff`, `calculate_1_in_x_custom`, `export`
    - Configuration globals: `min_temp_var`, `max_temp_var`, `space`, `downscaling`, `resolution`, `timescale`, `time_slice`, `scenario`, `units`, `approach`, `wls`, `return_periods`
    - `VERBOSE` / thread-safety monkey patch for `DataInterface` if using multithreading
- Usage:
    - Submit using `executor.submit(process_batch_time, (idx, batch))` or `executor.submit(process_batch_gwl, (idx, batch))`
    - Inspect returned value in futures; exported files follow the patterns `teff_batch_time_{idx}` and `teff_batch_gwl_{idx}`
- Notes and tips:
    - Functions print progress and detailed tracebacks to help debugging.
    - Keep `n_workers` small enough for available memory and API rate limits.
    - Ensure the DataInterface monkey‑patch (cell 3) is applied if you will create multiple threads.

In [None]:
import traceback

def process_batch_time(args: tuple[int, pd.DataFrame]) -> pd.DataFrame | None:
    """
    Process a batch of locations for time-based data retrieval and analysis.

    Parameters
    ----------
    args : tuple
        Tuple containing (idx, batch)
        - idx : int
            Batch index used for logging and output file naming.
        - batch : pandas.DataFrame
            DataFrame containing at least 'lat' and 'lon' columns that define the
            set of locations to process.

    Returns
    -------
    pandas.DataFrame or None
        The resulting DataFrame produced by calculate_1_in_x_custom for the batch,
        or None if an exception occurred during processing.

    Raises
    ------
    None
        All exceptions are caught internally; details are printed and the function
        returns None on error.

    Notes
    -----
    The function:
    - Determines the spatial bounding box from batch lat/lon values.
    - Fetches min/max temperature data for that bounding box.
    - Computes teff and 1-in-x statistics.
    - Exports the result to a file named "teff_batch_time_{idx}".
    """
    idx, batch = args
    
    print(f"[Batch {idx}] Starting processing...")
    
    try:
        min_lat = min(batch['lat'].values.tolist())
        max_lat = max(batch['lat'].values.tolist())
        min_lon = min(batch['lon'].values.tolist())
        max_lon = max(batch['lon'].values.tolist())
        
        min_temp_df = get_data(
            variable=min_temp_var,
            latitude=(min_lat - space, max_lat + space),
            longitude=(min_lon - space, max_lon + space),
            downscaling_method=downscaling,
            timescale=timescale,
            time_slice=time_slice,
            scenario=scenario,
            resolution=resolution,
            units=units
        )
        max_temp_df = get_data(
            variable=max_temp_var,
            latitude=(min_lat - space, max_lat + space),
            longitude=(min_lon - space, max_lon + space),
            downscaling_method=downscaling,
            timescale=timescale,
            time_slice=time_slice,
            scenario=scenario,
            resolution=resolution,
            units=units
        )
        
        teff_batch = calculate_teff(min_temp_df, max_temp_df)
        result = calculate_1_in_x_custom(
            teff_batch,
            batch,
            return_periods=return_periods,
        )

        export(result, filename=f"teff_batch_time_{idx}")        
        return result
        
    except Exception as exc:
        print(f"\n[Batch {idx}] ❌❌❌ EXCEPTION OCCURRED:")
        print(f"[Batch {idx}] Exception type: {type(exc).__name__}")
        print(f"[Batch {idx}] Exception message: {exc}")
        print(f"[Batch {idx}] Full traceback:")
        print("-" * 60)
        traceback.print_exc()
        print("-" * 60)
        return None


def process_batch_gwl(args: tuple[int, pd.DataFrame]) -> pd.DataFrame | None:
    """
    Process a batch of locations for global warming level (GWL) based data retrieval
    and analysis.

    Parameters
    ----------
    args : tuple
        Tuple containing (idx, batch)
        idx : int
            Batch index used for logging and output file naming.
        batch : pandas.DataFrame
            DataFrame containing at least 'lat' and 'lon' columns that define the
            set of locations to process.

    Returns
    -------
    pandas.DataFrame or None
        The resulting DataFrame produced by calculate_1_in_x_custom for the batch,
        or None if an exception occurred during processing.

    Raises
    ------
    None
        All exceptions are caught internally; details are printed and the function
        returns None on error.

    Notes
    -----
    - Uses get_data with `approach` and `warming_levels` (wls) parameters to fetch
      min/max temperature data for the spatial bounding box of the batch.
    - Computes teff and 1-in-x statistics via calculate_teff and
      calculate_1_in_x_custom.
    - Exports the result to a file named "teff_batch_gwl_{idx}".
    """
    idx, batch = args
    
    print(f"[Batch {idx}] Starting processing...")
    
    try:
        min_lat = min(batch['lat'].values.tolist())
        max_lat = max(batch['lat'].values.tolist())
        min_lon = min(batch['lon'].values.tolist())
        max_lon = max(batch['lon'].values.tolist())
        
        min_temp_df = get_data(
            variable=min_temp_var,
            latitude=(min_lat - space, max_lat + space),
            longitude=(min_lon - space, max_lon + space),
            downscaling_method=downscaling,
            timescale=timescale,
            scenario=scenario,
            resolution=resolution,
            units=units,
            approach=approach,
            warming_levels=wls
        )
        max_temp_df = get_data(
            variable=max_temp_var,
            latitude=(min_lat - space, max_lat + space),
            longitude=(min_lon - space, max_lon + space),
            downscaling_method=downscaling,
            timescale=timescale,
            scenario=scenario,
            resolution=resolution,
            units=units,
            approach=approach,
            warming_levels=wls
        )
        
        teff_batch = calculate_teff(min_temp_df, max_temp_df)
        result = calculate_1_in_x_custom(
            teff_batch,
            batch,
            return_periods=return_periods,
        )

        export(result, filename=f"teff_batch_gwl_{idx}")        
        return result
        
    except Exception as exc:
        print(f"\n[Batch {idx}] ❌❌❌ EXCEPTION OCCURRED:")
        print(f"[Batch {idx}] Exception type: {type(exc).__name__}")
        print(f"[Batch {idx}] Exception message: {exc}")
        print(f"[Batch {idx}] Full traceback:")
        print("-" * 60)
        traceback.print_exc()
        print("-" * 60)
        return None

### Purpose
Run the batch processing workers in parallel (with thread pools) to compute and collect 1‑in‑X results for two processing modes:
- time‑based retrieval (`process_batch_time`)
- global warming level retrieval (`process_batch_gwl`)

### What the cell does
- Configures a thread pool size via `n_workers` (here 2).
- Submits the first two batches (`batches[:2]`) to `process_batch_time` using `ThreadPoolExecutor`, collects completed futures, appends each returned result to `results` along with its batch index, and prints exceptions if any occur.
- Repeats the same pattern for `process_batch_gwl`.
- After both executors finish, sorts `results` by batch index and replaces `results` with the list of returned datasets (discarding the indices).

### Key behaviors and notes
- Only the first two batches are processed because of `batches[:2]`; change that slice to process more.
- Each submitted job receives `(idx, batch)` (matching the worker function signatures).
- Exceptions inside worker functions are captured both inside the worker (they return `None`) and here when accessing `future.result()`; both paths are handled.
- `n_workers` should be tuned for available memory and API limits; the DataInterface monkey‑patch (cell 3) should be applied if using multiple threads to avoid race conditions.
- Final `results` contains only the returned datasets (or `None` for failed batches), ordered by batch index.

> [!NOTE]
> Please comment out the version you aren't using for the current run. If you're using both then leave them both un-commented

In [None]:
results = []
n_workers = 2  # Adjust based on your system capabilities
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
    # Submit all batches for parallel processing
    future_to_batch = {executor.submit(process_batch_time, args): args[0] for args in enumerate(batches[:2])}
    for future in concurrent.futures.as_completed(future_to_batch):
        idx = future_to_batch[future]
        try:
            result = future.result()
            results.append((idx, result))
        except Exception as exc:
            print(f"[Batch {idx}] generated an exception: {exc}")

with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
    # Submit all batches for parallel processing
    future_to_batch = {executor.submit(process_batch_gwl, args): args[0] for args in enumerate(batches[:2])}
    for future in concurrent.futures.as_completed(future_to_batch):
        idx = future_to_batch[future]
        try:
            result = future.result()
            results.append((idx, result))
        except Exception as exc:
            print(f"[Batch {idx}] generated an exception: {exc}")

# Optionally, sort results by batch index
results.sort(key=lambda x: x[0])
results = [r[1] for r in results]