In [1]:
import numpy as np
import math
from scipy.optimize import fsolve
from numba import njit, prange, float64, vectorize
import xarray as xr

In [2]:
import pandas as pd
import matplotlib.pyplot as plt

In [49]:
# @njit(float64[:](float64[:]))
def samlmom3(sample):
    """
    samlmom3 returns the first three L-moments of samples
    sample is the 1-d array
    n is the total number of the samples, j is the j_th sample
    """
    # Sort in descending order
    sorted_sample = np.sort(sample)[::-1]
    n = len(sorted_sample)
    
    # Calculate mean directly
    b0 = np.sum(sorted_sample) / n
    
    # Pre-calculate the coefficients for b1
    b1 = 0.0
    for j in range(n):
        b1_coef = (n - j - 1) / (n * (n - 1))
        b1 += b1_coef * sorted_sample[j]
    
    # Pre-calculate the coefficients for b2
    b2 = 0.0
    for j in range(n-1):  # Note: upper bound is n-1 for b2
        b2_coef = (n - j - 1) * (n - j - 2) / (n * (n - 1) * (n - 2))
        b2 += b2_coef * sorted_sample[j]
    
    # Calculate L-moments
    lmom1 = b0
    lmom2 = 2 * b1 - b0
    lmom3 = 6 * (b2 - b1) + b0

    return np.array([lmom1, lmom2, lmom3])

# @njit
def f_for_fsolve(x, t):
    """Helper function for pargev_fsolve."""
    return 2 * (1 - 3 ** (-x)) / (1 - 2 ** (-x)) - 3 - t

# @njit
def pargev(lmom):
    """
    pargev returns the parameters of the Generalized Extreme Value
    distribution given the L-moments of samples
    """
    t3 = lmom[2] / lmom[1]
    # Don't create a new array here - just use the values directly
    lmom_0 = lmom[0]
    lmom_1 = lmom[1]
    
    SMALL = 1e-5
    eps = 1e-6
    maxit = 20
    # Constants
    EU = 0.57721566
    DL2 = math.log(2)
    DL3 = math.log(3)
    # Coefficients for rational-function approximations
    A0 = 0.28377530
    A1 = -1.21096399
    A2 = -2.50728214
    A3 = -1.13455566
    A4 = -0.07138022
    B1 = 2.06189696
    B2 = 1.31912239
    B3 = 0.25077104
    C1 = 1.59921491
    C2 = -0.48832213
    C3 = 0.01573152
    D1 = -0.64363929
    D2 = 0.08985247
    
    # Check for valid L-moments
    if lmom_1 <= 0 or abs(t3) >= 1:
        # Create a new array for the result rather than modifying an input
        return np.array([np.nan, np.nan, np.nan], dtype=np.float64)
        
    para1 = 0.0
    para2 = 0.0
    para3 = 0.0
    
    if t3 <= 0:
        G = (A0 + t3 * (A1 + t3 * (A2 + t3 * (A3 + t3 * A4)))) / (
            1 + t3 * (B1 + t3 * (B2 + t3 * B3))
        )
        if t3 >= -0.8:
            para3 = G
            GAM = math.exp(math.lgamma(1 + G))
            para2 = lmom_1 * G / (GAM * (1 - 2**-G))
            para1 = lmom_0 - para2 * (1 - GAM) / G
            return np.array([para1, para2, para3], dtype=np.float64)
            
        elif t3 <= -0.97:
            G = 1 - math.log(1 + t3) / DL2
        T0 = (t3 + 3) * 0.5
        
        # Iteration loop
        for IT in range(maxit):
            X2 = 2**-G
            X3 = 3**-G
            XX2 = 1 - X2
            XX3 = 1 - X3
            T = XX3 / XX2
            DERIV = (XX2 * X3 * DL3 - XX3 * X2 * DL2) / (XX2**2)
            GOLD = G
            G -= (T - T0) / DERIV
            if abs(G - GOLD) <= eps * G:
                para3 = G
                GAM = math.exp(math.lgamma(1 + G))
                para2 = lmom_1 * G / (GAM * (1 - 2**-G))
                para1 = lmom_0 - para2 * (1 - GAM) / G
                return np.array([para1, para2, para3], dtype=np.float64)
                
        # If iteration doesn't converge, return NaN values
        return np.array([np.nan, np.nan, np.nan], dtype=np.float64)
        
    else:
        Z = 1 - t3
        G = (-1 + Z * (C1 + Z * (C2 + Z * C3))) / (1 + Z * (D1 + Z * D2))
        
        if abs(G) < SMALL:
            para2 = lmom_1 / DL2
            para1 = lmom_0 - EU * para2
            para3 = 0
        else:
            para3 = G
            GAM = math.exp(math.lgamma(1 + G))
            para2 = lmom_1 * G / (GAM * (1 - 2**-G))
            para1 = lmom_0 - para2 * (1 - GAM) / G
            
        return np.array([para1, para2, para3], dtype=np.float64)

In [46]:
def calculate_gev_params(data_array, dim):
    """
    Calculate GEV parameters for an xarray DataArray along a specified dimension.
    
    Parameters:
    -----------
    data_array : xarray.DataArray
        The input data array containing values to fit GEV distribution to
    dim : str
        The dimension name to calculate L-moments over (typically a time dimension)
        
    Returns:
    --------
    xarray.Dataset
        A dataset containing the three GEV parameters (location, scale, shape)
    """
    # Define wrapper functions for xarray's apply_ufunc
    
    # Step 1: Calculate L-moments
    def calc_lmoments_wrapper(data):
        # Handle all-NaN slices by returning NaN for all moments
        if np.any(np.isnan(data)):
            return np.array([np.nan, np.nan, np.nan])
            
        # Calculate L-moments
        return samlmom3(data.astype(np.float64))
    
    # Step 2: Calculate GEV parameters from L-moments
    def calc_gev_params_wrapper(lmoms):
        # Check if we have valid L-moments
        if np.any(np.isnan(lmoms)):
            return np.array([np.nan, np.nan, np.nan])
            
        # L-moment ratio validity check
        if lmoms[1] <= 0 or abs(lmoms[2]/lmoms[1]) >= 1:
            return np.array([np.nan, np.nan, np.nan])
            
        # Calculate GEV parameters
        return pargev(lmoms)
    
    # Apply the first function to calculate L-moments along the specified dimension
    lmoments = xr.apply_ufunc(
        calc_lmoments_wrapper,
        data_array,
        input_core_dims=[[dim]],
        output_core_dims=[["lmoment"]],
        vectorize=True,
        dask="forbidden",  # No Dask as requested
        output_sizes={"lmoment": 3}
    )
    
    # Apply the second function to calculate GEV parameters from L-moments
    gev_params = xr.apply_ufunc(
        calc_gev_params_wrapper,
        lmoments,
        input_core_dims=[["lmoment"]],
        output_core_dims=[["param"]],
        vectorize=True,
        dask="forbidden",  # No Dask as requested
        output_sizes={"param": 3}
    )
    
    # Create a labeled dataset with the parameters
    result = gev_params.to_dataset(name="gev_params")
    
    # Convert to separate data variables for convenience
    result = xr.Dataset({
        "location": result['gev_params'].isel(param=0),
        "scale": result['gev_params'].isel(param=1),
        "shape": result['gev_params'].isel(param=2)
    })

    return result

# Example usage:
def example_usage():
    # Create a sample dataset (replace with your actual data)
    # Here we create synthetic data with a known GEV distribution
    np.random.seed(42)
    
    # Create coordinates for the dataset
    time = pd.date_range("2000-01-01", periods=50, freq="D")
    lat = np.linspace(30, 40, 400)
    lon = np.linspace(-120, -110, 900)
    
    # Create a 3D array of random data
    # For demonstration, we'll generate different GEV distributions at each grid point
    shape = (len(time), len(lat), len(lon))
    data = np.zeros(shape)
    
    # For each lat/lon point, generate data from a different GEV distribution
    for i in range(len(lat)):
        for j in range(len(lon)):
            # Vary the GEV parameters across the grid
            loc = 10 + i + j/2  # Location parameter varies
            scale = 2 + i/1000    # Scale parameter varies slightly
            shape_param = -0.5 + j/1000  # Shape parameter varies slightly
            
            # Generate random samples from GEV distribution
            # Using simplified approximation for demonstration
            u = np.random.uniform(size=len(time))
            if abs(shape_param) < 1e-6:
                # Gumbel case (shape ≈ 0)
                data[:, i, j] = loc - scale * np.log(-np.log(u))
            else:
                # GEV case
                data[:, i, j] = loc + scale * ((-np.log(u))**(-shape_param) - 1) / shape_param
    
    # Create an xarray DataArray
    da = xr.DataArray(
        data,
        coords=[time, lat, lon],
        dims=["time", "lat", "lon"],
        name="values"
    )

    return da

In [47]:
%%time
da = example_usage()

CPU times: user 7.86 s, sys: 80.5 ms, total: 7.94 s
Wall time: 8.05 s


In [48]:
%%time
gev_params = calculate_gev_params(da, dim="time")

CPU times: user 11 s, sys: 48.7 ms, total: 11.1 s
Wall time: 11.1 s


In [50]:
%%time
gev_params = calculate_gev_params(da, dim="time")

CPU times: user 28.9 s, sys: 43.5 ms, total: 28.9 s
Wall time: 29.2 s
