In [None]:
from dotenv import load_dotenv
import os

import cdsapi
import xarray as xr
from xclim.testing import open_dataset
from xclim.indices.stats import standardized_index_fit_params
from xclim.indices import standardized_precipitation_index

import xarray as xr
import dask.array as da
from dask.distributed import Client
from dateutil.relativedelta import relativedelta
from calendar import monthrange
import pandas as pd

from distributed import Client

In [None]:
load_dotenv()

data_path=os.getenv("data_path")

## Steps in the Code

1. **Environment Setup:** Loads environment variables, including the data path, using dotenv.
2. **CDS API Initialization:** Initializes the Climate Data Store (CDS) API client. Retrieves seasonal monthly single-levels data for total precipitation from 1981 to 2023 for specified months and lead times, saving it in GRIB format. This is seprately done using web interface, rewrote it here for reproducablity. 
4. **Dask Client Setup:** Initializes a Dask distributed client for parallel computation.
5. **Data Loading and Transformation:**
    - Loads the GRIB file into an xarray dataset. Followed tutorial at [Seasonal Forecast Anomalies ECMWF tutorial](https://ecmwf-projects.github.io/copernicus-training-c3s/sf-anomalies.html#change-representation-of-forecast-lead-time)
    - Transforms the data to calculate total monthly precipitation in mm, adjusting for the number of days in each month.
6. **SPI Calculation Functions:** Defines functions `apply_spi3` and `apply_spi4` to calculate the Standardized Precipitation Index (SPI) for 3-month and 4-month windows, respectively.
7. **SPI Calculation and Concatenation:**
    - Applies SPI calculation for each lead time, iterating over ensemble members.
    - Concatenates SPI values across ensemble members and lead times.
8. **Final Dataset Creation and Saving:** Creates a final xarray dataset for SPI values and saves it as a NetCDF file.

---
**Notes**
1. The calculation of SPI is based on a gamma distribution and adjusted for the climatology period 1991-2018. Utilizing [xclim library](https://xclim.readthedocs.io/en/stable/indices.html#xclim.indices.standardized_precipitation_index) Approximate fitting method is used. 
1. The code leverages parallel processing with Dask to improve computational efficiency.
1. The final output is a ensemeble dataset of SPI values across different lead times, which can be used for verification, it covers the entire East Africa in 1deg resolution. 


In [None]:

c = cdsapi.Client()

c.retrieve(
    'seasonal-monthly-single-levels',
    {
        'originating_centre': 'ecmwf',
        'system': '51',
        'variable': [
            'total_precipitation',
        ],
        'product_type': 'monthly_mean',
        'year': [
           '1981', '1982', '1983', '1984', '1985', '1986', '1987', '1988', '1989', '1990', '1991', '1992', '1993', '1994', '1995', '1996', '1997', '1998', '1999',
           '2000', '2001', '2002', '2003', '2004', '2005','2006', '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018',
           '2019', '2020', '2021', '2022','2023'
        ],
        'month': [
            '01', '02', '03',
            '04', '05', '06',
            '07', '08', '09',
            '10', '11', '12',
        ],
        'leadtime_month': [
            '1', '2', '3',
            '4', '5', '6',
        ],
        'area': [
            25, 20, -15,
            55,
        ],
        'format': 'grib',
    },
    f'{data_path}ea_seas_v51_1981_2023.grib')


In [None]:
# Depending on your workstation specifications, you may need to adjust these values.
# On a single machine, n_workers=1 is usually better.
client = Client(n_workers=3, threads_per_worker=4, memory_limit="2GB")
client

In [3]:

p1_input_data=f'{data_path}ea_seas_v51_1981_2023.grib'
db = xr.open_dataset(
    p1_input_data, 
    engine='cfgrib', 
    chunks={'time': 1},  # Adjust chunk size as needed
    backend_kwargs=dict(time_dims=('forecastMonth', 'time'))
)

# Define a function that transforms the data for a single time index
def transform_data(data_at_time):
    """
    Transforms the input data to calculate total precipitation and adjust for the number of days in each month.

    Parameters:
    - data_at_time (xarray.Dataset): Input dataset containing precipitation data at different forecast times.

    Returns:
    - data_at_time_tp (xarray.Dataset): Transformed dataset with total precipitation adjusted for the number of days in each month.
    """
    valid_time = [pd.to_datetime(data_at_time.time.values) + relativedelta(months=fcmonth-1) 
                  for fcmonth in data_at_time.forecastMonth]
    data_at_time = data_at_time.assign_coords(valid_time=('forecastMonth', valid_time))
    numdays = [monthrange(dtat.year, dtat.month)[1] for dtat in valid_time]
    data_at_time = data_at_time.assign_coords(numdays=('forecastMonth', numdays))
    data_at_time_tp = data_at_time * data_at_time.numdays * 24 * 60 * 60 * 1000
    data_at_time_tp.attrs['units'] = 'mm'
    data_at_time_tp.attrs['long_name'] = 'Total precipitation' 
    return data_at_time_tp

# Apply the function in parallel using map_blocks or apply_ufunc
# Note that the exact function call will depend on the shape of your data and the operation you want to perform
cont_db = xr.concat([transform_data(db.sel(time=time_index)) for time_index in db.time], dim='time').persist()
cont_db = cont_db.compute()
client.close()

In [None]:
cont_db=xr.open_dataset(f'{data_path}ea_seas51_tprate_1981_2023.nc')
cont_db

In [None]:
def apply_spi3(cont_db,lead_val):
    """
    Calculates the 3-month Standardized Precipitation Index (SPI) for a specified lead time.

    Parameters:
    - cont_db (xarray.Dataset): The input dataset containing total monthly precipitation data.
    - lead_val (int): The lead time value for which the SPI is calculated.

    Returns:
    - cont_spi (list): A list of xarray.DataArrays containing the SPI values for each ensemble member.
    """
    lt1_db = cont_db.sel(forecastMonth=lead_val)
    lt1_db['tprate'].attrs['units'] = 'mm/month'
    cont_spi=[]
    for nsl in lt1_db.number.values:
        lt1_db2=lt1_db.sel(number=nsl)
        #lt1_db3 = lt1_db2.chunk({'time': 4, 'latitude': 2, 'longitude': 2})
        lt1_db3 = lt1_db2.chunk(-1)
        aa=lt1_db3.tprate
        spi_3 = standardized_precipitation_index(
             aa,
             freq="MS",
             window=3,
             dist="gamma",
             method="APP",
             cal_start='1991-01-01',
             cal_end='2018-01-01',
        )  
        a_s3=spi_3.compute()
        cont_spi.append(a_s3)
        aa=[]
        lt1_db3 = []
        lt1_db2 = []
        print(nsl)
    return cont_spi

In [None]:
lead_val=1
cont1=apply_spi3(cont_db,lead_val)
t1d = xr.concat(cont1, dim='member')
t1da= t1d.to_dataset(name='spi3')

lead_val=2
cont2=apply_spi3(cont_db,lead_val)
t2d = xr.concat(cont2, dim='member')
t2da= t2d.to_dataset(name='spi3')
print('lt2')

lead_val=3
cont3=apply_spi3(cont_db,lead_val)
t3d = xr.concat(cont3, dim='member')
t3da= t3d.to_dataset(name='spi3')
print('lt3')

lead_val=4
cont4=apply_spi3(cont_db,lead_val)
t4d = xr.concat(cont4, dim='member')
t4da= t4d.to_dataset(name='spi3')
print('lt4')

lead_val=5
cont5=apply_spi3(cont_db,lead_val)
t5d = xr.concat(cont5, dim='member')
t5da= t5d.to_dataset(name='spi3')
print('lt5')

lead_val=6
cont6=apply_spi3(cont_db,lead_val)
t6d = xr.concat(cont6, dim='member')
t6da= t6d.to_dataset(name='spi3')
print('lt6')

In [None]:
ds_ea = xr.concat([t1d,t2d,t3d,t4d,t5d,t6d], dim='lead')
ds_ea= ds_ea.to_dataset(name='spi3')
ds_ea.spi3.nbytes / (1024*1024)
input_path
ds_ea.to_netcdf(f'{input_path}ea_seas51_spi3_xclim_20240306.nc')

## SPI4

In [None]:
def apply_spi4(cont_db,lead_val):
     """
    Calculates the 4-month Standardized Precipitation Index (SPI) for a specified lead time.

    Parameters:
    - cont_db (xarray.Dataset): The input dataset containing total monthly precipitation data.
    - lead_val (int): The lead time value for which the SPI is calculated.

    Returns:
    - cont_spi (list): A list of xarray.DataArrays containing the SPI values for each ensemble member.
    """
    lt1_db = cont_db.sel(forecastMonth=lead_val)
    lt1_db['tprate'].attrs['units'] = 'mm/month'
    cont_spi=[]
    for nsl in lt1_db.number.values:
        lt1_db2=lt1_db.sel(number=nsl)
        #lt1_db3 = lt1_db2.chunk({'time': 4, 'latitude': 2, 'longitude': 2})
        lt1_db3 = lt1_db2.chunk(-1)
        aa=lt1_db3.tprate
        spi_4 = standardized_precipitation_index(
             aa,
             freq="MS",
             window=4,
             dist="gamma",
             method="APP",
             cal_start='1991-01-01',
             cal_end='2018-01-01',
        )  
        a_s4=spi_4.compute()
        cont_spi.append(a_s4)
        aa=[]
        lt1_db3 = []
        lt1_db2 = []
        print(nsl)
    return cont_spi

In [None]:
lead_val=1
cont1=apply_spi4(cont_db,lead_val)
t1d4 = xr.concat(cont1, dim='member')
t1da4= t1d4.to_dataset(name='spi4')
cont2=[]
print('lt2')

lead_val=2
cont2=apply_spi4(cont_db,lead_val)
t2d4 = xr.concat(cont2, dim='member')
t2da4= t2d4.to_dataset(name='spi4')
cont2=[]
print('lt2')

lead_val=3
cont3=apply_spi4(cont_db,lead_val)
t3d4 = xr.concat(cont3, dim='member')
cont3=[]
print('lt3')


lead_val=4
cont4=apply_spi4(cont_db,lead_val)
t4d4 = xr.concat(cont4, dim='member')
cont4=[]
print('lt4')

lead_val=5
cont5=apply_spi4(cont_db,lead_val)
t5d4 = xr.concat(cont5, dim='member')
cont5=[]
print('lt5')

lead_val=6
cont6=apply_spi4(cont_db,lead_val)
t6d4 = xr.concat(cont6, dim='member')
cont6=[]
print('lt6')

ds_ea4 = xr.concat([t1d4,t2d4,t3d4,t4d4,t5d4,t6d4], dim='lead')
ds_ea4= ds_ea4.to_dataset(name='spi4')
ds_ea4.spi4.nbytes / (1024*1024)
input_path
ds_ea4.to_netcdf(f'{input_path}ea_seas51_spi4_xclim_20240306.nc')