In [None]:
from io import StringIO
import os 
from dotenv import load_dotenv

import climpred
import xarray as xr
import xesmf as xe
import numpy as np
import pandas as pd
import regionmask
import geopandas as gp
from climpred import HindcastEnsemble
from datetime import datetime

import xhistogram.xarray as xhist
from sklearn.metrics import roc_auc_score
import pandas as pd

import xskillscore as xs
from xbootstrap import block_bootstrap
from dask.distributed import Client

In [None]:
load_dotenv()

data_path=os.getenv("data_path")

## The Forecast Evaluation Process

Following process is invovled in assessing the accuracy of drought forecasts relative to observed data. 

The code uses contignecy function of [xskillscore](https://xskillscore.readthedocs.io/en/stable/api.html#dichotomous-only-yes-no-metrics) for Dichotomous-Only (yes/no) Metrics and [roc function](https://xskillscore.readthedocs.io/en/stable/api/xskillscore.roc.html#xskillscore.roc) of for Multi-Category Metrics ROC and AUROC.

1. **Observation and Forecast Data**: The process begins with two primary datasets - observed data (`obs_data`) and ensemble forecast data (`ens_data`).

2. **Threshold Determination**: A predefined threshold is used to classify whether a drought event occurred (based on observed and forecasted data). This is the threshold value for different season as discussed in the paper. The used threshold values are indicated in the function `get_threshold`.

3. **Dichotomous Events Creation**:
   - **Observed Events**: A bin edges as described in the documentation for observed events is passed taking the extent of SPI values -4 to 4. The less than threshold value is NOT following Gabriela et.al 2023 as it is following the logic of xskillscore. 
   - **Forecasted Events**: Instead of averaging ensemble forecasts, an empirical probability is calculated. This reflects the proportion of ensemble members predicting values at or below the given threshold. 

4. **Probability Threshold for Forecasts**: The forecast data is then classified into dichotomous events based on a trigger value. This step converts the forecast probability into a binary outcome - whether a drought is forecasted to occur or not. In the process, all the possible trigger value between 0 to 1 is subject to testing as in the line `trigger_values = xr.DataArray(np.linspace(0, 1, num=100), dims=['trigger_value'])`. All combination of probablity value is tested by using a for loop and passing the bin edges as needed by the xskillscore. the function get_triggers_bin_edges creates the bin_edges as reqeuierd by the xskillscore. 

5. The **Contingency Table Construction** and **Skill Score Calculation** are created following the internal logic of xskillscore which uses `xhist.histogram`. The 2D histogram (contingency table) is created from the binary observed and forecasted events. This table quantifies the relationship between observed occurrences/non-occurrences and forecasted occurrences/non-occurrences of drought. From the contingency table, various skill scores are calculated, including hit rates, false alarm ratios, and others. These metrics provide a comprehensive view of forecast accuracy and reliability.

7. **AUROC Calculation**: The Area Under the Receiver Operating Characteristic (AUROC) score is calculated, offering a measure of the forecast's ability to discriminate between the occurrence and non-occurrence of drought events. Bootstrap method based on library [xbootstrap](https://pypi.org/project/xbootstrap/) and applied in xskillscore as [following the gist](https://gist.github.com/aaronspring/471e70f787aef6689825182e794421fb) was used, it uses dask and thus computationally fast. 

8. **Final Output**: The process culminates in the generation of a DataFrame summarizing the calculated skill scores for various trigger values, which is then saved to a CSV file. This summary facilitates the evaluation of forecast performance across different probability thresholds.

---


## The summary of steps are as follows

1. **Load Datasets**: Based on the provided `season_str` length, either SPI3 or SPI4 datasets are loaded for both forecast (`kn_fct`) and observed (`kn_obs`) data from a specified `data_path`.

2. **Generate Region Masks**: Utilizes `ken_mask_creator()` to create masks for specified regions or districts using the `regionmask` library, facilitating region-specific analyses.

3. **Select Region**: Extracts geographical bounds from a combined GeoDataFrame and selects forecast and observation data within these bounds for the specified `region_id`.

4. **Initialize Hindcast Ensemble**: Creates a `HindcastEnsemble` object with the selected forecast data and adds the corresponding observations to it.

5. **Subset for Lead Time**: Subsets the forecast data for the given `lead_int`, ensuring analysis is conducted at the specified forecast lead time.

6. **Generate Seasonal Product Names**: Depending on `season_str` length, either `spi3_prod_name_creator` or `spi4_prod_`name_creator` is called to generate a list of seasonal product names for both forecast and observed data, facilitating season-specific filtering.

7. **Assign and Filter by Season**: Assigns generated seasonal product names as coordinates to both datasets and filters them to include only data corresponding to the specified `season_str`.

8. **Align Time Coordinates**: Aligns the observed dataset time coordinates with the forecast dataset's valid time coordinates, ensuring that both datasets are comparable in time for analysis.


This process ensures that the observed and forecasted datasets are correctly prepared and aligned for a specified region, season, and lead time, allowing for the accurate calculation of verification scores such as hit rates, false alarm ratios, and AUROC scores.

In [None]:
def ken_mask_creator():
    """
    Utiliity for generating region/district masks using regionmask library

    Returns
    -------
    the_mask : TYPE
        DESCRIPTION.
    rl_dict : TYPE
        DESCRIPTION.

    """
    dis=gp.read_file(f'{data_path}Karamoja_boundary_dissolved.shp')
    mbt_path=os.getenv("mbt_path")
    reg=gp.read_file(f'{data_path}wajir_mbt_extent.shp')
    mds=pd.concat([dis,reg])
    mds1=mds.reset_index()
    mds1['region']=[0,1,2]
    mds1['region_name']=['Karamoja', 'Marsabit','Wajir']
    mds2=mds1[['geometry','region','region_name']]
    rl_dict=dict(zip(mds2.region, mds2.region_name))
    the_mask = regionmask.from_geopandas(mds2,numbers='region',overlap=True)
    return the_mask, rl_dict, mds2

def spi3_prod_name_creator(ds_ens,var_name):
    """
    Convenience function to generate a list of SPI product
    names, such as MAM, so that can be used to filter the 
    SPI product from dataframe

    added with method to convert the valid_time in CF format into datetime at
    line 3, which is the format given by climpred valid_time calculation 

    Parameters
    ----------
    ds_ens : xarray dataframe
        The data farme with SPI output organized for 
        the period 1981-2023.

    Returns
    -------
    spi_prod_list : String list
        List of names with iteration of SPI3 product names such as
        ['JFM','FMA','MAM',......]

    """
    db=pd.DataFrame()
    db['dt']=ds_ens[var_name].values
    db['dt1'] = db['dt'].apply(lambda x: datetime(x.year, x.month, x.day,
                                                                     x.hour, x.minute, x.second))
    #db['dt1']=db['dt'].to_datetimeindex()
    db['month']=db['dt1'].dt.strftime('%b').astype(str).str[0]
    db['year']=db['dt1'].dt.strftime('%Y')
    db['spi_prod'] = db.groupby('year')['month'].shift(2)+db.groupby('year')['month'].shift(1) + db.groupby('year')['month'].shift(0)
    spi_prod_list=db['spi_prod'].tolist()
    return spi_prod_list


def spi4_prod_name_creator(ds_ens,var_name):
    """
    Convenience function to generate a list of SPI product
    names, such as MAM, so that can be used to filter the 
    SPI product from dataframe

    added with method to convert the valid_time in CF format into datetime at
    line 3, which is the format given by climpred valid_time calculation 

    Parameters
    ----------
    ds_ens : xarray dataframe
        The data farme with SPI output organized for 
        the period 1981-2023.

    Returns
    -------
    spi_prod_list : String list
        List of names with iteration of SPI3 product names such as
        ['JFM','FMA','MAM',......]

    """
    db=pd.DataFrame()
    db['dt']=ds_ens[var_name].values
    db['dt1'] = db['dt'].apply(lambda x: datetime(x.year, x.month, x.day,
                                                                     x.hour, x.minute, x.second))
    #db['dt1']=db['dt'].to_datetimeindex()
    db['month']=db['dt1'].dt.strftime('%b').astype(str).str[0]
    db['year']=db['dt1'].dt.strftime('%Y')
    db['spi_prod'] = db.groupby('year')['month'].shift(3)+db.groupby('year')['month'].shift(2)+db.groupby('year')['month'].shift(1) + db.groupby('year')['month'].shift(0)
    spi_prod_list=db['spi_prod'].tolist()
    return spi_prod_list


def make_obs_fct_dataset(region_id,season_str,lead_int):
    """
    Prepares observed and forecasted dataset subsets for a specific region, season, and lead time.

    This function loads observed and forecasted datasets based on the season string length (indicating SPI3 or SPI4),
    applies regional masking, selects the data for the given region by its ID, and subsets the data for the specified
    season and lead time. It then aligns the observed dataset time coordinates with the forecasted dataset valid time
    coordinates and returns both datasets.

    Parameters:
    - region_id (int): The identifier for the region of interest.
    - season_str (str): A string representing the season. The length of this string determines whether SPI3 or SPI4
                        datasets are used ('mam', 'jjas', etc. for SPI3, and longer strings for SPI4).
    - lead_int (int): The lead time index for which the forecast dataset is to be subset.

    Returns:
    - obs_data (xarray.DataArray): The subsetted observed data array for the specified region, season, and aligned time coordinates.
    - ens_data (xarray.DataArray): The subsetted forecast data array for the specified region, season, lead time, and aligned time coordinates.

    Notes:
    - The function assumes the existence of a `data_path` variable that specifies the base path to the dataset files.
    - It requires the `xarray` library for data manipulation and assumes specific naming conventions for the dataset files.
    - Regional masking and season-specific processing rely on externally defined functions and naming conventions.
    - The final alignment of observed dataset time coordinates with forecasted dataset valid time coordinates ensures
      comparability between observed and forecasted values for verification purposes.

    Example Usage:
    >>> obs_data, ens_data = make_obs_fct_dataset(1, 'mam', 0)
    >>> print(obs_data)
    >>> print(ens_data)

    This would load the observed and forecasted SPI3 datasets for region 1 during the 'mam' season and subset them
    for lead time index 0, aligning the observed data time coordinates with the forecasted data valid time coordinates.
    """
    if len(season_str) == 3:
        kn_fct=xr.open_dataset(f'{data_path}kn_fct_spi3.nc')
        kn_obs=xr.open_dataset(f'{data_path}kn_obs_spi3.nc')
    else:
        kn_fct=xr.open_dataset(f'{data_path}kn_fct_spi4.nc')
        kn_obs=xr.open_dataset(f'{data_path}kn_obs_spi4.nc')
    the_mask, rl_dict,mds1=ken_mask_creator()
    bounds = mds1.bounds
    #bounds.iloc[0].minx
    llon=bounds.iloc[region_id].minx
    llat=bounds.iloc[region_id].miny
    ulon=bounds.iloc[region_id].maxx
    ulat=bounds.iloc[region_id].maxy
    a_fc=kn_fct.sel(lon=slice(llon, ulon), lat=slice(llat,ulat))
    a_obs=kn_obs.sel(lon=slice(llon, ulon), lat=slice(llat,ulat))
    hindcast = HindcastEnsemble(a_fc)
    hindcast = hindcast.add_observations(a_obs)
    #hindcast
    #spi_cdb1spi3_prod_name_creator(ds_ens)
    a_fc1=hindcast.get_initialized()
    a_fc2=a_fc1.isel(lead=lead_int)
    if len(season_str) == 3:
        spi_prod_list=spi3_prod_name_creator(a_fc2,'valid_time')
        obs_spi_prod_list=spi3_prod_name_creator(a_obs,'time')
    else:
        spi_prod_list=spi4_prod_name_creator(a_fc2,'valid_time')
        obs_spi_prod_list=spi4_prod_name_creator(a_obs,'time')
    a_fc2 = a_fc2.assign_coords(spi_prod=('init',spi_prod_list))
    a_fc3=a_fc2.where(a_fc2.spi_prod==season_str, drop=True)
    #obsertations
    a_obs1 = a_obs.assign_coords(spi_prod=('time',obs_spi_prod_list))
    a_obs2=a_obs1.where(a_obs1.spi_prod==season_str, drop=True)
    #valid_time_series = a_fc3.valid_time.to_series().reset_index(drop=True).drop_duplicates()
    valid_time_flattened = a_fc2.valid_time.to_dataframe().reset_index().drop_duplicates(subset='valid_time')['valid_time']
    valid_time_flattened.columns=['valid_time','cc']
    #valid_time_flattened['valid_time'] = pd.to_datetime(valid_time_flattened['valid_time'])
    # Apply lambda function to create 'dt1' column
    #valid_time_flattened['dt1'] = valid_time_flattened['valid_time'].apply(
    #    lambda x: datetime(x.year, x.month, x.day, x.hour, x.minute, x.second)
    #)
    #
    valid_time_flattened['dt1'] =valid_time_flattened['valid_time'].apply(lambda x: datetime(x.year, x.month, x.day,x.hour, x.minute, x.second))
    # Ensure the valid_time is in 'YYYY-MM-DD' string format
    #valid_time_flattened['dt2'] = valid_time_flattened['dt1'].dt.strftime('%Y-%m-%d')
    valid_time_flattened['dt1'] = valid_time_flattened['dt1'].dt.strftime('%Y-%m-%dT%H:%M:%S.%f')
    valid_time_flattened['dt1'] = pd.to_datetime(valid_time_flattened['dt1'])
    # Convert to xarray DataArray with time as the dimension name
    #valid_time_da = xr.DataArray(valid_time_flattened['dt1'], dims=['time'])
    valid_time_da = xr.DataArray(valid_time_flattened['dt1'], dims=['time'],coords=valid_time_flattened['dt1'])
    a_obs3 = a_obs2.reindex(time=valid_time_da)
    #a_obs4 = a_obs3.reindex(time=a_obs2.time)
    #a_obs4 = a_obs3.sel(time=a_obs2.time, drop=True)
    a_obs3 = a_obs3.dropna(dim='time')
    if len(season_str) == 3:
        obs_data=a_obs3['spi3']
        ens_data=a_fc3['spi3']
    else:
        obs_data=a_obs3['spi4']
        ens_data=a_fc3['spi4']
    return obs_data, ens_data


def get_threshold(region_id, season, level):
    """
    Retrieves the drought threshold value for a specified region, season, and drought level.

    The function reads predefined threshold values from a CSV-format string. It looks up the threshold for the given
    region ID, season, and drought level ('mod' for moderate, 'sev' for severe, or 'ext' for extreme). These thresholds
    are specific to certain regions and seasons and indicate the level at which a drought event of a particular severity
    is considered to occur.

    Parameters:
    - region_id (int): The integer identifier for the region of interest.
    - season (str): The season for which the threshold is required. Expected values are season codes such as 'mam' (March-April-May),
                    'jjas' (June-July-August-September), 'ond' (October-November-December), etc.
    - level (str): The drought severity level for which the threshold is requested. Valid options are 'mod' for moderate,
                   'sev' for severe, and 'ext' for extreme drought conditions.

    Returns:
    - float: The threshold value for the specified region, season, and drought level. Returns None if no threshold is found for the given inputs.

    Note:
    - This function uses a hardcoded CSV string as its data source. In a production environment, it's recommended to
      store and retrieve such data from a more robust data management system.
    - The function requires the pandas library for data manipulation and the StringIO module from io for string-based data input.

    Example usage:
    >>> threshold = get_threshold(1, 'mam', 'mod')
    >>> print(threshold)
    -0.14
    """
    data = """region_id,region,season,mod,sev,ext
    0,kmj,mam,-0.03,-0.56,-0.99
    0,kmj,jjas,-0.01,-0.41,-0.99
    1,mbt,mam,-0.14,-0.38,-0.8
    1,mbt,ond,-0.15,-0.53,-0.71
    2,wjr,mam,-0.19,-0.45,-0.75
    2,wjr,ond,-0.29,-0.76,-0.9
    """
    # Use StringIO to convert the string data to a file-like object
    data_io = StringIO(data)
    # Read the data into a pandas DataFrame
    df = pd.read_csv(data_io)
    thresholds_dict = { (row['region_id'], row['season']): {'mod': row['mod'], 'sev': row['sev'], 'ext': row['ext']}
                   for _, row in df.iterrows() }
    # Retrieve the dictionary for the given region_id and season
    season_thresholds = thresholds_dict.get((region_id, season), {})
    # Return the threshold for the given level (mod, sev, ext), or None if not found
    return season_thresholds


def get_triggers_bin_edges():
    """
    Generate bin edges for triggers based on forecast category edges.

    Returns:
    list of lists: Bin edges arranged with three elements each.
    """
    forecast_category_edges = np.linspace(0, 1, 101)
    # Initialize an empty list to hold your list of lists
    list_of_lists = []
    # Iterate through forecast_category_edges to construct each [n1, n2, n3]
    for i, edge in enumerate(forecast_category_edges):
        if i == 0:
            # For the first element, there is no lower edge within the range, so you might set n1 to 0 or any other logic
            n1 = 0  # or edge itself if you want to keep it within valid probability bounds
        else:
            n1 = forecast_category_edges[i-1]
        n2 = edge  # The current edge value
        if i == len(forecast_category_edges) - 1:
            # For the last element, there is no upper edge within the range, so you might set n3 to 1 or any other logic
            n3 = 1  # or edge itself if you want to keep it within valid probability bounds
        else:
            n3 = forecast_category_edges[i+1]
        # Append the [n1, n2, n3] list to your list of lists
        list_of_lists.append([n1, n2, n3])
    return list_of_lists

def get_thresholds_bin_edges(threshold_dict, lowest_bound=-4.0, highest_bound=4.0):
    """
    Generate bin edges based on provided thresholds, ensuring all sublists have three elements:
    [lower_edge, threshold, upper_edge], including the extreme bounds.
    
    Parameters:
    - threshold_dict (dict): Dictionary with levels as keys and thresholds as values.
    - lowest_bound (float): Lowest boundary for the bins.
    - highest_bound (float): Highest boundary for the bins.
    
    Returns:
    - list of lists: Bin edges arranged with three elements each.
    
    TODO
    merge the dict call on level and then return the sepcific bin edges for that level
    """
    # Extract thresholds and sort them in ascending order
    sorted_thresholds = sorted(threshold_dict.values())
    
    # Initialize list of lists with the first bin
    list_of_lists = []
    
    # Handle the first bin separately
    if sorted_thresholds:
        list_of_lists.append([lowest_bound, sorted_thresholds[0], sorted_thresholds[1] if len(sorted_thresholds) > 1 else highest_bound])
    
    # Loop through the sorted thresholds to create bins for the middle thresholds
    for i in range(1, len(sorted_thresholds) - 1):
        list_of_lists.append([sorted_thresholds[i-1], sorted_thresholds[i], sorted_thresholds[i+1]])
    
    # Handle the last bin separately if there are at least two thresholds
    if len(sorted_thresholds) > 1:
        list_of_lists.append([sorted_thresholds[-2], sorted_thresholds[-1], highest_bound])
    
    # Special case: If there is only one threshold, adjust the initial list to include highest_bound
    if len(sorted_thresholds) == 1:
        list_of_lists[0][-1] = highest_bound  # Replace the last element of the first sublist with highest_bound
    
    return list_of_lists

In [None]:
region_id=0
season_str='MAM'
level='mod'
lead_int=1
obs_data, ens_data=make_obs_fct_dataset(region_id,season_str,lead_int)

sc_season_str=season_str.lower()
threshold_dict=get_threshold(region_id, sc_season_str, level)

obs_be=get_thresholds_bin_edges(threshold_dict, lowest_bound=-4.0, highest_bound=4.0)

fcst_be=get_triggers_bin_edges()

level_be=obs_be[1]

In [None]:
cont_db=[]
for tr_be in fcst_be:
    a_level_be=np.array(level_be)
    a_tr_be=np.array(tr_be)
    multicategory_contingency = xs.Contingency(obs_data, ens_data, a_level_be, a_tr_be, dim=["lat","lon","member",'init'])
    db=pd.DataFrame()
    db['time']=multicategory_contingency.heidke_score()['time'].values
    db['bias_score']=multicategory_contingency.bias_score().values
    db['false_alarm_ratio']=multicategory_contingency.false_alarm_ratio().values
    db['hit_rate']=multicategory_contingency.hit_rate().values
    db['heidke_score']=multicategory_contingency.heidke_score().values
    db['peirce_score']=multicategory_contingency.peirce_score().values
    db.insert(0, 'tr_be', tr_be[1])
    cont_db.append(db)
    print(tr_be)
    
    
db=pd.concat(cont_db)

db.to_csv(f'{data_path}{region_id}_{season_str}_{level}_lt{lead_int}.csv')

## AUROC 

In [None]:
date_strings = ens_data['init'].dt.strftime('%Y-%m-%d %H:%M:%S').values

# Now, use pandas to parse these strings into datetime64
ens_data['init'] = pd.to_datetime(date_strings)

ens_data = ens_data.swap_dims({'init': 'valid_time'})

vdate_strings = ens_data['valid_time'].dt.strftime('%Y-%m-%d %H:%M:%S').values

# Now, use pandas to parse these strings into datetime64
ens_data['valid_time'] = pd.to_datetime(vdate_strings)

ens_data = ens_data.rename({'valid_time': 'time'})
#ens_data['init'].values
#ens_data['valid_time'].values

obs_event1 = obs_data >= threshold_dict['mod']
#obs_event1['name'] = 'observed_event' 
obs_event1.name = 'spi3' 

dfp= (ens_data <= threshold_dict['mod']).mean(dim='member')

In [None]:
# Depending on your workstation specifications, you may need to adjust these values.
# On a single machine, n_workers=1 is usually better.
client = Client(n_workers=3, threads_per_worker=4, memory_limit="2GB")
client
#client.close()

In [None]:
bs_obs_data,bs_ens_data= block_bootstrap(
    obs_event1.chunk(),
    dfp.chunk(),
    blocks={"time":1},
    n_iteration=1000,
    circular=True,
)

In [None]:
fpr, tpr, aroc = xs.roc(bs_obs_data,bs_ens_data, bin_edges='continuous',dim=["iteration"],return_results='all_as_metric_dim')

In [None]:
maroc = aroc.mean(dim=["lat","lon"])
mdb=maroc.to_dataframe()

# Calculate the 2.5th percentile across the 'iteration' dimension
p2aroc = aroc.quantile(0.025, dim=["lat","lon"])
p2db=p2aroc.to_dataframe()

# Calculate the 97.5th percentile across the 'iteration' dimension
p9aroc = aroc.quantile(0.975, dim=["lat","lon"])
p9db=p9aroc.to_dataframe()

In [None]:
adb=pd.concat([mdb,p2db,p9db],axis=1)
adb.to_csv(f'{data_path}auroc_{region_id}_{season_str}_{level}_lt{lead_int}_a.csv')