In [1]:
import dataclasses
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union, Literal
import xarray as xr
import numpy as np
import dacite
from scores.continuous import mae, rmse
import polars as pl
import pandas as pd
from extremeweatherbench import case, metrics, regions, utils
import copy
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [2]:
!uv pip install sparse

[2mUsing Python 3.13.1 environment at: /Users/taylor/code/ExtremeWeatherBench/.venv[0m
[2mAudited [1m1 package[0m [2min 20ms[0m[0m


# Notes

1. An easy? solution to dealing with variable mapping:
    - Create a list of the names of each variable used in EWB that users can extend
    - Each Observation should have a default mapping; if it's our observation default we will know the variable names coming in. Users will need to know the names of theirs as well.
    - Forecasts will be able to access this mapping as well and can be provided in the ExtremeWeatherBench call
2. ~~AppliedMetrics will define exactly what observations, observation variables, and forecast variables to use~~
    - I think this is wrong now, it should be defined at the Event level.
3. Defining the order of orchestration is still an open question. I guess this will come as I wire the components together... We're running each case as its own entity, thus:
    - Each case will have multiple metrics that might or might not reuse Observations and Forecasts. Do metrics inside cases define the variables and DerivedVariables? Or should the EventType? Leaning towards EventType. DerivedVariables hold information on what core variables are needed.
        1. Observation(s) should be built first which defines the dimensions for Forecasts. **DONE**
        2. Observation(s) derived variables should be processed next. **DONE**
        3. Forecasts should then be spatiotemporally subset to the Observation(s).
        4. Forecasts derived variables should then be processed.
        5. Might be worth then temporally subsetting the Observation(s) to the Forecasts
        6. Finally, run applied metrics that use the Observation(s) and Forecasts.


# Misc:

In [3]:
def convert_longitude_to_180(longitude: float) -> float:
    """Convert longitude from 0-360 to -180-180 degrees."""
    return (longitude + 180) % 360 - 180

def extract_coordinates_from_sparse_coo(dataset: xr.Dataset, data_var: str = 'report_type') -> pd.DataFrame:
    """
    Extract latitude and longitude pairs from a sparse COO array without densifying.
    """
    # Get the sparse COO array
    sparse_array = dataset[data_var].data
    try:
        # Get the coordinates from the COO array
        coords = sparse_array.coords
    except AttributeError:
        print(f"Warning: {data_var} is not a sparse COO array.")
    
    # For dimensions (valid_time, latitude, longitude):
    # coords[0] = valid_time indices
    # coords[1] = latitude indices  
    # coords[2] = longitude indices
    
    # Extract latitude and longitude indices
    lat_indices = coords[1]  # Second dimension
    lon_indices = coords[2]  # Third dimension
    
    # Map indices to actual coordinate values
    lats = dataset.latitude.values[lat_indices]
    lons = dataset.longitude.values[lon_indices]
    
    # Create DataFrame with unique coordinate pairs
    coords_df = pd.DataFrame({
        'latitude': lats,
        'longitude': lons
    }).drop_duplicates()
    
    return coords_df

@dataclasses.dataclass
class LocationCoords:
    latitude: xr.DataArray
    longitude: xr.DataArray

def practically_perfect_hindcast(
    ds: xr.Dataset,
    output_bounds: regions.Region,
    resolution: float = 0.25,
    report_type: Union[Literal["all"], list[Literal["tor", "hail", "wind"]]] = "all",
    sigma: float = 1.5,
) -> xr.Dataset:
    """Compute the Practically Perfect Hindcast (PPH) using storm report data using latitude/longitude grid spacing
    instead of the NCEP 212 Eta Lambert Conformal projection; based on the method described in Hitchens et al 2013,
    https://doi.org/10.1175/WAF-D-12-00113.1

    Args:
        ds: An xarray Dataset containing the storm report data as a sparse (COO) array.
        output_interpolation: A LocationCoords object containing the latitude and longitude of the output interpolation.
        resolution: The resolution of the grid to use. Default is 0.25 degrees.
        report_type: The type of report to use. Default is all. Currently only supports all.
        sigma: The sigma (standard deviation) of the gaussian filter to use. Default is 1.5.
    Returns:
        pph: An xarray DataArray containing the PPH around the storm report data.
    """
    coords_df = extract_coordinates_from_sparse_coo(ds)
    coords_df['longitude'] = utils.convert_longitude_to_360(coords_df['longitude'])
    valid_lats = coords_df['latitude'].values
    valid_lons = coords_df['longitude'].values

    # Create coordinates from region, not reports, fixed at 0.25 degree intervals
    min_lat_fixed = np.ceil(output_bounds.latitude_min * 4) / 4  # Round up to nearest 0.25
    max_lat_fixed = np.floor(output_bounds.latitude_max * 4) / 4  # Round down to nearest 0.25
    min_lon_fixed = np.ceil(output_bounds.longitude_min * 4) / 4  # Round up to nearest 0.25
    max_lon_fixed = np.floor(output_bounds.longitude_max * 4) / 4  # Round down to nearest 0.25

    # Create the grid coordinates
    grid_lats = np.arange(min_lat_fixed, max_lat_fixed + resolution, resolution)
    grid_lons = np.arange(min_lon_fixed, max_lon_fixed + resolution, resolution)

    # Initialize an empty grid
    grid = np.zeros((len(grid_lats), len(grid_lons)))

    # Mark grid cells that contain reports
    for lat, lon in zip(valid_lats, valid_lons):
        # Skip reports that are outside the grid bounds
        if lat < min_lat_fixed or lat > max_lat_fixed or lon < min_lon_fixed or lon > max_lon_fixed:
            continue
        # Find the nearest grid indices
        lat_idx = np.abs(grid_lats - lat).argmin()
        lon_idx = np.abs(grid_lons - lon).argmin()
        grid[lat_idx, lon_idx] = 1

    pph_ds = xr.Dataset(
        data_vars={"reports": (["latitude", "longitude"], grid)},
        coords={"latitude": grid_lats, "longitude": grid_lons},
    )

    # Apply bilinear interpolation to smooth the field
    # First, create a gaussian kernel for smoothing
    smoothed_grid = gaussian_filter(grid, sigma=sigma)

    # Combine the data into a Dataset
    pph_ds['practically_perfect'] = xr.DataArray(smoothed_grid, dims=["latitude", "longitude"])

    return pph_ds

def plot_practically_perfect_hindcast(test_output, ax=None):
    if ax is None:
        fig = plt.figure(figsize=(12, 8))
        ax = plt.axes(projection=ccrs.PlateCarree())

    # Add map features
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.STATES, linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)

    try:
        # Get reports data
        reports = test_output[0].reports
    except KeyError as e:
        if "reports" in str(e):
            reports = test_output.reports
        else:
            raise e

    # Create practically perfect forecast from reports using Gaussian smoothing
    reports_array = reports.values
    sigma = 1.5  # Smoothing parameter - adjust as needed
    practically_perfect = gaussian_filter(reports_array, sigma=sigma)

    # Create contour plot of practically perfect forecast
    contour_levels = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
    contour = ax.contourf(reports.longitude, reports.latitude, practically_perfect, 
                         levels=contour_levels, cmap='plasma', alpha=0.7,
                         transform=ccrs.PlateCarree())

    # Add colorbar
    cbar = plt.colorbar(contour, ax=ax, label='Percentile')

    # Also plot the original report locations as small dots
    lats, lons = np.where(reports > 0)
    lat_values = reports.latitude.values[lats]
    lon_values = reports.longitude.values[lons]
    ax.scatter(lon_values, lat_values, c='black', s=10, alpha=0.5,
               transform=ccrs.PlateCarree(), marker='.')

    # Set extent based on the dataset coordinates
    lon_min, lon_max = reports.longitude.min().item(), reports.longitude.max().item()
    lat_min, lat_max = reports.latitude.min().item(), reports.latitude.max().item()
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    ax.set_title('Practically Perfect Hindcast from Observation Reports', loc='left')
    return ax

# EWB Variable Names

In [4]:
EWB_VARIABLES = {
    'data_vars': 
                 [
                     'surface_air_temperature',
                     'air_pressure_at_mean_sea_level',
                     'surface_air_pressure',
                     'surface_wind_speed',
                     'surface_wind_from_direction',
                     'surface_air_temperature',
                     'surface_dew_point_temperature',
                     'surface_relative_humidity',
                     'surface_eastward_wind',
                     'surface_northward_wind',
                     'accumulated_1_hour_precipitation',
                     'pressure_level',
                     'air_temperature',
                     'dewpoint_temperature',
                     'relative_humidity',
                     'specific_humidity',
                     'geopotential',
                     'geopotential_height',
                     'potential_temperature',
                     'vertical_velocity',
                     'eastward_wind',
                     'northward_wind',
                     ],
                     'coords': [
                 'valid_time',
                 'init_time',
                 'lead_time',
                 'latitude',
                 'longitude',
                 'elevation',
                 'station_id',
                 'station_long_name',
                 'case_id'
                 ]
                 }

# Create mapping from EWB variable names to ERA5 variable names
ERA5_MAPPING = {
    'surface_air_temperature': '2m_temperature',
    'surface_dew_point_temperature': '2m_dewpoint_temperature',
    'air_pressure_at_mean_sea_level': 'mean_sea_level_pressure',
    'surface_air_pressure': 'surface_pressure',
    'accumulated_1_hour_precipitation': 'total_precipitation',
    'air_temperature': 'temperature',
    'relative_humidity': 'relative_humidity',
    'specific_humidity': 'specific_humidity',
    'geopotential': 'geopotential',
    'geopotential_height': 'geopotential_height',
    'potential_temperature': 'potential_temperature',
    'vertical_velocity': 'vertical_velocity',
    'eastward_wind': 'u_component_of_wind',
    'northward_wind': 'v_component_of_wind',
    'surface_eastward_wind': '10m_u_component_of_wind',
    'surface_northward_wind': '10m_v_component_of_wind'
}

# DerivedVariables

In [5]:
class DerivedVariable(ABC):
    """A base class defining the interface for ExtremeWeatherBench derived variables.
    
    A DerivedVariable is any variable that requires extra computation, not derived in an
    observation or forecast raw dataset. Some examples include the practically perfect hindcast,
    MLCAPE, IVT, or atmospheric river masks.
    
    Attributes:
        name: The name of the variable.
        input_variables: A list of variables that are used to compute the variable.
    """

    def __init__(self, name: str, input_variables: List[str]):
        self._name = name
        self._input_variables = input_variables

    @property
    def name(self) -> str:
        """Get the name of the variable."""
        return self._name

    @property
    def input_variables(self) -> List[str]:
        """Get the input variables for the variable."""
        return self._input_variables

    def _check_variables(self, data: xr.Dataset, variables: Optional[List[str]] = None) -> List[str]:
        """Check that the variables are in the dataset."""
        if variables is None:
            variables = list(data.data_vars)
        for variable in variables:
            if variable not in data.data_vars:
                raise ValueError(f"Variable {variable} not found in dataset.")

    @abstractmethod
    def compute(self, case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.Dataset:
        """Compute the variable from the input variables."""


class PracticallyPerfectHindcast(DerivedVariable):
    """A derived variable that computes the practically perfect hindcast."""

    def __init__(self):
        super().__init__(name = "practically_perfect_hindcast", input_variables = ['report_type'])

    def compute(self, single_case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.Dataset:
        """Compute the practically perfect hindcast."""
        pph = practically_perfect_hindcast(data[self.input_variables], output_bounds = single_case.location, report_type = ['tor', 'hail'])
        return pph
    
class CravenSignificantSevereParameter(DerivedVariable):
    """A derived variable that computes the Craven significant severe parameter."""

    def __init__(self):
        super().__init__(
            name = "craven_significant_severe_parameter", 
            input_variables = ['air_temperature', 
                               'dewpoint_temperature',
                               'relative_humidity',
                               'eastward_wind',
                               'northward_wind',
                               'surface_eastward_wind',
                               'surface_northward_wind'])

    def compute(self, case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.Dataset:
        """Compute the Craven significant severe parameter."""
        cbss_ds = calc.craven_brooks_significant_severe(
            case, data[self.input_variables])
        return data
    

# Observations:

In [81]:
#: Storage/access options for gridded observation datasets.
ARCO_ERA5_FULL_URI = (
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
)

#: Storage/access options for default point observation dataset.
DEFAULT_GHCN_URI = "gs://extremeweatherbench/datasets/ghcnh.parq"

#: Storage/access options for local storm report (LSR) tabular data.
LSR_URI = "gs://extremeweatherbench/datasets/lsr_01012020_04302025.parq"

IBTRACS_URI = "https://www.ncei.noaa.gov/data/international-best-track-archive-for-climate-stewardship-ibtracs/v04r01/access/csv/ibtracs.ALL.list.v04r01.csv"  # noqa: E501

# type hint for the data input to the observation classes
ObservationDataInput = Union[
    xr.Dataset, xr.DataArray, pl.LazyFrame, pd.DataFrame, np.ndarray
]


class Observation(ABC):
    """
    Abstract base class for all observation types.

    An Observation is data that acts as the "truth" for a case. It can be a gridded dataset,
    a point observation dataset, or any other reference dataset. Observations in EWB
    are not required to be the same variable as the forecast dataset, but they must be in the
    same coordinate system for evaluation.
    """

    source: str

    @abstractmethod
    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> ObservationDataInput:
        """
        Open the observation data from the source, opting to avoid loading the entire dataset into memory if possible.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.

        Returns:
            The observation data with a type determined by the user.
        """

    @abstractmethod
    def _subset_data_to_case(
        self,
        data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> ObservationDataInput:
        """
        Subset the observation data to the case information provided in IndividualCase.

        Time information, spatial bounds, and variables are captured in the case metadata
        where this method is used to subset.

        Args:
            data: The observation data to subset, which should be a xarray dataset, xarray dataarray, polars lazyframe,
            pandas dataframe, or numpy array.
            variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.

        Returns:
            The observation data with the variables subset to the case metadata.
        """

    @abstractmethod
    def _maybe_convert_to_dataset(self, data: ObservationDataInput, **kwargs) -> xr.Dataset:
        """
        Convert the observation data to an xarray dataset if it is not already.

        If this method is used prior to _subset_data_to_case, OOM errors are possible
        prior to subsetting.

        Args:
            data: The observation data already run through _subset_data_to_case.

        Returns:
            The observation data as an xarray dataset.
        """

    @abstractmethod
    def _maybe_map_variable_names(self, data: ObservationDataInput, variable_mapping: Optional[dict] = None, **kwargs) -> ObservationDataInput:
        """
        Map the variable names to the observation data, if required.
        """
    
    def _maybe_derive_variables(
        self, data: xr.Dataset, case: case.IndividualCase, variables: list[str | DerivedVariable], **kwargs
    ) -> xr.Dataset:
        """
        Derive variables from the observation data if any exist in variables.

        Args:
            data: The observation data already run through _subset_data_to_case.
            variables: The variables to derive.

        Returns:
            The observation data with the derived variables.
        """
        for v in variables:
            # there should only be strings or derived variables in the list
            if not isinstance(v, str):
                if not issubclass(v, DerivedVariable):
                    raise ValueError(f"Expected str or DerivedVariable, got {type(v)}")
                derived_data = v().compute(data=data, single_case=case, variables=variables)
                return derived_data
        return data

    def run_pipeline(
        self,
        case: case.IndividualCase,
        storage_options: Optional[dict] = None,
        variables: Optional[list[str | DerivedVariable]] = None,
        variable_mapping: dict = {},
        **kwargs,
    ) -> xr.Dataset:
        """
        Shared method for running the observation pipeline.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.
            variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.
            variable_mapping: A dictionary of variable names to map to the observation data.
            **kwargs: Additional keyword arguments to pass in as needed.

        Returns:
            The observation data with a type determined by the user.
        """

        
        # Open data and process through pipeline steps
        data = (
            self._open_data_from_source(
                storage_options=storage_options,
                **kwargs,
            )
            .pipe(
                self._maybe_map_variable_names, 
                variable_mapping=variable_mapping,
                **kwargs,
            )
            .pipe(
                self._subset_data_to_case,
                case=case,
                variables=variables,
                **kwargs,
            )
            .pipe(self._maybe_convert_to_dataset, **kwargs)
            .pipe(
                self._maybe_derive_variables,  
                case=case, 
                variables=variables or [], 
                **kwargs,
                )
        )
        return data


class ERA5(Observation):
    """
    Observation class for ERA5 gridded data.

    The easiest approach to using this class
    is to use the ARCO ERA5 dataset provided by Google for a source. Otherwise, either a
    different zarr source or modifying the _open_data_from_source method to open the data
    using another method is required.
    """

    source: str = ARCO_ERA5_FULL_URI

    def _open_data_from_source(
        self, 
        storage_options: Optional[dict] = None, 
        chunks: dict = {'time': 48, 'latitude': 721, 'longitude': 1440},
        **kwargs,
    ) -> ObservationDataInput:
        data = xr.open_zarr(
            self.source,
            storage_options=storage_options,
            chunks=None,
        )
        return data

    def _subset_data_to_case(
        self,
        data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> ObservationDataInput:

        if not isinstance(data, (xr.Dataset, xr.DataArray)):
            raise ValueError(f"Expected xarray Dataset or DataArray, got {type(data)}")

        # subset time first to avoid OOM masking issues
        subset_time_data = data.sel(time=slice(case.start_date, case.end_date))

        # check that the variables are in the observation data
        if variables is not None and any(
            var not in subset_time_data.data_vars for var in variables
        ):
            raise ValueError(f"Variables {variables} not found in observation data")
        # subset the variables
        elif variables is not None:
            subset_time_variable_data = subset_time_data[variables]
        else:
            raise ValueError("Variables not defined for ERA5. Please list at least one variable to select.")
        # # calling chunk here to avoid loading subset_data into memory
        chunks = kwargs.get('chunks', {'time': 48, 'latitude': 721, 'longitude': 1440})
        subset_time_variable_data = subset_time_variable_data.chunk(chunks) 
        # mask the data to the case location
        fully_subset_data = case.location.mask(subset_time_variable_data, drop=True)

        return fully_subset_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput, **kwargs):
        if isinstance(data, xr.DataArray):
            data = data.to_dataset()
        return data

    def _maybe_map_variable_names(self, data: ObservationDataInput, variable_mapping: Optional[dict] = None, **kwargs) -> ObservationDataInput:
        """
        Map the variable names to the observation data, if required.
        """
        if variable_mapping is None:
            return data
        # Filter the mapping to only include variables that exist in the dataset
        filtered_mapping = {v: k for k, v in variable_mapping.items() if v in data.data_vars}
        if filtered_mapping:
            data = data.rename(filtered_mapping)
        return data


class GHCN(Observation):
    """
    Observation class for GHCN tabular data.

    Data is processed using polars to maintain the lazy loading
    paradigm in _open_data_from_source and to separate the subsetting
    into _subset_data_to_case.
    """

    source: str = DEFAULT_GHCN_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> ObservationDataInput:
        observation_data: pl.LazyFrame = pl.scan_parquet(
            self.source, storage_options=storage_options
        )

        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> ObservationDataInput:
        # Create filter expressions for LazyFrame
        time_min = case.start_date - pd.Timedelta(days=2)
        time_max = case.end_date + pd.Timedelta(days=2)

        if not isinstance(observation_data, pl.LazyFrame):
            raise ValueError(f"Expected polars LazyFrame, got {type(observation_data)}")

        # Apply filters using proper polars expressions
        subset_observation_data = observation_data.filter(
            (pl.col("time") >= time_min)
            & (pl.col("time") <= time_max)
            & (pl.col("latitude") >= case.location.latitude_min)
            & (pl.col("latitude") <= case.location.latitude_max)
            & (pl.col("longitude") >= case.location.longitude_min)
            & (pl.col("longitude") <= case.location.longitude_max)
        )

        # Add time, latitude, and longitude to the variables, polars doesn't do indexes
        if variables is None:
            all_variables = ["time", "latitude", "longitude"]
        else:
            all_variables = variables + ["time", "latitude", "longitude"]

        # check that the variables are in the observation data
        schema_fields = [field for field in subset_observation_data.collect_schema()]
        if variables is not None and any(
            var not in schema_fields for var in all_variables
        ):
            raise ValueError(f"Variables {all_variables} not found in observation data")

        # subset the variables
        if variables is not None:
            subset_observation_data = subset_observation_data.select(all_variables)

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput, **kwargs):
        if isinstance(data, pl.LazyFrame):
            data = data.collect().to_pandas()
            data = data.set_index(["time", "latitude", "longitude"])
            # GHCN data can have duplicate values right now, dropping here if it occurs
            try:
                data = data.to_xarray()
            except ValueError as e:
                if "non-unique" in str(e):
                    pass
                data = data.drop_duplicates().to_xarray()
            return data
        else:
            raise ValueError(f"Data is not a polars LazyFrame: {type(data)}")
        
    def _maybe_map_variable_names(self, data: ObservationDataInput, variable_mapping: dict, **kwargs) -> ObservationDataInput:
        """
        Map the variable names to the observation data, if required.
        """
        # Filter the mapping to only include variables that exist in the dataset
        filtered_mapping = {v: k for k, v in variable_mapping.items() if v in data.columns}
        if filtered_mapping:
            data = data.rename(filtered_mapping)
        return data

class LSR(Observation):
    """
    Observation class for local storm report (LSR) tabular data.

    run_pipeline() returns a dataset with LSRs and practically perfect hindcast gridded
    probability data. IndividualCase date ranges for LSRs should ideally be
    12 UTC to the next day at 12 UTC to match SPC methods.
    """

    source: str = LSR_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> ObservationDataInput:
        
        # force LSR to use anon token to prevent google reauth issues for users
        observation_data = pd.read_parquet(self.source, storage_options={'token': 'anon'})

        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> ObservationDataInput:
        if not isinstance(observation_data, pd.DataFrame):
            raise ValueError(f"Expected pandas DataFrame, got {type(observation_data)}")

        # latitude, longitude are strings by default, convert to float
        observation_data["lat"] = observation_data["lat"].astype(float)
        observation_data["lon"] = observation_data["lon"].astype(float)
        observation_data["time"] = pd.to_datetime(observation_data["time"])

        filters = (
            (observation_data["time"] >= case.start_date)
            & (observation_data["time"] <= case.end_date)
            & (observation_data["lat"] >= case.location.latitude_min)
            & (observation_data["lat"] <= case.location.latitude_max)
            & (observation_data["lon"] >= convert_longitude_to_180(case.location.longitude_min))
            & (observation_data["lon"] <= convert_longitude_to_180(case.location.longitude_max))
        )

        subset_observation_data = observation_data.loc[filters]

        subset_observation_data = subset_observation_data.rename(
            columns={"lat": "latitude", "lon": "longitude", "time": "valid_time"}
        )

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput, **kwargs):
        if isinstance(data, pd.DataFrame):
            data = data.set_index(["valid_time", "latitude", "longitude"])
            data = xr.Dataset.from_dataframe(data[~data.index.duplicated(keep='first')],sparse=True)
            return data
        else:
            raise ValueError(f"Data is not a pandas DataFrame: {type(data)}")

    def _maybe_map_variable_names(self, data: ObservationDataInput, variable_mapping: dict, **kwargs) -> ObservationDataInput:
        """
        Map the variable names to the observation data, if required.
        """
        # Filter the mapping to only include variables that exist in the dataset
        filtered_mapping = {v: k for k, v in variable_mapping.items() if v in data.columns}
        if filtered_mapping:
            data = data.rename(filtered_mapping)
        return data

class IBTrACS(Observation):
    """
    Observation class for IBTrACS data.
    """

    source: str = IBTRACS_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> ObservationDataInput:
        # not using storage_options in this case due to NetCDF4Backend not supporting them
        observation_data: pl.LazyFrame = pl.scan_csv(
            self.source, storage_options=storage_options
        )
        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> ObservationDataInput:
        # Create filter expressions for LazyFrame
        year = case.start_date.year

        if not isinstance(observation_data, pl.LazyFrame):
            raise ValueError(f"Expected polars LazyFrame, got {type(observation_data)}")

        # Apply filters using proper polars expressions
        subset_observation_data = observation_data.filter(
            (pl.col("NAME") == case.title.upper())
        )

        all_variables = [
            "SEASON",
            "NUMBER",
            "NAME",
            "ISO_TIME",
            "LAT",
            "LON",
            "WMO_WIND",
            "USA_WIND",
            "WMO_PRES",
            "USA_PRES",
        ]
        # Get the season (year) from the case start date, cast as string as polars is interpreting the schema as strings
        season = str(year)

        # First filter by name to get the storm data
        subset_observation_data = observation_data.filter(
            (pl.col("NAME") == case.title.upper())
        )

        # Create a subquery to find all storm numbers in the same season
        matching_numbers = (
            subset_observation_data.filter(pl.col("SEASON") == season)
            .select("NUMBER")
            .unique()
        )

        # Apply the filter to get all data for storms with the same number in the same season
        # This maintains the lazy evaluation
        subset_observation_data = observation_data.join(
            matching_numbers, on="NUMBER", how="inner"
        ).filter((pl.col("NAME") == case.title.upper()) & (pl.col("SEASON") == season))

        # check that the variables are in the observation data
        schema_fields = [field for field in subset_observation_data.collect_schema()]
        if variables is not None and any(
            var not in schema_fields for var in all_variables
        ):
            raise ValueError(f"Variables {all_variables} not found in observation data")

        # subset the variables
        if variables is not None:
            subset_observation_data = subset_observation_data.select(all_variables)

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput, **kwargs):
        if isinstance(data, pl.LazyFrame):
            data = data.collect().to_pandas()
            data = data.set_index(["ISO_TIME"])
            try:
                data = data.to_xarray()
            except ValueError as e:
                if "non-unique" in str(e):
                    pass
                data = data.drop_duplicates().to_xarray()
            return data
        else:
            raise ValueError(f"Data is not a polars LazyFrame: {type(data)}")
        
    def _maybe_map_variable_names(self, data: ObservationDataInput, variable_mapping: dict, **kwargs) -> ObservationDataInput:
        """
        Map the variable names to the observation data, if required.
        """
        # Filter the mapping to only include variables that exist in the dataset
        filtered_mapping = {v: k for k, v in variable_mapping.items() if v in data.columns}
        if filtered_mapping:
            data = data.rename(filtered_mapping)
        return data


# Cases:

In [82]:
#TODO implement this in case.py
@dataclasses.dataclass
class BaseCaseMetadataCollection:
    cases: List[case.IndividualCase]

    def subset_cases_by_event_type(self, event_type: str) -> List[case.IndividualCase]:
        """Subset the cases in the collection by event type."""
        return [c for c in self.cases if c.event_type == event_type]

In [83]:
@dataclasses.dataclass
class CaseOperator:
    """A class which stores the graph to process an individual case."""
    
    case: case.IndividualCase
    metrics: list[metrics.Metric]
    observations: list[Observation]
    
    def evaluate_case(self, forecast: xr.Dataset):
        """Process a case."""
        self.process_metrics(forecast)
        
    def process_metrics(self, forecast: xr.Dataset):
        """Process the metrics."""
        for metric in self.metrics:
            metric.process_metric(forecast, self.observations)

    def build_observations(self) -> xr.Dataset:
        """Build observation xarray Datasets from the observation sources."""
        observation_datasets = []
        for observation in self.observation_sources:
            obs_dataset = observation.run_pipeline()
            observation_datasets.append(obs_dataset)
        
        # Combine all observation datasets into a single dataset
        if len(observation_datasets) >= 1:
            combined_obs = xr.merge(observation_datasets)
            return combined_obs
        else:
            raise ValueError("No observations provided or observations failed to run, check the observation sources.")

    

# Metrics:

It seems logical to split up simple metrics, e.g. MAE, and more complicated split-apply-combine or other methods that requires multiple steps to prepare for said simple metric. These more complicated metrics are "AppliedMetrics", and can optionally include the simple metrics. Some metrics part of EWB don't have a simple metric downstream, such as categorical thresholds and contingency table metrics. This is a bit of a WIP so will update as the orchestration becomes more clear.

In [84]:
class BaseMetric(ABC):
    @abstractmethod
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset):
        pass

class AppliedMetric(ABC):
    def __init__(
            self, 
            metric: BaseMetric, 
            observation_sources: list[Observation],
            variables: list[str | DerivedVariable],
            
            ):
        self.metric = metric
        self.observation_sources = observation_sources
        self.variables = variables

    def compute_metric(self, forecast: xr.Dataset):
        return self.metric.compute(forecast, self.observation_sources)


class MAE(BaseMetric):

    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return mae(forecast, observation, **kwargs)

class RMSE(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return rmse(forecast, observation, **kwargs)

class MaximumMAE(AppliedMetric):

    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        maximum_timestep = observation.mean(["latitude", "longitude"]).idxmax("valid_time").values
        maximum_value = observation.mean(["latitude", "longitude"]).sel(valid_time=maximum_timestep).values
        forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
        filtered_max_forecast = forecast_spatial_mean.mean(['latitude','longitude']).where(
            (forecast_spatial_mean.valid_time >= maximum_timestep - np.timedelta64(48, 'h')) & 
            (forecast_spatial_mean.valid_time <= maximum_timestep + np.timedelta64(48, 'h')),
            drop=True
        ).max('valid_time')
        return self.metric().compute(filtered_max_forecast, maximum_value)
    
class RegionalRMSE(AppliedMetric):
    def __init__(self, metric: BaseMetric = RMSE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        return self.metric.compute(forecast, observation, preserve_dims='lead_time')
    
# Dummy metric classes for different event types

class MaxMinMAE(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for finding both max and min values
        return self.metric().compute(forecast, observation)

class OnsetME(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for onset mean error
        return self.metric().compute(forecast, observation)

class DurationME(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for duration mean error
        return self.metric().compute(forecast, observation)

class CSI(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for Critical Success Index
        return self.metric().compute(forecast, observation)

class LeadTimeDetection(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for lead time detection
        return self.metric().compute(forecast, observation)

class RegionalHitsMisses(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for regional hits and misses
        return self.metric().compute(forecast, observation)

class HitsMisses(AppliedMetric):
    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for hits and misses
        return self.metric().compute(forecast, observation)


# Events:

In [85]:
def maybe_expand_variable_lists(variable_list: List[str | DerivedVariable]) -> List[str]:
    """Build a list of core variables for the event, given the forecast and observation variables."""

    def iterator(variables: List[str | DerivedVariable]) -> List[str]:
        for variable in variables:
            if isinstance(variable, str):
                pass
            elif issubclass(variable, DerivedVariable):
                variables.extend([n for n in variable().input_variables])
        return variables

    return iterator(variable_list)

class EventType(ABC):
    """A base class defining the interface for ExtremeWeatherBench event types.

    An Event in ExtremeWeatherBench defines a specific weather event type, such as a heat wave,
    severe convective weather, or atmospheric rivers. These events encapsulate a set of cases and
    derived behavior for evaluating those cases. These cases will share common metrics, observations,
    and variables while each having unique dates and locations.

    Attributes:
        event_type: The type of event.
        forecast_variables: A list of variables that are used to forecast the event.
        observation_variables: A list of variables that are used to observe the event.
        case_metadata: A dictionary or yaml file with guiding metadata.
        metrics: A list of Metrics that are used to evaluate the cases.
        observations: A list of Observations that are used as targets for the metrics.
    """

    def __init__(
        self,
        event_type: str,
        forecast_variables: List[str | DerivedVariable],
        observation_variables: List[str | DerivedVariable],
        case_metadata: dict[str, Any],
        metrics: List[metrics.Metric],
        observations: List[Observation],
    ):
        self.event_type = event_type
        self.forecast_variables = maybe_expand_variable_lists(forecast_variables)
        self.observation_variables = maybe_expand_variable_lists(observation_variables)
        self.case_metadata = case_metadata
        self.metrics = metrics
        self.observations = observations

    def _build_base_case_metadata_collection(self) -> BaseCaseMetadataCollection:
        """Build a list of IndividualCases from the case_metadata."""
        cases = dacite.from_dict(
            data_class=BaseCaseMetadataCollection, 
            data=self.case_metadata, 
            config=dacite.Config(
                    type_hooks={regions.Region: regions.map_to_create_region},
                ),
        )
        cases = BaseCaseMetadataCollection(cases=[c for c in cases.cases if c.event_type == self.event_type])
        return cases
    
    def build_case_operator(self) -> list[CaseOperator]:
        """Build a CaseOperator from the event type."""
        case_metadata_collection = self._build_base_case_metadata_collection()
        case_operators = [
            CaseOperator(
                case = case,
                metrics = self.metrics,
                observations = self.observations,
                ) 
                for case in case_metadata_collection.cases
                ]
        return case_operators

class HeatWave(EventType):
    def __init__(self, case_metadata: dict[str, Any], 
                 forecast_variables: List[str | DerivedVariable] = ['surface_air_temperature'],
                 observation_variables: List[str | DerivedVariable] = ['surface_air_temperature'],
                 metrics: List[metrics.Metric] = [MaximumMAE,
                                                  MaxMinMAE, 
                                                  RegionalRMSE,
                                                  OnsetME,
                                                  DurationME], 
                 observations: List[Observation] = [ERA5]
                 ):
        super().__init__(event_type='heat_wave', 
                         forecast_variables=forecast_variables,
                         observation_variables=observation_variables,
                         case_metadata=case_metadata, 
                         metrics=metrics, 
                         observations=observations)

class SevereConvection(EventType):
    def __init__(self, case_metadata: dict[str, Any], 
                 forecast_variables: List[str | DerivedVariable] = [
                     CravenSignificantSevereParameter,
                 ],
                 observation_variables: List[str | DerivedVariable] = [
                     PracticallyPerfectHindcast,
                 ],
                 metrics: List[metrics.Metric] = [CSI, 
                                                  LeadTimeDetection, 
                                                  RegionalHitsMisses, 
                                                  HitsMisses
                                                  ], 
                 observations: List[Observation] = [LSR]
                 ):
        super().__init__(event_type='severe_convection',
                         forecast_variables=forecast_variables,
                         observation_variables=observation_variables,
                         case_metadata=case_metadata, 
                         metrics=metrics, 
                         observations=observations)

class AtmosphericRiver(EventType):
    def __init__(self, case_metadata: dict[str, Any], 
                 forecast_variables: List[str | DerivedVariable] = [],
                 observation_variables: List[str | DerivedVariable] = [],
                 metrics: List[metrics.Metric] = [CSI, 
                                                  LeadTimeDetection,
                                                  ], 
                 observations: List[Observation] = [ERA5]
                 ):
        super().__init__(event_type='atmospheric_river', 
                         forecast_variables=forecast_variables,
                         observation_variables=observation_variables,
                         case_metadata=case_metadata, 
                         metrics=metrics, 
                         observations=observations)

@dataclasses.dataclass    
class EventOperator:
    events: List[EventType]
    pre_composed_metrics: dict[EventType, List[metrics.Metric]] = dataclasses.field(default_factory=list, init=False, repr=True)
    pre_composed_observations: dict[EventType, List[Observation]] = dataclasses.field(default_factory=list, init=False, repr=True)
    pre_composed_forecast_variables: dict[EventType, List[str | DerivedVariable]] = dataclasses.field(default_factory=list, init=False, repr=True)
    pre_composed_observation_variables: dict[EventType, List[str | DerivedVariable]] = dataclasses.field(default_factory=list, init=False, repr=True)
    pre_composed_case_operators: List[CaseOperator] = dataclasses.field(default_factory=list, init=False, repr=True)
    
    def __post_init__(self):
        # Unravel attributes from composed event types
        self.pre_composed_metrics = {}
        self.pre_composed_observations = {}
        self.pre_composed_forecast_variables = {}
        self.pre_composed_observation_variables = {}
        self.pre_composed_case_operators = []
        
        # Collect attributes from each event type
        for event in self.events:
            self.pre_composed_metrics[event.event_type] = event.metrics
            self.pre_composed_observations[event.event_type] = event.observations
            self.pre_composed_forecast_variables[event.event_type] = event.forecast_variables
            self.pre_composed_observation_variables[event.event_type] = event.observation_variables
            self.pre_composed_case_operators.extend(event.build_case_operator())



# Orchestration

In [92]:
def _process_observations(case_operator: CaseOperator, variables: dict[str, list[str | DerivedVariable]], copy_data = True, **kwargs):
    event_type = case_operator.case.event_type
    observation_data_list = []
    if copy_data:
        modified_case_operator = copy.deepcopy(case_operator)
    else:
        modified_case_operator = case_operator
    for observation in modified_case_operator.observations:
        observation_data = observation().run_pipeline(case=modified_case_operator.case, variables=variables[event_type], **kwargs)
        observation_data_list.append(observation_data)
    return observation_data_list

class ExtremeWeatherBench:
    def __init__(self, event_operator: EventOperator, forecast_dir: str):
        self.event_operator = copy.deepcopy(event_operator)
        self.forecast_dir = forecast_dir

    def run(self, **kwargs):
        '''Runs the workflow'''
        
        for case_operator in self.event_operator.pre_composed_case_operators:
            print(self.process_case(case_operator, **kwargs))
        pass

    def process_case(self, case_operator: CaseOperator, **kwargs):
        observation_ds = self.process_observations(case_operator, **kwargs)
        return observation_ds

    def process_observations(self, case_operator: CaseOperator, **kwargs):
        pre_derived_observation_ds = _process_observations(case_operator, variables=self.event_operator.pre_composed_observation_variables, **kwargs)
        return pre_derived_observation_ds



## Simple Event Operator:

In [93]:
case_yaml = utils.read_event_yaml('/Users/taylor/code/ExtremeWeatherBench/src/extremeweatherbench/data/events.yaml')
forecast_dir = 'gs://extremeweatherbench/virtualizarr/fcn_v3.parq'

In [94]:
heat_waves = HeatWave(case_metadata=case_yaml)
severe = SevereConvection(case_metadata=case_yaml)


In [95]:
# With the EventOperator class
simple_event_operator = EventOperator(events=[heat_waves, severe])
ewb = ExtremeWeatherBench(simple_event_operator, forecast_dir)

In [96]:
ewb.event_operator

EventOperator(events=[<__main__.HeatWave object at 0x1198d6650>, <__main__.SevereConvection object at 0x1198d69e0>], pre_composed_metrics={'heat_wave': [<class '__main__.MaximumMAE'>, <class '__main__.MaxMinMAE'>, <class '__main__.RegionalRMSE'>, <class '__main__.OnsetME'>, <class '__main__.DurationME'>], 'severe_convection': [<class '__main__.CSI'>, <class '__main__.LeadTimeDetection'>, <class '__main__.RegionalHitsMisses'>, <class '__main__.HitsMisses'>]}, pre_composed_observations={'heat_wave': [<class '__main__.ERA5'>], 'severe_convection': [<class '__main__.LSR'>]}, pre_composed_forecast_variables={'heat_wave': ['surface_air_temperature'], 'severe_convection': [<class '__main__.CravenSignificantSevereParameter'>, 'air_temperature', 'dewpoint_temperature', 'relative_humidity', 'eastward_wind', 'northward_wind', 'surface_eastward_wind', 'surface_northward_wind', 'air_temperature', 'dewpoint_temperature', 'relative_humidity', 'eastward_wind', 'northward_wind', 'surface_eastward_wind'

In [97]:
# args here will apply to all observations and forecasts for all events;
# these should also be able to be passed to each event type individually
ewb.run(storage_options=dict(token="anon"), chunks={'time': 48, 'latitude': 721, 'longitude': 1440}, variable_mapping=ERA5_MAPPING)    

[<xarray.Dataset> Size: 503kB
Dimensions:                  (time: 313, latitude: 20, longitude: 20)
Coordinates:
  * latitude                 (latitude) float32 80B 50.0 49.75 ... 45.5 45.25
  * longitude                (longitude) float32 80B 235.2 235.5 ... 239.8 240.0
  * time                     (time) datetime64[ns] 3kB 2021-06-20 ... 2021-07-03
Data variables:
    surface_air_temperature  (time, latitude, longitude) float32 501kB dask.array<chunksize=(48, 20, 20), meta=np.ndarray>
Attributes:
    last_updated:           2025-07-25 01:52:37.403950+00:00
    valid_time_start:       1940-01-01
    valid_time_stop:        2025-04-30
    valid_time_stop_era5t:  2025-07-19]
[<xarray.Dataset> Size: 388kB
Dimensions:                  (time: 241, latitude: 20, longitude: 20)
Coordinates:
  * latitude                 (latitude) float32 80B 44.25 44.0 ... 39.75 39.5
  * longitude                (longitude) float32 80B 270.0 270.2 ... 274.5 274.8
  * time                     (time) datetime6

# Tests:

In [25]:
severe = SevereConvection(case_metadata=case_yaml)

# With the EventOperator class
simple_event_operator = EventOperator(events=[severe])
ewb = ExtremeWeatherBench(simple_event_operator, forecast_dir)

In [None]:
for n in ewb.event_operator.pre_composed_case_operators:
    test_output = ewb.process_observations(n)
    ax = plot_practically_perfect_hindcast(test_output)
    plt.show()

Exception ignored in: <function WeakSet.__init__.<locals>._remove at 0x100bd85e0>
Traceback (most recent call last):
  File "/Users/taylor/.local/share/uv/python/cpython-3.13.1-macos-aarch64-none/lib/python3.13/_weakrefset.py", line 39, in _remove
    def _remove(item, selfref=ref(self)):
KeyboardInterrupt: 
