# Drought index

## Import libraries

In [None]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "2000-01"
stop = "2022-12"
index_slice = slice("2022-06-01", "2022-09-30")

# Max value allowed
threshold = -1.5

# Space
area = [58, -10, 36, 30]

## Define requests

In [None]:
requests = {
    "Reanalysis": [
        "derived-era5-single-levels-daily-statistics",
        {
            "product_type": "reanalysis",
            "variable": ["volumetric_soil_water_layer_1", "land_sea_mask"],
            "daily_statistic": "daily_mean",
            "time_zone": "utc+00:00",
            "frequency": "1_hourly",
            "area": area,
        },
    ],
    "Satellite": [
        "satellite-soil-moisture",
        {
            "variable": ["volumetric_surface_soil_moisture"],
            "type_of_sensor": ["combined_passive_and_active"],
            "time_aggregation": ["day_average"],
            "type_of_record": ["cdr"],
            "version": ["v202312"],
        },
    ],
}

## Define functions to cache

In [None]:
def smooth(obj, window):
    return obj.interpolate_na("time").rolling(time=window, min_periods=1).mean()


def compute_drought_index(ds, index_slice, lon_slice, lat_slice, threshold):
    # Get raw data
    (var_name,) = set(ds.data_vars) & {"sm", "swvl1"}
    raw_data = utils.regionalise(ds[var_name], lon_slice=lon_slice, lat_slice=lat_slice)
    if (lsm := ds.get("lsm")) is not None:
        raw_data = raw_data.where((lsm > 0.5).all("time"))
    raw_data = raw_data.chunk(time=-1)

    # Get time-varying index
    group_dim = "time.dayofyear"
    smooth_data_grouped = smooth(raw_data, 11).groupby(group_dim)
    index = raw_data.groupby(group_dim) - smooth_data_grouped.mean()
    index = index.groupby(group_dim) / smooth_data_grouped.std()
    index = index.sel(time=index_slice)
    index = smooth(index, 3)
    index = index.where(index < threshold)

    # Compute severity and anomaly drought index
    severity = index.sum("time")
    severity = severity.where(severity)
    severity.attrs = {"long_name": "Severity", "units": "1"}
    index = index.min("time")
    index.attrs = {"long_name": "Anomaly drought index", "units": "1"}
    return xr.Dataset({"sma": index, "severity": severity})

## Download and transform

In [None]:
datasets = {}
for product, (collection_id, request) in requests.items():
    datasets[product] = download.download_and_transform(
        collection_id,
        download.update_request_date(request, start, stop, stringify_dates=True),
        chunks={"year": 1, "month": 1},
        transform_func=compute_drought_index,
        transform_func_kwargs={
            "index_slice": index_slice,
            "lon_slice": slice(area[1], area[3]),
            "lat_slice": slice(area[0], area[2]),
            "threshold": threshold,
        },
        transform_chunks=False,
    )

## Plot maps

In [None]:
def plot_map(da, colors, levels, **kwargs):
    cmap = mcolors.ListedColormap(colors)
    norm = mcolors.BoundaryNorm(levels, cmap.N)
    return plot.projected_map(da, levels=levels, cmap=cmap, norm=norm, **kwargs)


plot_kwargs = {
    "sma": {
        "colors": ["#fe0000", "#fc7f01", "#ff9f00", "#febd01", "#fee819", "#e4ff7a"],
        "levels": [-8.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.5],
    },
    "severity": {
        "colors": ["#fe0000", "#fc7f01", "#ff9f00", "#febd01", "#fee819", "#e4ff7a"],
        "levels": [-300, -250, -200, -150, -100, -50, 0],
    },
}

for product, ds in datasets.items():
    for variable, kwargs in plot_kwargs.items():
        plot_map(ds[variable], **kwargs)
        plt.suptitle(f"{variable = }\n{product = }")
        plt.show()