# Drought index

## Import libraries

In [None]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "2000-01"
stop = "2022-12"
index_slice = slice("2022-06-01", "2022-09-30")

# Max value allowed
threshold = -1.5

# Space
area = [58, -10, 36, 30]

## Define requests

In [None]:
requests = {
    "Reanalysis": (
        "derived-era5-single-levels-daily-statistics",
        {
            "product_type": "reanalysis",
            "variable": ["volumetric_soil_water_layer_1", "land_sea_mask"],
            "daily_statistic": "daily_mean",
            "time_zone": "utc+00:00",
            "frequency": "1_hourly",
            "area": area,
        },
    ),
    "Satellite": (
        "satellite-soil-moisture",
        {
            "variable": ["volumetric_surface_soil_moisture"],
            "type_of_sensor": ["combined_passive_and_active"],
            "time_aggregation": ["day_average"],
            "type_of_record": ["cdr"],
            "version": ["v202312"],
        },
    ),
}
target_grid_request = (
    "reanalysis-era5-single-levels",
    {
        "product_type": ["reanalysis"],
        "variable": ["land_sea_mask"],
        "year": ["1940"],
        "month": ["01"],
        "day": ["01"],
        "time": ["00:00"],
        "data_format": "grib",
        "download_format": "unarchived",
        "area": area,
    },
)

## Define functions to cache

In [None]:
def smooth(obj, window):
    obj = obj.chunk(time=-1)
    return obj.interpolate_na("time").rolling(time=window, min_periods=1).mean()


def compute_anomaly_drought_index(ds, threshold, target_grid_request, **xesmf_kwargs):
    # Get raw data
    (var_name,) = set(ds.data_vars) & {"sm", "swvl1"}
    raw_data = ds[var_name]

    # Mask
    if (lsm := ds.get("lsm")) is not None:
        raw_data = raw_data.where((lsm > 0.5).all("time"))

    # Interpolate
    if target_grid_request:
        grid_out = download.download_and_transform(
            *target_grid_request, invalidate_cache=False
        )
        grid_out = grid_out[["latitude", "longitude"]]
        grid_out = grid_out.drop_vars(set(grid_out.variables) - set(grid_out.dims))
        raw_data = diagnostics.regrid(raw_data, grid_out, **xesmf_kwargs)
    else:
        assert not xesmf_kwargs

    # Get time-varying index
    group_dim = "time.dayofyear"
    smooth_data_grouped = smooth(raw_data, 11).groupby(group_dim)
    sma = raw_data.groupby(group_dim) - smooth_data_grouped.mean()
    sma = sma.groupby(group_dim) / smooth_data_grouped.std()
    sma = smooth(sma, 3)
    sma = sma.where(sma < threshold)
    sma.attrs = {"long_name": "Anomaly drought index", "units": "1"}
    return sma.rename("sma")


def compute_severity(sma):
    severity = sma.sum("time")
    severity = severity.where(severity)
    severity.attrs = {"long_name": "Severity", "units": "1"}
    return severity.rename("severity")


def compute_timeseries(ds, threshold, target_grid_request, **xesmf_kwargs):
    da = compute_anomaly_drought_index(
        ds, threshold=threshold, target_grid_request=target_grid_request, **xesmf_kwargs
    )
    return diagnostics.spatial_weighted_mean(da).to_dataset()


def compute_maps(ds, threshold, index_slice, target_grid_request, **xesmf_kwargs):
    sma = compute_anomaly_drought_index(
        ds, threshold, target_grid_request, **xesmf_kwargs
    )
    sma = sma.sel(time=index_slice)
    severity = compute_severity(sma)
    return xr.merge([sma.min("time", keep_attrs=True), severity])

## Download and transform

In [None]:
datasets_maps = []
datasets_timeseries = []
for product, (collection_id, request) in requests.items():
    print(f"{product = }")
    request = download.update_request_date(request, start, stop, stringify_dates=True)
    kwargs = {"threshold": threshold, "target_grid_request": None}
    if product == "Satellite":
        kwargs["target_grid_request"] = target_grid_request
        kwargs["method"] = "conservative"

    # Map
    ds = download.download_and_transform(
        collection_id,
        request,
        chunks={"year": 1, "month": 1},
        transform_func=compute_maps,
        transform_func_kwargs=kwargs | {"index_slice": index_slice},
        transform_chunks=False,
    )
    datasets_maps.append(ds.expand_dims(product=[product]))

    # Timeseries
    ds = download.download_and_transform(
        collection_id,
        request,
        chunks={"year": 1, "month": 1},
        transform_func=compute_timeseries,
        transform_func_kwargs=kwargs,
        transform_chunks=False,
    )
    datasets_timeseries.append(ds.expand_dims(product=[product]))
ds_maps = xr.combine_by_coords(datasets_maps)
ds_timeseries = xr.combine_by_coords(datasets_timeseries)

## Plot timeseries

In [None]:
ds_timeseries["sma"].plot(hue="product")
plt.grid()

## Plot maps

In [None]:
for variable, da in ds_maps.data_vars.items():
    # Products
    match variable:
        case "sma":
            colors = ["#fe0000", "#fc7f01", "#ff9f00", "#febd01", "#fee819", "#e4ff7a"]
            levels = [-8.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.5]
        case "severity":
            colors = ["#fe0000", "#fc7f01", "#ff9f00", "#febd01", "#fee819", "#e4ff7a"]
            levels = [-300, -250, -200, -150, -100, -50, 0]
        case _:
            raise NotImplementedError(f"{variable = }")
    cmap = mcolors.ListedColormap(colors)
    norm = mcolors.BoundaryNorm(levels, cmap.N)
    plot.projected_map(da, levels=levels, cmap=cmap, norm=norm, col="product")
    plt.show()

    # Bias
    with xr.set_options(keep_attrs=True):
        bias = da.diff("product").drop_vars("product")
    bias.attrs["long_name"] = "Bias of " + bias.long_name
    plot.projected_map(bias)
    plt.title(" - ".join(da["product"].values.tolist()[::-1]))
    plt.show()