# Drought Metrics

In this notebook, we will be exploring how to calculate PDSI and EDDI using AE variables.

To start, we will calculate Potential Evapotranspiration, since that variable is needed for both PDSI and EDDI.

### Using the Penman-Monteith method (most physically accurate) to calculate Potential Evapotranspiration

**Variables needed:**
- `tasmin`
- `tasmax`
- `relative humidity`
- `radiation flux`
    - rsds
    - rsus
    - rlds
    - rlus
- `wind speed (10m wind will be converted to 2m)`

### Imports

In [None]:
import xclim
import os
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot  as plt
from climakitae.core.data_interface import get_data
from climakitae.core.data_load import load
from climakitae.core.data_export import export
from climakitae.util.utils import add_dummy_time_to_wl

### Initial Setup

In [None]:
lat = 37.805993
lon = -122.273715

variables_dict = {
    "tasmax": "Maximum air temperature at 2m",
    "tasmin": "Minimum air temperature at 2m",
    "hurs": "Relative humidity",
    "rsds": "Instantaneous downwelling shortwave flux at bottom",
    "rsus": "Instantaneous upwelling shortwave flux at bottom",
    "rlds": "Instantaneous downwelling longwave flux at bottom",
    "rlus": "Instantaneous upwelling longwave flux at bottom",
    "wspd10mean": "Mean wind speed at 10m",
    "precip": "Precipitation (total)",
}

In [None]:
### Retrieving the different variables needed for PET
datas = []

for _, (variable, var_long_name) in enumerate(variables_dict.items()):

    file_path = f"tmp_data/{variable}_daily.nc"

    if os.path.exists(file_path):
        print(f"Reading {variable} from file.")
        da = xr.open_dataarray(file_path)
        
    else:
        # continue
        print(f"Computing {var_long_name}")
        ae_var_name = var_long_name
        timescale = 'daily'
        # if variable == 'tasmin':
        #     ae_var_name = 'Air Temperature at 2m'
        if variable == 'rlus' or variable == 'rsus':
            timescale = 'hourly'
        da = get_data(
            variable=ae_var_name,
            resolution='3 km',
            timescale=timescale,
            latitude=(lat - 0.1, lat + 0.1),
            longitude=(lon - 0.1, lon + 0.1),
            approach="Warming Level",
            warming_level=[0.8, 1.5, 2.0, 3.0],
            # scenario='SSP 3-7.0',
            # time_slice=(2030, 2060),
            downscaling_method="Dynamical"
        )
        da = load(add_dummy_time_to_wl(da), progress_bar=True)
        if variable == 'tasmin':
            agg_da = da.squeeze().resample(time='D').min()
        elif variable == 'tasmax':
            agg_da = da.squeeze().resample(time='D').max()
        elif variable == 'precip':
            agg_da = da.squeeze().resample(time='D').sum()
        else:
            agg_da = da.squeeze().resample(time='D').mean()
        agg_da.to_netcdf(file_path)  # Save for reuse
        da = agg_da

    datas.append(da)

In [None]:
# Creating daily variables for all hourly variables
tasmin = datas[0]
tasmax = datas[1]
hurs = datas[2] / 100 # Convert from % to fraction
new_hurs = hurs.assign_attrs(units='1')
rsds = datas[3]
rsus = datas[4]
rlds = datas[5]
rlus = datas[6]
sfcWind = datas[7]
precip = datas[8]

In [None]:
# %cd xclim
# !git checkout v0.54.0
# !pip install -e .

In [None]:
pet_calc = xclim.indices.potential_evapotranspiration(
    tasmin=tasmin,
    tasmax=tasmax,
    hurs=new_hurs,
    rsds=rsds,
    rsus=rsus,
    rlds=rlds,
    rlus=rlus,
    sfcWind=sfcWind,
    method="FAO_PM98"
)

In [None]:
# Preserving a spatial mask for later
spatial_mask = ~pet_calc.isel(warming_level=0, time=0, simulation=0).isnull()

# PDSI

Here, we will use the PDSI function from the `climate_indices` library, which is also what drought.gov has referenced. However, we will install a specific commit of the package that is compatible with the AE hub environment.

In [None]:
# Pip install the specific commit of the package that supports AE package versions
!pip install git+https://github.com/monocongo/climate_indices.git@43c5451

In [None]:
# Making imports from `climate_indices` package
import climate_indices
from climate_indices.palmer import pdsi

In [None]:
# Resampling PET and precip to monthly since the function only takes monthly variables
mon_pet = (pet_calc * 86400 / 25.4).resample(time='1ME').sum()
mon_precip = (precip / 25.4).resample(time='1ME').sum()

In [None]:
### Combining WL objects together, historical WL as 2000-2030, future WL as 2030-2060
def combine_wl_to_dummy_time(
    da: xr.DataArray,
    baseline_wl: float,
    future_wls: list[float],
    start_date: str = "2000-01-31",
) -> xr.DataArray:
    """
    Combine baseline warming level with multiple future warming levels into one
    DataArray along a new 'combined_wl' dimension.

    Parameters
    ----------
    da : xr.DataArray
        Original data with dims including 'warming_level' and 'time'.
    baseline_wl : float
        The warming level used for the first time segment.
    future_wls : list of float
        Warming levels to concatenate after baseline.
    start_date : str
        Start date for the combined time series (monthly freq).
    
    Returns
    -------
    xr.DataArray
        Combined DataArray with new dimension 'combined_wl' and coordinate labels like "0.8 to 1.5".
    """
    months_per_wl = da.sizes['time']
    total_months = 2 * months_per_wl
    new_time = pd.date_range(start_date, periods=total_months, freq='ME')

    combined_list = []
    combined_labels = []

    for fw in future_wls:
        da_base = da.sel(warming_level=baseline_wl)
        da_future = da.sel(warming_level=fw)

        combined = xr.concat([da_base, da_future], dim='time')
        combined = combined.assign_coords(time=new_time)

        wl_flag = np.array([baseline_wl] * months_per_wl + [fw] * months_per_wl)
        combined = combined.assign_coords(warming_level_flag=('time', wl_flag))

        combined_list.append(combined)
        combined_labels.append(f"{int(baseline_wl * 10):02d}_to_{int(fw * 10):02d}")

    combined_da = xr.concat(combined_list, dim='combined_wl')
    combined_da = combined_da.assign_coords(combined_wl=combined_labels)

    return combined_da

In [None]:
# Creating one Dataset of PET and precip with WLs combined
mon_pet_transform = combine_wl_to_dummy_time(mon_pet, baseline_wl=0.8, future_wls=[1.5,2.0,3.0])
mon_precip_transform = combine_wl_to_dummy_time(mon_precip, baseline_wl=0.8, future_wls=[1.5,2.0,3.0])

combined_ds = xr.Dataset({'precip': mon_precip_transform, 'pet': mon_pet_transform})

# Applying spatial mask
combined_ds = combined_ds.where(spatial_mask)

In [None]:
# Helper function to vectorize PDSI calculation across `combined_ds` dimensions.
def calc_pdsi(timeseries: xr.Dataset):
    """
    Compute the Palmer Drought Severity Index (PDSI) from a Dataset.

    Parameters
    ----------
    timeseries : xarray.Dataset
        Dataset containing 'precip' and 'pet' variables with a time dimension.

    Returns
    -------
    xarray.DataArray
        PDSI values along the time dimension.
    """
    # Extracting precip and PET by each timeseries and calculating PDSI
    precip = timeseries['precip'].squeeze()
    pet = timeseries['pet'].squeeze()
    
    pdsi_calc = pdsi(
        precips=precip.values,
        pet=pet.values,
        awc=5,
        data_start_year=2000,
        calibration_year_initial=2000,
        calibration_year_final=2030,
    )
    retval = xr.DataArray(pdsi_calc[0], coords={"time": precip.time.values}, dims=['time'])
    
    # Clipping PDSI to realistic values
    return retval.clip(min=-10, max=10)

In [None]:
# Applies the PDSI function across all dimensions so that a timeseries of PET/precip is always being passed into `pdsi`
pdsi_da = combined_ds.groupby([
    'combined_wl',
    'x',
    'y',
    'simulation'
]).apply(
    lambda timeseries: calc_pdsi(timeseries)
)

### Saving out the results

The results will have the following dimensions:
- time
- wl (showing which WL PDSI was calibrated on, and then which WL PDSI was calculated on)
- x
- y
- simulation

In [None]:
# Saving these results and cleaning the data
final_pdsi = pdsi_da.isel(time=slice(360, 720))
final_pdsi = final_pdsi.rename({'combined_wl': 'wl'})
export(final_pdsi, filename='pdsi_wl')

## EDDI

Now, we will calculate EDDI using PET.

In [None]:
# Import `standardized_index` from xclim, which we will apply to our PET data object to generate EDDI
from xclim.indices.stats import standardized_index

In [None]:
def calc_eddi(timeseries: xr.DataArray):
    """
    Compute the Evaporative Demand Drought Index (EDDI) for a time series.

    Parameters
    ----------
    timeseries : xarray.DataArray
        1D time series of ET₀. NaNs are skipped.

    Returns
    -------
    xarray.DataArray
        EDDI values. Positive = dry, negative = wet.
    """
    eddi = standardized_index(
        da=timeseries,
        freq='MS',
        window=1,
        dist="gamma",
        method="ML",
        zero_inflated=False,
        fitkwargs={},
        cal_start="2000-01-31",
        cal_end="2029-12-31"
    )
    # Clipping EDDI to realistic values
    retval = eddi.clip(min=-2.5, max=2.5)
    return retval

In [None]:
# Applies the `calc_eddi` function across all dimensions so that a timeseries of PET is always being passed into `calc_eddi`
eddi_da = combined_ds['pet'].groupby([
    'combined_wl',
    'x',
    'y',
    'simulation'
]).apply(
    lambda timeseries: calc_eddi(timeseries)
)

### Exporting the data

In [None]:
# Saving these results and cleaning the data
final_eddi = eddi_da.isel(time=slice(360, 720))
final_eddi = final_eddi.rename({'combined_wl': 'wl'})
export(final_eddi, filename='eddi_wl')