## Tutorial: European Winterstorm Return Period Map


In this notebook, we download the reanalysis data for European windstorm footprints and build the return period map by fitting a generaliezd Pareto distribution at each pixel. We use dask to fit the distributions in parallel, and apply some simple logic to clean up the data and to select the high thresholds above which the GPD can be considered valid.

To set up the Copernicus CDSAPI, see instructions [here](https://cds.climate.copernicus.eu/how-to-api).

In [8]:
import os
import re
from typing import List, Union, Tuple, Optional
from glob import glob
from pathlib import Path
import zipfile
import tempfile

import numpy as np
import pandas as pd
from requests import HTTPError

import cdsapi
import xarray as xr
from dask.distributed import Client, LocalCluster

from caseva.models import ThresholdExcessModel

In [10]:
PATH_ROOT = Path().resolve()
PATH_DATA = PATH_ROOT / "data"
PATH_DATA.mkdir(exist_ok=True)

COPERNICUS_DATASET = "sis-european-wind-storm-indicators"
COPERNICUS_PRODUCT = "windstorm_footprints"

LAT_CHNKS = 152   # The chunk size should divide the data dimensions evenly 
LON_CHNKS = 232
EVENT_CHNKS = -1  # No chunking over the event/time dimension

RETURN_PERIODS = np.array([10, 50, 100], dtype=float)

# Workers for fitting the distributions in parallel
N_WORKERS = 8
THREADS_PER_WORKER = 1

# Intermediate results and output names.
OUTPUT_VARNAME = "wind_footprint"
EVENT_SET_FILENAME = "wind_footprints.zarr"
RP_MAP_FILENAME = "return_period_map.zarr"

# Values above this value are considered erroneous
MAX_WIND_SPEED_MPS = 150

# If a pixel has fewer data points than this, the fitting is skipped.
MIN_NUM_DATA_TO_FIT = 15

## Prepare data requests

In [7]:
def pad(num: int) -> str:
    """Integer formatting compatible with the API request syntax.
    
    Examples
    --------
    pad(2) -> "02"
    pad(10) -> "10"
    """
    return f"{num:02d}"

# The 9th day of the month is missing from the dataset (= no events).
days = [pad(i) for i in range(1, 32) if  i != 9]
winter_months = [pad(i) for i in [1, 2, 3, 10, 11 , 12]]

# Some years have no data records (= no events occurring)
missing_yrs = [2003, 2004, 2010, 2018, 2019]
years = [str(yr) for yr in range(1979, 2022) if yr not in missing_yrs]

In [4]:
def download_data_files(
    years: List[str],
    months: List[str],
    days: List[str],
    save_dir: Union[str, Path]
) -> None:
    """Fetch wind footprint data files from the Copernicus API.
    
    Parameters
    ----------
    years : list of str
        List of years to fetch.
    months : list of str
        List of months to fetch.
    days : list of str
        List of days to fetch.
    save_dir : str | Path
        Directory for storing the temporary raw data files.
    """
    client = cdsapi.Client()

    for year in years:

        print(f"Processing year {year}...")
        for month in months:

            request = {
                "product": [COPERNICUS_PRODUCT],
                "variable": "all",
                "year": [year],
                "month": [month],
                "day": days,
            }

            target = os.path.join(save_dir, f"{year}_{month}.zip")

            try:
                client.retrieve(COPERNICUS_DATASET, request, target)
            except HTTPError:
                print(f"Download failed for {target}. Likely missing data.")


## Prepare data processing steps

In [5]:
def _extract_timestamp(fname: str) -> pd.Timestamp:
    """Extract timestamp from Copernicus filename.

    The Copernicus files have an 8-digit sequence for YYYMMDD.
    
    Parameters
    ----------
    fname : str
        Name of the netcdf file to process.

    Returns
    -------
    pd.Timestamp
        Date corresponding to the wind data of that file.
    """
    time_str = re.search(r'(\d{8})', fname).group(1)
    return pd.to_datetime(time_str)

In [6]:
def _infer_spatial_dims(ds: Union[xr.Dataset, xr.DataArray]) -> Tuple[str, str]:
    """Infer spatial dimension names from xarray data from fixed candidates.
    """

    lon_candidates = ["longitude", "lon", "lng", "x"]
    lat_candidates = ["latitude", "lat", "y"]
    
    lower_dims = {dim.lower(): dim for dim in ds.dims}

    lon_dims = [lower_dims[dim] for dim in lower_dims if dim in lon_candidates]
    lat_dims = [lower_dims[dim] for dim in lower_dims if dim in lat_candidates]

    if len(lon_dims) != 1 or len(lat_dims) != 1:
        raise ValueError("Could not find unique lon/lat names.")
    
    return (lon_dims[0], lat_dims[0])

In [7]:
def preprocess(ds: xr.Dataset) -> xr.Dataset:
    """Preprocess `ds` before it is read with the xr.open_mfdataset.
    
    Steps include:
    1. Extract time stamp from the input filename.
    2. Drop redundant coordinates and data variables, if any.
    3. Rename spatial coordinates to "lat" and "lon".
    4. Rename data variable to "wind_footprint".
    """

    # Parse time coordinate stamp from the .nc filename.
    filename = ds.encoding['source']
    timestamp = _extract_timestamp(filename)

    # Redundant coords to drop.
    drop_coords = [coord for coord in ds.coords if coord in ["time", "z"]]
    for coord in drop_coords:
        ds = ds.drop_vars(coord)

    # `varname` is sometimes 'FX', 'max_wind_speed', etc...
    varname = list(ds.data_vars)[0]
    xname, yname = _infer_spatial_dims(ds)
    return (
        ds
        .squeeze()
        .expand_dims(date=[timestamp])
        .rename({varname: OUTPUT_VARNAME, xname: "lon", yname: "lat"})
    )

## Download the datasets

In [None]:
with tempfile.TemporaryDirectory() as tmpdir:

    # Load zip files from Copernicus data store.
    download_data_files(years, winter_months, days, save_dir=tmpdir)

    # Unpack all zipped netcdf files.
    zipfiles = glob(os.path.join(tmpdir, "*.zip"))
    for zfile in zipfiles:
        with zipfile.ZipFile(zfile, "r") as file_ref:
            file_ref.extractall(tmpdir)

    # Combine all netcdf files into a single dataset.
    ncfiles = glob(os.path.join(tmpdir, "*.nc"))

    # Note: chunking is done before renaming. Also, chunking is done on a
    # file-by-file basis, meaning that the "event" dimension (corresponding
    # to each netcdf file needs to be chunked separately afterwards).

    ds = xr.open_mfdataset(
        ncfiles,
        preprocess=preprocess,
        chunks={"Latitude": LAT_CHNKS, "Longitude": LON_CHNKS}
    )[OUTPUT_VARNAME].chunk({"event": EVENT_CHNKS})

    # Write to a zarr file. 
    ds.to_zarr(PATH_DATA / EVENT_SET_FILENAME, compute=True, mode="w")

### First glance at the full dataset

In [8]:
ds = xr.open_zarr(PATH_DATA / EVENT_SET_FILENAME)[OUTPUT_VARNAME]
ds = ds.sortby("event")

In [None]:
ds.max("event").plot(figsize=(10, 6))

## Fit the extreme value distributions

### Pre-processing

A key assumption is that the extreme data are independent. Therefore, to rule out cases where multiple high values are caused by the same storm track, we add a simple time window threshold: If there are multiple entries for any 3-day window, we only keep the first one.

In [10]:
time_diffs = ds.event.diff("event")
is_clustered_event = time_diffs < np.timedelta64(3, "D")

first_event_mask = ds.event[0].copy(data=False)
exclusion_mask = xr.concat([first_event_mask, is_clustered_event], dim="event")

ds = ds.where(~exclusion_mask, drop=True)

all_years = pd.DatetimeIndex(ds.event).year
num_years = all_years.max() - all_years.min() + 1

In [11]:
def try_resolve_corner_solution(
    data: np.ndarray,
    init_thresh: float,
    num_years: int,
    step_size: float = 0.1,
    max_attempts: int = 10,
    min_data_size: int = 10,
) -> Optional[ThresholdExcessModel]:
    """Re-fit the model by adjusting the threshold upward and then downward
    to try avoiding corner solutions.

    This is a very ad-hoc way of automated threshold selection.

    Parameters
    ----------
    data : array-like
        The data to be fit by the model.
    init_thresh : float
        The initial threshold value.
    num_years : int
        Number of years for the model fit.
    step_size : float, optional, default=0.1
        The fractional increment/decrement applied to the threshold.
    max_attempts : int, optional, default=10
        The maximum number of attempts in each direction.
    min_data_size : int, optional, default=10
        The minimum number of data points above threshold required to fit.

    Returns
    -------
    model : ThresholdExcessModel or None
        A fitted model if a corner solution is successfully resolved; 
        otherwise None.
    """

    model = ThresholdExcessModel()
    
    # 1) Try increasing the threshold
    threshold = init_thresh
    for _ in range(max_attempts):

        threshold *= (1 + step_size)

        # If too few data points remain.
        if data[data > threshold].size < min_data_size:
            break

        model.fit(data=data, threshold=threshold, num_years=num_years)

        if not model.optimizer.is_corner_solution:
            return model
    
    # 2) Try decreasing the threshold
    threshold = init_thresh
    for _ in range(max_attempts):
        threshold *= (1 - step_size)

        model.fit(data=data, threshold=threshold, num_years=num_years)

        if not model.optimizer.is_corner_solution:
            return model
    
    # Could not resolve the corner solution
    return None

In [12]:
def get_sqrt_n_threshold(data: np.ndarray) -> float:
    """Get the threshold leading to sqrt(n) values treated as extremes."""
    k = int(np.ceil(np.sqrt(data.size)))
    tail_indices = np.argpartition(data, -k)[-k:]

    return data[tail_indices].min() - 1e-6

In [13]:
def fit_gpd(
    data: np.ndarray,
    num_years: int,
    resolve_corners=True,
    debug: bool = False
) -> np.ndarray:
    """Fit a generalized Pareto distribution to a 1d data array.
    
    Parameters
    ----------
    data : np.ndarray
        Data array corresponding to a time series of one pixel on the map.
    num_years : int
        The number of years the `data` corresponds to.
    resolve_corners : True
        Whether to try resolving corner solutions of the optimizer (often a 
        sign of a poor-quality fit) by adjusting the threshold up and down.
    debug : bool, default=False
        Whether to return a diagnostics plot of the fitting.

    Returns
    -------
    np.ndarray
        An array containing the requested return periods.
    """

    if np.isnan(data).all():
        return np.full_like(RETURN_PERIODS, np.nan)
    elif data[data > 0].size < MIN_NUM_DATA_TO_FIT:
        return np.zeros_like(RETURN_PERIODS)
    
    model = ThresholdExcessModel()
    threshold = get_sqrt_n_threshold(data)

    try:
        model.fit(data=data, threshold=threshold, num_years=num_years)

        if resolve_corners and model.optimizer.is_corner_solution:

            new_model = try_resolve_corner_solution(data, threshold, num_years)

            if new_model is not None:
                model = new_model
                print("Resolved a corner solution.")

        if debug:
            model.diagnostic_plot()
        return model.return_level(RETURN_PERIODS)["level"]
    
    except ValueError:
        return np.full_like(RETURN_PERIODS, np.nan)

### Test fit to a single pixel on the map

In [None]:
test_data = ds.sel(lon=10, lat=45, method="nearest").compute()
fit_gpd(test_data.values, num_years=num_years, debug=True)

In [15]:
def fit_and_eval_rp(ds: xr.Dataset, num_years: int) -> np.ndarray:
    """Fit the distribution to a coarse chunked dataset in parallel."""

    fits = xr.apply_ufunc(
        fit_gpd,
        ds,
        num_years,
        input_core_dims=[["event"], []],
        output_core_dims=[["return_period"]],
        dask="parallelized",  # Process chunks in parallel
        vectorize=True,       # Apply the fit to each (lon/lat) within chunk
        dask_gufunc_kwargs={
            "output_sizes": {"return_period": len(RETURN_PERIODS)}
        },
    )

    fits = (
        fits
        .assign_coords({"return_period": RETURN_PERIODS})
        .rename({"return_period": "event"})
    )

    return fits

### Run the fitting

Set up a small local cluster to fit the distributions in parallel.

In [None]:
with LocalCluster(
    n_workers=N_WORKERS, threads_per_worker=THREADS_PER_WORKER
) as cluster:
    with Client(cluster) as client:

        fit_and_eval_rp(ds, num_years).to_zarr(
            RP_MAP_FILENAME,
            mode="w",
            compute=True
        )

## Post-processing

An indication of a poor fit is that the return levels are extremely high even for moderate return periods. Therefore, we first identify if there are any such pixels, try to replace them with valid neighbor values, but if none are available, we set the pixel to a NaN value.

In [49]:
def maybe_repalce_high_vals(chnk):
    """Process 3d chunks by replacing high values with neighbor vals or NaNs"""

    maxevent = chnk.max("event")

    if maxevent.max() < MAX_WIND_SPEED_MPS or np.isnan(maxevent).all():
        return chnk

    high_vals = (
        maxevent
        .where(maxevent >= MAX_WIND_SPEED_MPS, drop=True)
        .stack(point=("lat", "lon"))
    )

    for point in high_vals:

        lon_indx = chnk.get_index("lon").get_loc(point.lon.item())
        lat_indx = chnk.get_index("lat").get_loc(point.lat.item())

        neighbors = []
        # loop over neighbors
        for nx in range(max(lon_indx-1, 0), min(lon_indx+2, chnk.lon.size)):
            for ny in range(max(lat_indx-1, 0), min(lat_indx+2, chnk.lat.size)):

                if nx == ny == 0:
                    continue

                neighbors.append((nx, ny))

        neighbor_replacement_found = False
        for nghbr in neighbors:

            neighbor_val = chnk.isel(lon=nghbr[0], lat=nghbr[1]).compute()
            if 0 < neighbor_val.max() < MAX_WIND_SPEED_MPS:

                chnk.loc[{
                    "lat": point.lat.item(),
                    "lon": point.lon.item()
                }] = neighbor_val

                neighbor_replacement_found = True
                break

        # Set unresolved ones to zero.
        if not neighbor_replacement_found:
            chnk.loc[{
                "lat": point.lat.item(),
                "lon": point.lon.item()
            }] = np.nan

    return chnk


In [53]:
ds = xr.open_zarr(RP_MAP_FILENAME)[OUTPUT_VARNAME]

with LocalCluster(
    n_workers=N_WORKERS, threads_per_worker=THREADS_PER_WORKER
) as cluster:
    
    template = xr.zeros_like(ds)
    with Client(cluster) as client:
        ds_final = (
            xr.map_blocks(maybe_repalce_high_vals, ds, template=template)
            .sel(event=100)
            .compute()
        )

## Plot the final result

In [None]:
ds_final.plot(figsize=(10, 6))