#### IMPORTS

In [1]:
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import rioxarray as rxr
from pyproj import CRS
from shapely.geometry import mapping
from rasterio.enums import Resampling

import climate_indices
from climate_indices import indices

#### SETTING PARAMETERS

In [None]:
############################################################################ SETTING PARAMETERS  for the SPI #################################################################

scale = 3
distribution = climate_indices.indices.Distribution.gamma  # Fixed
data_start_year = 1980
calibration_year_initial = 1980
calibration_year_final = 2023
periodicity = climate_indices.compute.Periodicity.monthly  # Fixed

######################################################################################### CRS ############################################################################

crs_project = CRS.from_epsg(4326)  # WGS84

######################################################################################### INPUTS ############################################################################

ERA5_input_path =  'C:\\Insert\\path\\to\\monthly\\ERA5_land\\data.nc'

shapefile_path = "C:\\Insert\\path\\to\\shapefile\\shapefile.shp"

######################################################################################### OUPUTS ############################################################################

path_out = "C:\\Insert\\path\\to\\output\\directory\\"  #'Insert path to the output directory'

SPI_ouput_file = "SPI_test.nc"  #'Output file name'

#### DEFINING FUNCTIONS

In [14]:
# Load the shapefile
def load_shape_file(filepath):
    """Loads the shape file desired to mask a grid.
    Args:
        filepath: Path to *.shp file
    """
    shpfile = gpd.read_file(filepath)
    print(
        """Shapefile loaded. To prepare for masking, run the function
        `select_shape`."""
    )
    return shpfile


# Create the mask
def select_shape(shpfile):
    """Select the submask of interest from the shapefile.
    Args:
        shpfile: (*.shp) loaded through `load_shape_file`
        category: (str) header of shape file from which to filter shape.
            (Run print(shpfile) to see options)
        name: (str) name of shape relative to category.
           Returns:
        shapely polygon
    """

    col_code = "ISO3_CODE"
    country_codes = ["ZAF", "LSO", "SWZ"]

    # Extract the rows that have 'ZAF', 'LSO', or 'SWZ' in the 'SOV_A3' column
    selected_rows = shpfile[shpfile[col_code].isin(country_codes)]

    # Combine the selected polygons into a single polygon
    unioned_polygon = selected_rows.geometry.unary_union

    # Convert the unioned polygon to a geopandas dataframe with a single row
    mask_polygon = gpd.GeoDataFrame(geometry=[unioned_polygon])

    print("""Mask created.""")

    return mask_polygon


# Processing the data (masking and reshaping)
def proccessingNETCDF(data):
    """Process the data to serve as input to de SPI function
    Args:
        data: netcdf file

        Returns
        DataArrayGroupBy grouped over point (y and x coordinates)
    """
    num_days_month = data.time.dt.days_in_month

    data_precip = (
        data * 2.908522800670776e-07
    ) + 0.009530702520736942  # Rescaling the values
    data_precip = (
        data_precip * 1000 * num_days_month
    )  # The original units are meters, we change them to milimeters, and multiply by the days of the month

    # Reverse the Y dimension values to increasing values (This is an issue of ERA5 datasets and other climatic datasets)
    data_precip = data_precip.rename({"y": "lat", "x": "lon"})  # Necessary step
    data_precip = data_precip.reindex(lat=list(reversed(data_precip["lat"])))
    data_precip = data_precip.rename({"lat": "y", "lon": "x"})

    # Mask the country
    data_precip_masked = data_precip.rio.clip(
        mask_layer.geometry.apply(mapping),
        crs=mask_layer.crs,
        all_touched=True,
        from_disk=True,
    ).squeeze()

    # Giving the appropriate shape to da data
    data_grouped = data_precip_masked.stack(point=("y", "x")).groupby("point")
    print(
        """Data is prepared to serve
         as input for the SPI index."""
    )

    return data_grouped

#### ERA5 DATA

In [15]:
# Loading the data
data = rxr.open_rasterio(ERA5_input_path, masked=True)
# Giving a CRS
data.rio.write_crs(crs_project, inplace=True)
# Getting precipitation values
data = data["tp"]

#### MASK LAYER

In [None]:
# Load de shp
shpfile = load_shape_file(shapefile_path)  # Boundaries

# Create the mask layer
mask_layer = select_shape(shpfile)
# Giving a CRS
mask_layer.crs = crs_project

#### DATA PROCESSING

In [None]:
data_grouped = proccessingNETCDF(data)

#### APPLY SPI FUNCTION

In [None]:
#####https://github.com/monocongo/climate_indices
spi_values = xr.apply_ufunc(
    indices.spi,
    data_grouped,
    scale,
    distribution,
    data_start_year,
    calibration_year_initial,
    calibration_year_final,
    periodicity,
)

# Unstack the array back into original dimensions
spi_results = spi_values.unstack("point")

In [20]:
spi_results = spi_results.rio.reproject_match(
    data, resampling=Resampling.bilinear, nodata=np.nan
)

#### EXPORTING

In [24]:
spi_results.to_netcdf(f"{path_out}{SPI_ouput_file}")