# Dataset Creation
This is used to generate the benchmark data for the EUPPBench Station Post Processing Benchmark, containint a Reforecast to Reforecst (*R2R*)and Reforecast to Forecast (*R2F*) task.

## Prerequisites
Download the [EUPPBench Dataset](https://zenodo.org/records/7708362), unzip it 
```shell
unzip EUPPBench-stations.zip
rm EUPPBench-stations.zip
```

## Data Split
Reforecasts:
- Train: [1997-2013]
- Test: [2014-2017]

Forecasts:
- Test: [2017-2018]

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
from typing import Tuple, Union, List

In [2]:
class ZarrLoader:
    """
    A class for loading data from Zarr files.

    Args:
        data_path (str): The path to the data directory.

    Attributes:
        data_path (str): The path to the data directory.
        countries (List[str]): The list of countries to load data for.
        features (List[str]): The list of features to load.

    Methods:
        get_stations(arr: xr.Dataset) -> pd.DataFrame:
            Get the stations information from the dataset.

        load_data(countries: Union[str, List[str]] = "all",
        features: Union[str, List[str]] = "all")
        -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]:
            Load the data from Zarr files.

        validate_stations() -> bool:
            Validate if the station IDs match between forecasts and reforecasts.
    """

    def __init__(self, data_path: str) -> None:
        self.data_path = data_path

    def get_stations(self, arr: xr.Dataset) -> pd.DataFrame:
        """
        Get the stations information from the dataset.

        Args:
            arr (xr.Dataset): The dataset containing station information.

        Returns:
            pd.DataFrame: The dataframe containing station information.
        """
        stations = pd.DataFrame(
            {
                "station_id": arr.station_id.values,
                "lat": arr.station_latitude.values,
                "lon": arr.station_longitude.values,
                "altitude": arr.station_altitude.values,
                "name": arr.station_name.values,
            }
        )
        stations = stations.sort_values("station_id").reset_index(drop=True)
        return stations
    
    def _fix_time(self, ds: xr.Dataset) -> xr.Dataset:
        """
        Fixes the time dimension of the given xr dataset by subtracting the year values from the time values.
        Use for reforecasts data.
        
        Args:
            ds (xr.Dataset): The input dataset with a 'year' and 'time' dimension.

        Returns:
            xr.Dataset: The modified dataset with a new time dimension.

        """
        new_times = []
        for year in ds.year.values:
            # Convert to Pandas DatetimeIndex
            dates_pd = pd.DatetimeIndex(ds.time.values)
            # Subtract years
            dates_subtracted = dates_pd - pd.DateOffset(years=year)
            new_times.append(dates_subtracted.values)

        new_times = np.concatenate(new_times)
        # Create a new dataset with the combined time dimension
        ds_new = ds.stack(new_time=('year', 'time'))
        # Assign the new time values
        ds_new = ds_new.drop_vars(['new_time', 'year', 'time'])
        ds_new = ds_new.assign_coords(new_time=new_times)
        # Sort the dataset by the new time coordinate
        ds_new = ds_new.sortby('new_time')
        # Rename 'new_time' to 'time'
        ds_new = ds_new.rename({'new_time': 'time'})
        return ds_new

    def load_data(
        self, countries: Union[str, List[str]] = "all", features: Union[str, List[str]] = "all"
    ) -> Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]:
        """
        Load data for the specified lead time, countries, and features.

        Args:
            countries (Union[str, List[str]]): The countries for which to load the data. Default is "all".
            features (Union[str, List[str]]): The features to load. Default is "all".

        Returns:
            Tuple[xr.Dataset, xr.Dataset, xr.Dataset, xr.Dataset]:
            A tuple containing the following datasets:
                - df_f: The forecasts dataset.
                - df_f_target: The targets for the forecasts dataset.
                - df_rf: The reforecasts dataset.
                - df_rf_target: The targets for the reforecasts dataset.
        """
        if countries == "all":
            print("[INFO] Loading data for all countries")
            self.countries = ["austria", "belgium", "france", "germany", "netherlands"]
        elif isinstance(countries, list):
            print(f"[INFO] Loading data for {countries}")
            self.countries = countries
        else:
            raise ValueError("countries must be a list of strings or 'all'")

        if features == "all":
            print("[INFO] Loading all features")
            self.features = ["number"] + [
                "station_id",
                "time",
                "cape",
                "model_orography",
                "sd",
                "station_altitude",
                "station_latitude",
                "station_longitude",
                "stl1",
                "swvl1",
                "t2m",
                "tcc",
                "tcw",
                "tcwv",
                "u10",
                "u100",
                "v10",
                "v100",
                "vis",
                "cp6",
                "mn2t6",
                "mx2t6",
                "p10fg6",
                "slhf6",
                "sshf6",
                "ssr6",
                "ssrd6",
                "str6",
                "strd6",
                "tp6",
                "z",
                "q",
                "u",
                "v",
                "t",
            ]
        elif isinstance(features, list):
            print(f"[INFO] Loading features: {features}")
            self.features = ["number"] + features
        else:
            raise ValueError("features must be a list of strings or 'all'")

        # Load Data from Zarr ####
        all_countries = {'forecasts': [], 'reforecasts': []}
        targets_all_countries = {'forecasts': [], 'reforecasts': []}
        
        forecasts_xrs = {}
        targets_xrs = {}
        
        for pred in ['forecasts', 'reforecasts']:
            print(f"[INFO] Loading {pred}")
            for country in self.countries:
                # Forecasts
                surface_xr = xr.open_zarr(f"{self.data_path}/stations_ensemble_{pred}_surface_{country}.zarr")
                surface_pp_xr = xr.open_zarr(f"{self.data_path}/stations_ensemble_{pred}_surface_postprocessed_{country}.zarr")
                pressure_500_xr = xr.open_zarr(f"{self.data_path}/stations_ensemble_{pred}_pressure_500_{country}.zarr")
                pressure_700_xr = xr.open_zarr(f"{self.data_path}/stations_ensemble_{pred}_pressure_700_{country}.zarr")
                pressure_850_xr = xr.open_zarr(f"{self.data_path}/stations_ensemble_{pred}_pressure_850_{country}.zarr")
                obs_xr = xr.open_zarr(f"{self.data_path}/stations_{pred}_observations_surface_{country}.zarr")
                
                forecasts = [surface_xr, surface_pp_xr, pressure_500_xr, pressure_700_xr, pressure_850_xr]
                forecasts = [forecast.drop_vars("valid_time").squeeze(drop=True) for forecast in forecasts]
                forecasts = xr.merge(forecasts)
                all_countries[pred].append(forecasts)
                
                targets = obs_xr.squeeze(drop=True)
                targets_all_countries[pred].append(targets)
            
            forecasts_xrs[pred] = xr.concat(all_countries[pred], dim="station_id")
            targets_xrs[pred] = xr.concat(targets_all_countries[pred], dim="station_id")
            
            if pred == 'reforecasts':
                tmp = self._fix_time(forecasts_xrs[pred])
                forecasts_xrs[pred] = tmp.transpose('station_id', 'number',  'time', 'step')
                tmp = self._fix_time(targets_xrs[pred])
                targets_xrs[pred] = tmp.transpose('station_id', 'time', 'step') 


        print(
            f"[INFO] Data loaded successfully. Forecasts shape: {forecasts_xrs['forecasts'].t2m.shape}, Reforecasts shape: {forecasts_xrs['reforecasts'].t2m.shape}"
        )

        # Extract Stations ####
        self.stations_f = self.get_stations(forecasts_xrs['forecasts'])
        self.stations_rf = self.get_stations(forecasts_xrs['reforecasts'])
        return forecasts_xrs, targets_xrs

    def validate_stations(self):
        return (self.stations_f.station_id == self.stations_rf.station_id).all()


In [3]:
# Assuming the data is stored in the 'data' directory
loader2 = ZarrLoader(data_path="data/EUPPBench-stations")
fc, targets = loader2.load_data(countries="all", features="all")

[INFO] Loading data for all countries
[INFO] Loading all features
[INFO] Loading forecasts
[INFO] Loading reforecasts
[INFO] Data loaded successfully. Forecasts shape: (122, 51, 730, 21), Reforecasts shape: (122, 11, 4180, 21)


In [4]:
# Training Data
# Slice excludes last element
train = fc["reforecasts"].sel(time=slice('1997-01-01', '2014-01-01'))  # [1997-2013]
test_rf = fc["reforecasts"].sel(time=slice('2014-01-01', '2018-01-01'))  # [2014-2017]
test_f = fc['forecasts'].sel(time=slice('2017-01-01', '2019-01-01'))  # [2017-2018]

# Targets
train_targets = targets["reforecasts"].sel(time=slice('1997-01-01', '2014-01-01'))  # [1997-2013]
test_rf_targets = targets["reforecasts"].sel(time=slice('2014-01-01', '2018-01-01'))  # [2014-2017]
test_f_targets = targets['forecasts'].sel(time=slice('2017-01-01', '2019-01-01'))  # [2017-2018]

# Save to disk
train.to_zarr("shared/train.zarr", mode="w")
test_rf.to_zarr("shared/test_rf.zarr", mode="w")
test_f.to_zarr("shared/test_f.zarr", mode="w")
train_targets.to_zarr("shared/train_targets.zarr", mode="w")
test_rf_targets.to_zarr("shared/test_rf_targets.zarr", mode="w")
test_f_targets.to_zarr("shared/test_f_targets.zarr", mode="w")

<xarray.backends.zarr.ZarrStore at 0x165475340>