# Drought index

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "2000-01"
stop = "2022-12"
index_slice = slice("2022-06-01", "2022-09-30")

# Space
area = [58, -10, 36, 30]

## Define requests

In [None]:
requests = {
    "Reanalysis": [
        "derived-era5-single-levels-daily-statistics",
        {
            "product_type": "reanalysis",
            "variable": ["volumetric_soil_water_layer_1"],
            "daily_statistic": "daily_mean",
            "time_zone": "utc+00:00",
            "frequency": "1_hourly",
            "area": area,
        },
    ],
    "Satellite": [
        "satellite-soil-moisture",
        {
            "variable": ["volumetric_surface_soil_moisture"],
            "type_of_sensor": ["combined_passive_and_active"],
            "time_aggregation": ["day_average"],
            "type_of_record": ["cdr"],
            "version": ["v202312"],
        },
    ],
}

## Define functions to cache

In [None]:
def smooth(obj, window):
    return obj.interpolate_na("time").rolling(time=window, min_periods=1).mean()


def compute_drought_index(ds, index_slice, lon_slice, lat_slice):
    # Get raw data
    (var_name,) = set(ds.data_vars) & {"sm", "swvl1"}
    raw_data = utils.regionalise(ds[var_name], lon_slice=lon_slice, lat_slice=lat_slice)
    raw_data = raw_data.chunk(time=-1)

    # Compute index
    group_dim = "time.dayofyear"
    smooth_data_grouped = smooth(raw_data, 11).groupby(group_dim)
    index = raw_data.groupby(group_dim) - smooth_data_grouped.mean()
    index = index.groupby(group_dim) / smooth_data_grouped.std()
    index = smooth(index.sel(time=index_slice), 3)

    # Attributes
    index.attrs = {"long_name": "Drought index"}
    return index.to_dataset(name="drought_index")

## Download and transform

In [None]:
datasets = {}
for product, (collection_id, request) in requests.items():
    datasets[product] = download.download_and_transform(
        collection_id,
        download.update_request_date(request, start, stop, stringify_dates=True),
        chunks={"year": 1, "month": 1},
        transform_func=compute_drought_index,
        transform_func_kwargs={
            "index_slice": index_slice,
            "lon_slice": slice(area[1], area[3]),
            "lat_slice": slice(area[0], area[2]),
        },
        transform_chunks=False,
    )
ds_timeseries = xr.combine_by_coords(
    [
        diagnostics.spatial_weighted_mean(ds).expand_dims(product=[product])
        for product, ds in datasets.items()
    ]
)

## Quick and dirty plot: Timeseries

In [None]:
ds_timeseries["drought_index"].plot(hue="product")
plt.grid()

## Quick and dirty plot: Maps

In [None]:
for product, ds in datasets.items():
    plot.projected_map(ds["drought_index"].mean("time", keep_attrs=True))
    plt.suptitle(f"{product = }")
    plt.show()