In [None]:
import xarray as xr
import xclim
import xarray as xr
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from datetime import datetime

# Calculating daily SPI from precipitation data

In [None]:
!pip install climate-indices

In [None]:
path_files='prec/'

In [None]:
def prepare_input(path_files,name_input,name_output):
    '''This function prepare the input data for climate indices, checking the units and the order of the coordinates'''
    ds=xr.open_dataset(path_files+name_input)
    ds["TOT_PREC"].attrs["units"]='mm'
    prec=ds.TOT_PREC
    preferred_dims = ("lat", "lon", "time")
    transposed_prec = prec.transpose(*preferred_dims)
    transposed_prec.attrs["units"]='mm'
    transposed_prec.to_netcdf(path_files+name_output)
    return(transposed_prec)
    #path_rea='CMCC_VHR_REA_2022/'

In [None]:
def prepare_output(path_files, name_input, name_output, name_mask, threshold):
    """
    This function:
    1. Reorders the SPI file dimensions to (time, lat, lon)
    2. Saves the reordered SPI file
    3. Creates a binary mask for drought events (SPI < threshold), with 1 for drought and 0 otherwise
    4. Saves the mask to a NetCDF file
    """
    # Load SPI data
    ds_spi = xr.open_dataset(path_files + name_input)

    # Make sure dimensions are ordered properly
    preferred_dims_out = ("time", "lat", "lon")
    ds_spi = ds_spi.transpose(*preferred_dims_out)

    # Save the reordered SPI file
    ds_spi.to_netcdf(path_files + name_output)

    # Extract the SPI variable (assumes it's the only variable or named 'spi')
    if 'SPI_365' in ds_spi.data_vars:
        spi_data = ds_spi['SPI_365']
    else:
        spi_data = list(ds_spi.data_vars.values())[0]  # fallback to first variable

    # Create drought mask: 1 where SPI < threshold, 0 otherwise
    mask = xr.where(spi_data < threshold, 1, 0)

    # Save the mask
    mask.name = "drought_mask"
    mask.to_netcdf(path_files + name_mask)

    print(f"✅ Saved reordered SPI to: {path_files + name_output}")
    print(f"✅ Saved drought mask to: {path_files + name_mask}")

    return()

In [None]:
prec_transposed=prepare_input(path_files,'prec_daily.nc','prec_daily_transposed.nc')

In [None]:
!process_climate_indices --index spi  --periodicity daily --netcdf_precip $path_files/prec_daily_transposed.nc --var_name_precip TOT_PREC --output_file_base $path_files/daily --scales 365 --calibration_start_year 1991 --calibration_end_year 2022 --multiprocessing all

In [None]:
prepare_output(path_files,'daily_spi_365_gamma.nc','drought_daily.nc','02_drought_mask.nc',-2)

# Adding duration filters to heatwaves and drought events

In [None]:
def label_consecutive_ones(file_path, output_path,counter,name):
    '''This functions checks the minimum duration of the drought and heatwave events. 
    It takes as input the mask files and counts the number of consecutive ones (events above threshold) per each cell.
    If the count if less than the selected threshold (counter), it puts it to zero, otherwise it leaves it 1.'''
    
    # Load the dataset
    ds = xr.open_dataset(file_path)

    # Selecting the hazard mask var
    binary_var = ds[name]

    # Create an empty array to store the results
    result = np.zeros_like(binary_var, dtype=int)

    # Iterate over each cell in the spatial dimensions (lat, lon)
    for lat in range(binary_var.shape[1]):  # Assuming the second dimension is latitude
        for lon in range(binary_var.shape[2]):  # Assuming the third dimension is longitude
            # Get the time series for the current cell
            cell_series = binary_var[:, lat, lon].values

            # Find consecutive 1s
            count = 0
            for t in range(len(cell_series)):
                if cell_series[t] == 1:
                    count += 1
                else:
                    if count > counter:
                        result[t-count:t, lat, lon] = 1
                    count = 0
            
            # Handle the case where the sequence ends at the last time point
            if count > counter:
                result[len(cell_series)-count:len(cell_series), lat, lon] = 1

    # Create a new DataArray for the result
    result_da = xr.DataArray(result, coords=binary_var.coords, dims=binary_var.dims)

    # Create a new Dataset for the result
    result_ds = xr.Dataset({name: result_da})

    # Save the result to a new NetCDF file
    result_ds.to_netcdf(output_path)

    print(f"Processed data saved to {output_path}")