In [60]:
import dataclasses
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
import xarray as xr
import numpy as np
import dacite
from scores.continuous import mae, rmse
import polars as pl
import pandas as pd
from extremeweatherbench import case, metrics, observations,utils

# DerivedVariables

In [56]:
class DerivedVariable(ABC):
    """A base class defining the interface for ExtremeWeatherBench derived variables.
    
    A DerivedVariable is any variable that requires extra computation, not derived in an
    observation or forecast raw dataset. Some examples include the practically perfect hindcast,
    MLCAPE, IVT, or atmospheric river masks.
    
    Attributes:
        name: The name of the variable.
        input_variables: A list of variables that are used to compute the variable.
    """

    def __init__(self, name: str, input_variables: List[str]):
        self.name = name
        self.input_variables = input_variables

    @abstractmethod
    def compute(self, data: xr.Dataset) -> xr.Dataset:
        """Compute the variable from the input variables."""

# Observations:

In [63]:
# type hint for the data input to the observation classes
ObservationDataInput = Union[
    xr.Dataset, xr.DataArray, pl.LazyFrame, pd.DataFrame, np.ndarray
]

class Observation(ABC):
    """
    Abstract base class for all observation types.

    An Observation is data that acts as the "truth" for a case. It can be a gridded dataset,
    a point observation dataset, or any other reference dataset. Observations in EWB
    are not required to be the same variable as the forecast dataset, but they must be in the
    same coordinate system for evaluation.
    """

    @abstractmethod
    def _open_data_from_source(
        self, source: str, storage_options: Optional[dict] = None
    ) -> ObservationDataInput:
        """
        Open the observation data from the source, opting to avoid loading the entire dataset into memory if possible.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.

        Returns:
            The observation data with a type determined by the user.
        """

    @abstractmethod
    def _subset_data_to_case(
        self,
        data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
    ) -> ObservationDataInput:
        """
        Subset the observation data to the case information provided in IndividualCase.

        Time information, spatial bounds, and variables are captured in the case metadata
        where this method is used to subset.

        Args:
            data: The observation data to subset, which should be a xarray dataset, xarray dataarray, polars lazyframe,
            pandas dataframe, or numpy array.
            variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.

        Returns:
            The observation data with the variables subset to the case metadata.
        """

    @abstractmethod
    def _maybe_convert_to_dataset(self, data: ObservationDataInput) -> xr.Dataset:
        """
        Convert the observation data to an xarray dataset if it is not already.

        If this method is used prior to _subset_data_to_case, OOM errors are possible
        prior to subsetting.

        Args:
            data: The observation data already run through _subset_data_to_case.

        Returns:
            The observation data as an xarray dataset.
        """

    def _maybe_derive_variables(
        self, data: xr.Dataset, variables: list[str | DerivedVariable]
    ) -> xr.Dataset:
        """
        Derive variables from the observation data if any exist in variables.

        Args:
            data: The observation data already run through _subset_data_to_case.
            variables: The variables to derive.

        Returns:
            The observation data with the derived variables.
        """

        for v in variables:
            if isinstance(v, DerivedVariable):
                derived_variable = v.compute(data)
                data[v.name] = derived_variable
        return data

    def run_pipeline(
        self,
        source: str,
        case: case.CaseOperator,
        storage_options: Optional[dict] = None,
    ) -> xr.Dataset:
        """
        Shared method for running the observation pipeline.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.
            variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.

        Returns:
            The observation data with a type determined by the user.
        """

        # Open data and process through pipeline steps
        data = (
            self._open_data_from_source(
                source=source,
                storage_options=storage_options,
            )
            .pipe(
                self._subset_data_to_case,
                case=case,
                variables=[v for v in case.variables if isinstance(v, str)],
            )
            .pipe(self._maybe_convert_to_dataset)
            .pipe(self._maybe_derive_variables, variables=case.variables)
        )
        return data


class ERA5(Observation):
    """
    Observation class for ERA5 gridded data.

    The easiest approach to using this class
    is to use the ARCO ERA5 dataset provided by Google for a source. Otherwise, either a
    different zarr source or modifying the _open_data_from_source method to open the data
    using another method is required.
    """

    def _open_data_from_source(
        self, source: str, storage_options: Optional[dict] = None
    ) -> ObservationDataInput:
        data = xr.open_zarr(
            source,
            chunks=None,
            storage_options=dict(token="anon"),
        )
        return data

    def _subset_data_to_case(
        self,
        data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
    ) -> ObservationDataInput:
        # TODO: fix case to automatically apply these; currently stand-in for now
        case.latitude_min = case.location.latitude - case.bounding_box_degrees / 2
        case.latitude_max = case.location.latitude + case.bounding_box_degrees / 2
        case.longitude_min = np.mod(
            case.location.longitude - case.bounding_box_degrees / 2, 360
        )
        case.longitude_max = np.mod(
            case.location.longitude + case.bounding_box_degrees / 2, 360
        )

        if not isinstance(data, (xr.Dataset, xr.DataArray)):
            raise ValueError(f"Expected xarray Dataset or DataArray, got {type(data)}")

        subset_data = data.sel(
            time=slice(case.start_date, case.end_date),
            # latitudes are sliced from max to min
            latitude=slice(case.latitude_max, case.latitude_min),
            longitude=slice(case.longitude_min, case.longitude_max),
        )

        # check that the variables are in the observation data
        if variables is not None and any(
            var not in subset_data.data_vars for var in variables
        ):
            raise ValueError(f"Variables {variables} not found in observation data")

        # subset the variables
        if variables is not None:
            subset_data = subset_data[variables]

        return subset_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput):
        if isinstance(data, xr.DataArray):
            data = data.to_dataset()
        return data


class GHCN(Observation):
    """
    Observation class for GHCN tabular data.

    Data is processed using polars to maintain the lazy loading
    paradigm in _open_data_from_source and to separate the subsetting
    into _subset_data_to_case.
    """

    def _open_data_from_source(
        self, source: str, storage_options: Optional[dict] = None
    ) -> ObservationDataInput:
        observation_data: pl.LazyFrame = pl.scan_parquet(
            source, storage_options=storage_options
        )

        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: ObservationDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
    ) -> ObservationDataInput:
        # Create filter expressions for LazyFrame
        time_min = case.start_date - pd.Timedelta(days=2)
        time_max = case.end_date + pd.Timedelta(days=2)

        # TODO: fix case to automatically apply these; currently stand-in for now
        case.latitude_min = case.location.latitude - case.bounding_box_degrees / 2
        case.latitude_max = case.location.latitude + case.bounding_box_degrees / 2
        case.longitude_min = case.location.longitude - case.bounding_box_degrees / 2
        case.longitude_max = case.location.longitude + case.bounding_box_degrees / 2

        if not isinstance(observation_data, pl.LazyFrame):
            raise ValueError(f"Expected polars LazyFrame, got {type(observation_data)}")

        # Apply filters using proper polars expressions
        subset_observation_data = observation_data.filter(
            (pl.col("time") >= time_min)
            & (pl.col("time") <= time_max)
            & (pl.col("latitude") >= case.latitude_min)
            & (pl.col("latitude") <= case.latitude_max)
            & (pl.col("longitude") >= case.longitude_min)
            & (pl.col("longitude") <= case.longitude_max)
        )

        # Add time, latitude, and longitude to the variables, polars doesn't do indexes
        if variables is None:
            all_variables = ["time", "latitude", "longitude"]
        else:
            all_variables = variables + ["time", "latitude", "longitude"]

        # check that the variables are in the observation data
        schema_fields = [field for field in subset_observation_data.collect_schema()]
        if variables is not None and any(
            var not in schema_fields for var in all_variables
        ):
            raise ValueError(f"Variables {all_variables} not found in observation data")

        # subset the variables
        if variables is not None:
            subset_observation_data = subset_observation_data.select(all_variables)

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: ObservationDataInput):
        if isinstance(data, pl.LazyFrame):
            data = data.collect().to_pandas()
            data = data.set_index(["time", "latitude", "longitude"])
            # GHCN data can have duplicate values right now, dropping here if it occurs
            try:
                data = data.to_xarray()
            except ValueError as e:
                if "non-unique" in str(e):
                    logger.warning(
                        "ValueError when converting to xarray due to duplicate indexes"
                    )
                data = data.drop_duplicates().to_xarray()
            return data
        else:
            raise ValueError(f"Data is not a polars LazyFrame: {type(data)}")

AttributeError: module 'extremeweatherbench.case' has no attribute 'CaseOperator'

# Cases:

In [58]:
#TODO implement this in case.py
@dataclasses.dataclass
class BaseCaseMetadataCollection:
    cases: List[case.IndividualCase]

    def subset_cases_by_event_type(self, event_type: str) -> List[case.IndividualCase]:
        """Subset the cases in the collection by event type."""
        return [c for c in self.cases if c.event_type == event_type]

In [None]:
@dataclasses.dataclass
class CaseOperator:
    """A class which stores the graph to process an individual case."""
    
    case: case.IndividualCase
    metrics: list[metrics.Metric]
    observations: list[observations.Observation]
    variable_mapping: dict[str | DerivedVariable, str | DerivedVariable] = None
    
    def evaluate_case(self, forecast: xr.Dataset):
        """Process a case."""
        self.process_metrics(forecast)
        
    def process_metrics(self, forecast: xr.Dataset):
        """Process the metrics."""
        for metric in self.metrics:
            metric.process_metric(forecast, self.observations)
    

# Events:

In [53]:
class EventType(ABC):
    """A base class defining the interface for ExtremeWeatherBench event types.

    An Event in ExtremeWeatherBench defines a specific weather event type, such as a heat wave,
    severe convective weather, or atmospheric rivers. These events encapsulate a set of cases and
    derived behavior for evaluating those cases. These cases will share common metrics, observations,
    and variables while each having unique dates and locations.

    Attributes:
        case_metadata: A dictionary or yaml file with guiding metadata.
        metrics: A list of Metrics that are used to evaluate the cases.
        observations: A list of Observations that are used as targets for the metrics.
    """

    def __init__(
        self,
        event_type: str,
        case_metadata: dict[str, Any],
        metrics: List[metrics.Metric],
        observations: List[observations.Observation],
        variable_mapping: Optional[
            dict[str | DerivedVariable, str | DerivedVariable]
        ] = None,
    ):
        self.event_type = event_type
        self.case_metadata = case_metadata
        self.metrics = metrics
        self.observations = observations
        self.variable_mapping = variable_mapping

    def _build_base_case_metadata_collection(self) -> BaseCaseMetadataCollection:
        """Build a list of IndividualCases from the case_metadata."""
        cases = dacite.from_dict(
            data_class=BaseCaseMetadataCollection, data=self.case_metadata
        )
        cases = BaseCaseMetadataCollection(cases=[c for c in cases.cases if c.event_type == self.event_type])
        return cases
    
    def build_case_operator(self) -> list[CaseOperator]:
        """Build a CaseOperator from the event type."""
        case_metadata_collection = self._build_base_case_metadata_collection()
        case_operators = [
            CaseOperator(
                case = case,
                metrics = self.metrics,
                observations = self.observations,
                variable_mapping = self.variable_mapping
                ) 
                for case in case_metadata_collection.cases
                ]
        return case_operators
    

class HeatWave(EventType):
    def __init__(self, case_metadata: dict[str, Any], 
                 metrics: List[metrics.Metric], 
                 observations: List[Observation],
                 variable_mapping: Optional[dict[str | DerivedVariable, str | DerivedVariable]] = None):
        super().__init__(event_type='heat_wave', 
                         case_metadata=case_metadata, 
                         metrics=metrics, 
                         observations=observations, 
                         variable_mapping=variable_mapping)

# Metrics:

In [54]:
class BaseMetric(ABC):
    @abstractmethod
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset):
        pass

class AppliedMetric(ABC):
    def __init__(
            self, 
            metric: BaseMetric, 
            observation_sources: list[Observation],
            ):
        self.metric = metric
        self.observation_sources = observation_sources

    def compute_metric(self, forecast: xr.Dataset):
        return self.metric.compute(forecast, self.observation_sources)


class MAE(BaseMetric):

    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return mae(forecast, observation, **kwargs)

class RMSE(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return rmse(forecast, observation, **kwargs)

class MaximumMAE(AppliedMetric):

    def __init__(self, metric: BaseMetric = MAE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        maximum_timestep = observation.mean(["latitude", "longitude"]).idxmax("valid_time").values
        maximum_value = observation.mean(["latitude", "longitude"]).sel(valid_time=maximum_timestep).values
        forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
        filtered_max_forecast = forecast_spatial_mean.mean(['latitude','longitude']).where(
            (forecast_spatial_mean.valid_time >= maximum_timestep - np.timedelta64(48, 'h')) & 
            (forecast_spatial_mean.valid_time <= maximum_timestep + np.timedelta64(48, 'h')),
            drop=True
        ).max('valid_time')
        return self.metric().compute(filtered_max_forecast, maximum_value)
    
class RegionalRMSE(AppliedMetric):
    def __init__(self, metric: BaseMetric = RMSE, observation: Observation = ERA5):
        super().__init__(metric, observation)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        return self.metric.compute(forecast, observation, preserve_dims='lead_time')

In [55]:
case_yaml = utils.load_events_yaml()
heat_waves = HeatWave(case_metadata=case_yaml, 
                      metrics=[MaximumMAE, RegionalRMSE], 
                      observations=[ERA5(name='era5', variable='t2m', units='K')])

heat_waves._build_case_collection()

CaseCollection(cases=[IndividualCase(case_id_number=1, title='2021 Pacific Northwest', start_date=datetime.datetime(2021, 6, 20, 0, 0), end_date=datetime.datetime(2021, 7, 3, 0, 0), location=Location(latitude=47.6062, longitude=np.float64(237.6679)), bounding_box_degrees=5, event_type='heat_wave', data_vars=None, cross_listed=None), IndividualCase(case_id_number=2, title='2022 Upper Midwest', start_date=datetime.datetime(2022, 5, 7, 0, 0), end_date=datetime.datetime(2022, 5, 17, 0, 0), location=Location(latitude=41.8781, longitude=np.float64(272.3702)), bounding_box_degrees=5, event_type='heat_wave', data_vars=None, cross_listed=None), IndividualCase(case_id_number=3, title='2022 California', start_date=datetime.datetime(2022, 6, 7, 0, 0), end_date=datetime.datetime(2022, 6, 15, 0, 0), location=Location(latitude=34.0522, longitude=np.float64(241.7563)), bounding_box_degrees=5, event_type='heat_wave', data_vars=None, cross_listed=None), IndividualCase(case_id_number=4, title='2022 Texas