In [42]:
!uv pip install virtualizarr[all_parsers, all_writers] sparse polars

1504.09s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


[2mUsing Python 3.13.1 environment at: /Users/taylor/code/ExtremeWeatherBench/.venv[0m
[2mAudited [1m3 packages[0m [2min 42ms[0m[0m


In [43]:
import dataclasses
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union, Callable
import xarray as xr
import numpy as np
import dacite
from scores.continuous import mae, rmse, mean_error
import scores.categorical as cat
import polars as pl
import pandas as pd
from extremeweatherbench import regions, case, utils
import copy
import matplotlib.pyplot as plt

from obstore.store import from_url
from virtualizarr.registry import ObjectStoreRegistry


import refactor_scripts as rs

# Notes

1. An easy? solution to dealing with variable mapping:
    - Create a list of the names of each variable used in EWB that users can extend
    - Each Observation should have a default mapping; if it's our observation default we will know the variable names coming in. Users will need to know the names of theirs as well.
    - Forecasts will be able to access this mapping as well and can be provided in the ExtremeWeatherBench call
2. ~~AppliedMetrics will define exactly what observations, observation variables, and forecast variables to use~~
    - I think this is wrong now, it should be defined at the Event level.
3. Defining the order of orchestration is still an open question. I guess this will come as I wire the components together... We're running each case as its own entity, thus:
    - Each case will have multiple metrics that might or might not reuse Observations and Forecasts. Do metrics inside cases define the variables and DerivedVariables? Or should the EventType? Leaning towards EventType. DerivedVariables hold information on what core variables are needed.
        1. Observation(s) should be built first which defines the dimensions for Forecasts. **DONE**
        2. Observation(s) derived variables should be processed next. **DONE**
        3. Forecasts should then be spatiotemporally subset to the Observation(s). **DONE**
        4. Forecasts derived variables should then be processed. **DONE**
        5. Might be worth then temporally subsetting the Observation(s) to the Forecasts
        6. Finally, run applied metrics that use the Observation(s) and Forecasts.


# DerivedVariables

In [80]:
class DerivedVariable(ABC):
    """A abstract base class defining the interface for ExtremeWeatherBench derived variables.
    
    A DerivedVariable is any variable that requires extra computation than what is 
    provided in analysis or forecast data. Some examples include the practically perfect hindcast, MLCAPE, IVT, or atmospheric river masks. 
    
    Attributes:
        name: The name that is used for applications of derived variables. Defaults to the class name.
        input_variables: A list of variables that are used to build the variable.
        build: A method that builds the variable from the input variables. Build is used
        specifically to distinguish from the compute method, which eagerly processes the 
        data and loads into memory; build is used to lazily process the data and return a dataset that can be used to compute the variable.
        compute: A method that computes the variable from the input variables using the build method, returning the final product.
    """

    @property
    def name(self) -> str:
        'A name for the derived variable. Defaults to the class name.'
        return self.__class__.__name__

    @property
    @abstractmethod
    def input_variables(self) -> List[str]:
        """A list of variables that are used to compute the variable.
        
        Each derived variable is a product of one or more variables in an incoming dataset.
        The input variables should be the names of the variables in the incoming dataset, not the final product.

        """
        pass

    @abstractmethod
    def build(self, case: case.IndividualCase, data: xr.Dataset, variable_mapping: dict[str, str]) -> xr.DataArray:
        """Build the variable from the input variables."""
        pass

    def compute(self, case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.DataArray:
        """Computes the variable from the input variables."""
        return self.build(case, data, variables).compute()

# TODO: assign LSRs to a 0.25 degree grid
class PracticallyPerfectHindcast(DerivedVariable):
    """A derived variable that computes the practically perfect hindcast."""

    name = "practically_perfect_hindcast"
    input_variables = ['report_type']

    def build(self, case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.DataArray:
        """Process the practically perfect hindcast."""
        pph = rs.practically_perfect_hindcast(data[self.input_variables], output_bounds = case.location, report_type = ['tor', 'hail'])
        return pph['practically_perfect']
    
class CravenSignificantSevereParameter(DerivedVariable):
    """A derived variable that computes the Craven significant severe parameter."""

    name = "craven_significant_severe_parameter"
    input_variables = [
        '2m_temperature', 
        '2m_dewpoint_temperature',
        '2m_relative_humidity',
        '10m_u_component_of_wind',
        '10m_v_component_of_wind',
        'surface_pressure',
        'geopotential'
    ]

    def build(self, case: case.IndividualCase, data: xr.Dataset, variables: Optional[List[str]] = None) -> xr.DataArray:
        """Build the Craven significant severe parameter."""
        # cbss_ds = calc.craven_brooks_sig_svr(data[variables],variable_mapping={'pressure':'level', 'dewpoint':'dewpoint_temperature','temperature':'air_temperature'})
        test_da = data[variables[0]]*2
        return test_da
    
def maybe_derive_variables(
        ds: xr.Dataset, case: case.IndividualCase, variables: list[str | DerivedVariable]
) -> xr.Dataset:
    """Derive variables from the data if any exist in a list of variables.

    Derived variables must maintain the same spatial dimensions as the original dataset.

    Args:
        ds: The dataset, ideally already subset in case of in memory operations in the derived variables.
        case: The case to derive the variables for.
        variables: The potential variables to derive as a list of strings or DerivedVariable objects.

    Returns:
        A dataset with derived variables, if any exist, else the original dataset.
    """
    derived_variables = {}

    non_derived_variables = [v for v in variables if v in ds.data_vars]
    derived_variables = [v for v in variables if not isinstance(v, str)]
    if derived_variables:
        for v in derived_variables:
            derived_variable = v()
            derived_data = derived_variable.build(case=case, data=ds, variables=non_derived_variables)
            ds[derived_variable.name] = derived_data
        
    return ds

# Metrics:

It seems logical to split up simple metrics, e.g. MAE, and more complicated split-apply-combine or other methods that requires multiple steps to prepare for said simple metric. These more complicated metrics are "AppliedMetrics", and can optionally include the simple metrics. Some metrics part of EWB don't have a simple metric downstream, such as categorical thresholds and contingency table metrics. This is a bit of a WIP so will update as the orchestration becomes more clear.

In [61]:
class BaseMetric(ABC):
    """A BaseMetric class is an abstract class that defines the foundational interface for all metrics.
    
    Metrics are general operations applied between a forecast and analysis 
    xarray dataset. EWB metrics prioritize the use of any arbitrary sets of forecasts
    and analyses, so long as the spatiotemporal dimensions are the same.
    """
    @property
    def name(self) -> str:
        return self.__class__.__name__
    
    @abstractmethod
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset):
        pass

class AppliedMetric(ABC):
    """An applied metric is a derivative of a BaseMetric.

    It is a wrapper around one or more BaseMetrics that is intended for more complex rollups or aggregations.
    Typically, these metrics are used for one event type and are very specific. Temporal onset mean error,
    case duration mean error, and maximum temperature mean absolute error, are all examples of applied metrics.

    Attributes:
        base_metrics: A list of BaseMetrics to compute.
        compute_metric: A required method to compute the metric.
    """

    @property
    def name(self) -> str:
        return self.__class__.__name__
    
    @property
    @abstractmethod
    def base_metrics(self) -> list[BaseMetric]:
        pass

    @abstractmethod
    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset, forecast_variables: list[str], observation_variables: list[str],**kwargs):
        pass

class MAE(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return mae(forecast, observation, **kwargs)

class ME(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return mean_error(forecast, observation, **kwargs)

class RMSE(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return rmse(forecast, observation, **kwargs)

class BinaryContingencyTable(BaseMetric):
    def compute(self, forecast: xr.Dataset, observation: xr.Dataset, **kwargs):
        return cat.BinaryContingencyManager(forecast, observation, **kwargs)

class MaximumMAE(AppliedMetric):

    base_metric = [MAE]

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset, forecast_variables: list[str], observation_variables: list[str]):
        maximum_timestep = observation.mean(["latitude", "longitude"]).idxmax("valid_time").values
        maximum_value = observation.mean(["latitude", "longitude"]).sel(valid_time=maximum_timestep).values
        forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
        filtered_max_forecast = forecast_spatial_mean.mean(['latitude','longitude']).where(
            (forecast_spatial_mean.valid_time >= maximum_timestep - np.timedelta64(48, 'h')) & 
            (forecast_spatial_mean.valid_time <= maximum_timestep + np.timedelta64(48, 'h')),
            drop=True
        ).max('valid_time')
        return self.base_metric().compute(filtered_max_forecast, maximum_value)
    

class MaxMinMAE(AppliedMetric):
    base_metric = MAE

    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for finding both max and min values
        return self.metric().compute(forecast, observation)

class OnsetME(AppliedMetric):
    base_metric = ME
    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for onset mean error
        return self.metric().compute(forecast, observation)

class DurationME(AppliedMetric):
    base_metric = MAE
    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for duration mean error
        return self.metric().compute(forecast, observation)

class CSI(AppliedMetric):
    base_metric = BinaryContingencyTable
    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for Critical Success Index
        return self.metric().compute(forecast, observation)

class LeadTimeDetection(AppliedMetric):
    base_metric = MAE
    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for lead time detection
        return self.metric().compute(forecast, observation)

class RegionalHitsMisses(AppliedMetric):
    base_metric = BinaryContingencyTable
    def __init__(self, variables: list[str | DerivedVariable]):
        super().__init__(variables)

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for regional hits and misses
        return self.metric().compute(forecast, observation)

class HitsMisses(AppliedMetric):
    base_metric = BinaryContingencyTable
    def __init__(self, variables: list[str | DerivedVariable], threshold: float = 0.5):
        super().__init__(variables)
        self.threshold = threshold

    def compute_metric(self, forecast: xr.Dataset, observation: xr.Dataset):
        # Dummy implementation for hits and misses
        return self.metric().compute(forecast, observation)


# Observations:
commenting out for now with the impression the module is operating as intended with the same code

In [62]:
#: Storage/access options for gridded observation datasets.
ARCO_ERA5_FULL_URI = (
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
)

#: Storage/access options for default point observation dataset.
DEFAULT_GHCN_URI = "gs://extremeweatherbench/datasets/ghcnh.parq"

#: Storage/access options for local storm report (LSR) tabular data.
LSR_URI = "gs://extremeweatherbench/datasets/lsr_01012020_04302025.parq"

IBTRACS_URI = "https://www.ncei.noaa.gov/data/international-best-track-archive-for-climate-stewardship-ibtracs/v04r01/access/csv/ibtracs.ALL.list.v04r01.csv"  # noqa: E501

# type hint for the data input to the observation classes
IncomingDataInput = Union[
    xr.Dataset, xr.DataArray, pl.LazyFrame, pd.DataFrame, np.ndarray
]


class Observation(ABC):
    """
    Abstract base class for all observation types.

    An Observation is data that acts as the "truth" for a case. It can be a gridded dataset,
    a point observation dataset, or any other reference dataset. Observations in EWB
    are not required to be the same variable as the forecast dataset, but they must be in the
    same coordinate system for evaluation.
    """

    source: str

    @abstractmethod
    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> IncomingDataInput:
        """
        Open the observation data from the source, opting to avoid loading the entire dataset into memory if possible.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.

        Returns:
            The observation data with a type determined by the user.
        """

    @abstractmethod
    def _subset_data_to_case(
        self,
        data: IncomingDataInput,
        case: case.IndividualCase,
        observation_variables: Optional[list[str]] = None,
        **kwargs,
    ) -> IncomingDataInput:
        """
        Subset the observation data to the case information provided in IndividualCase.

        Time information, spatial bounds, and variables are captured in the case metadata
        where this method is used to subset.

        Args:
            data: The observation data to subset, which should be a xarray dataset, xarray dataarray, polars lazyframe,
            pandas dataframe, or numpy array.
            observation_variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.

        Returns:
            The observation data with the variables subset to the case metadata.
        """

    @abstractmethod
    def _maybe_convert_to_dataset(self, data: IncomingDataInput, **kwargs) -> xr.Dataset:
        """
        Convert the observation data to an xarray dataset if it is not already.

        If this method is used prior to _subset_data_to_case, OOM errors are possible
        prior to subsetting.

        Args:
            data: The observation data already run through _subset_data_to_case.

        Returns:
            The observation data as an xarray dataset.
        """

    def run_pipeline(
        self,
        case: case.IndividualCase,
        storage_options: Optional[dict] = None,
        observation_variables: Optional[list[str | DerivedVariable]] = None,
        observation_variable_mapping: dict = {},
        **kwargs,
    ) -> xr.Dataset:
        """
        Shared method for running the observation pipeline.

        Args:
            source: The source of the observation data, which can be a local path or a remote URL.
            storage_options: Optional storage options for the source if the source is a remote URL.
            variables: The variables to include in the observation. Some observations may not have variables, or
            only have a singular variable; thus, this is optional.
            variable_mapping: A dictionary of variable names to map to the observation data.
            **kwargs: Additional keyword arguments to pass in as needed.

        Returns:
            The observation data with a type determined by the user.
        """

        
        # Open data and process through pipeline steps
        data = (
            self._open_data_from_source(
                storage_options=storage_options,
                **kwargs,
            )
            .pipe(
                rs.maybe_map_variable_names,
                variable_mapping=observation_variable_mapping,
                **kwargs,
            )
            .pipe(
                self._subset_data_to_case,
                case=case,
                observation_variables=observation_variables,
                **kwargs,
            )
            .pipe(self._maybe_convert_to_dataset, **kwargs)
            .pipe(
                maybe_derive_variables,  
                case=case, 
                variables=observation_variables
                )
        )
        return data


class ERA5(Observation):
    """
    Observation class for ERA5 gridded data.

    The easiest approach to using this class
    is to use the ARCO ERA5 dataset provided by Google for a source. Otherwise, either a
    different zarr source or modifying the _open_data_from_source method to open the data
    using another method is required.
    """

    source: str = ARCO_ERA5_FULL_URI

    def _open_data_from_source(
        self, 
        storage_options: Optional[dict] = None, 
        chunks: dict = {'time': 48, 'latitude': 721, 'longitude': 1440},
        **kwargs,
    ) -> IncomingDataInput:
        data = xr.open_zarr(
            self.source,
            storage_options=storage_options,
            chunks=chunks,
        )
        return data

    def _subset_data_to_case(
        self,
        data: IncomingDataInput,
        case: case.IndividualCase,
        observation_variables: Optional[list[str]] = None,
        **kwargs,
    ) -> IncomingDataInput:

        if not isinstance(data, (xr.Dataset, xr.DataArray)):
            raise ValueError(f"Expected xarray Dataset or DataArray, got {type(data)}")

        # subset time first to avoid OOM masking issues
        subset_time_data = data.sel(time=slice(case.start_date, case.end_date))

        # check that the variables are in the observation data
        if observation_variables is not None and any(
            var not in subset_time_data.data_vars for var in observation_variables
        ):
            raise ValueError(f"Variables {observation_variables} not found in observation data")
        # subset the variables
        elif observation_variables is not None:
            subset_time_variable_data = subset_time_data[observation_variables]
        else:
            raise ValueError("Variables not defined for ERA5. Please list at least one variable to select.")
        # # calling chunk here to avoid loading subset_data into memory
        chunks = kwargs.get('chunks', {'time': 48, 'latitude': 721, 'longitude': 1440})
        subset_time_variable_data = subset_time_variable_data.chunk(chunks) 
        # mask the data to the case location
        fully_subset_data = case.location.mask(subset_time_variable_data, drop=True)

        return fully_subset_data

    def _maybe_convert_to_dataset(self, data: IncomingDataInput, **kwargs):
        if isinstance(data, xr.DataArray):
            data = data.to_dataset()
        return data


class GHCN(Observation):
    """
    Observation class for GHCN tabular data.

    Data is processed using polars to maintain the lazy loading
    paradigm in _open_data_from_source and to separate the subsetting
    into _subset_data_to_case.
    """

    source: str = DEFAULT_GHCN_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> IncomingDataInput:
        observation_data: pl.LazyFrame = pl.scan_parquet(
            self.source, storage_options=storage_options
        )

        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: IncomingDataInput,
        case: case.IndividualCase,
        observation_variables: Optional[list[str]] = None,
        **kwargs,
    ) -> IncomingDataInput:
        # Create filter expressions for LazyFrame
        time_min = case.start_date - pd.Timedelta(days=2)
        time_max = case.end_date + pd.Timedelta(days=2)

        if not isinstance(observation_data, pl.LazyFrame):
            raise ValueError(f"Expected polars LazyFrame, got {type(observation_data)}")

        # Apply filters using proper polars expressions
        subset_observation_data = observation_data.filter(
            (pl.col("time") >= time_min)
            & (pl.col("time") <= time_max)
            & (pl.col("latitude") >= case.location.latitude_min)
            & (pl.col("latitude") <= case.location.latitude_max)
            & (pl.col("longitude") >= case.location.longitude_min)
            & (pl.col("longitude") <= case.location.longitude_max)
        )

        # Add time, latitude, and longitude to the variables, polars doesn't do indexes
        if observation_variables is None:
            all_variables = ["time", "latitude", "longitude"]
        else:
            all_variables = observation_variables + ["time", "latitude", "longitude"]

        # check that the variables are in the observation data
        schema_fields = [field for field in subset_observation_data.collect_schema()]
        if observation_variables is not None and any(
            var not in schema_fields for var in all_variables
        ):
            raise ValueError(f"Variables {all_variables} not found in observation data")

        # subset the variables
        if observation_variables is not None:
            subset_observation_data = subset_observation_data.select(all_variables)

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: IncomingDataInput, **kwargs):
        if isinstance(data, pl.LazyFrame):
            data = data.collect().to_pandas()
            data = data.set_index(["time", "latitude", "longitude"])
            # GHCN data can have duplicate values right now, dropping here if it occurs
            try:
                data = data.to_xarray()
            except ValueError as e:
                if "non-unique" in str(e):
                    pass
                data = data.drop_duplicates().to_xarray()
            return data
        else:
            raise ValueError(f"Data is not a polars LazyFrame: {type(data)}")

class LSR(Observation):
    """
    Observation class for local storm report (LSR) tabular data.

    run_pipeline() returns a dataset with LSRs and practically perfect hindcast gridded
    probability data. IndividualCase date ranges for LSRs should ideally be
    12 UTC to the next day at 12 UTC to match SPC methods.
    """

    source: str = LSR_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> IncomingDataInput:
        
        # force LSR to use anon token to prevent google reauth issues for users
        observation_data = pd.read_parquet(self.source, storage_options={'token': 'anon'})

        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: IncomingDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> IncomingDataInput:
        if not isinstance(observation_data, pd.DataFrame):
            raise ValueError(f"Expected pandas DataFrame, got {type(observation_data)}")

        # latitude, longitude are strings by default, convert to float
        observation_data["lat"] = observation_data["lat"].astype(float)
        observation_data["lon"] = observation_data["lon"].astype(float)
        observation_data["time"] = pd.to_datetime(observation_data["time"])

        filters = (
            (observation_data["time"] >= case.start_date)
            & (observation_data["time"] <= case.end_date)
            & (observation_data["lat"] >= case.location.latitude_min)
            & (observation_data["lat"] <= case.location.latitude_max)
            & (observation_data["lon"] >= rs.convert_longitude_to_180(case.location.longitude_min))
            & (observation_data["lon"] <= rs.convert_longitude_to_180(case.location.longitude_max))
        )

        subset_observation_data = observation_data.loc[filters]

        subset_observation_data = subset_observation_data.rename(
            columns={"lat": "latitude", "lon": "longitude", "time": "valid_time"}
        )

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: IncomingDataInput, **kwargs):
        if isinstance(data, pd.DataFrame):
            data = data.set_index(["valid_time", "latitude", "longitude"])
            data = xr.Dataset.from_dataframe(data[~data.index.duplicated(keep='first')],sparse=True)
            return data
        else:
            raise ValueError(f"Data is not a pandas DataFrame: {type(data)}")

class IBTrACS(Observation):
    """
    Observation class for IBTrACS data.
    """

    source: str = IBTRACS_URI

    def _open_data_from_source(
        self, storage_options: Optional[dict] = None, **kwargs
    ) -> IncomingDataInput:
        # not using storage_options in this case due to NetCDF4Backend not supporting them
        observation_data: pl.LazyFrame = pl.scan_csv(
            self.source, storage_options=storage_options
        )
        return observation_data

    def _subset_data_to_case(
        self,
        observation_data: IncomingDataInput,
        case: case.IndividualCase,
        variables: Optional[list[str]] = None,
        **kwargs,
    ) -> IncomingDataInput:
        # Create filter expressions for LazyFrame
        year = case.start_date.year

        if not isinstance(observation_data, pl.LazyFrame):
            raise ValueError(f"Expected polars LazyFrame, got {type(observation_data)}")

        # Apply filters using proper polars expressions
        subset_observation_data = observation_data.filter(
            (pl.col("NAME") == case.title.upper())
        )

        all_variables = [
            "SEASON",
            "NUMBER",
            "NAME",
            "ISO_TIME",
            "LAT",
            "LON",
            "WMO_WIND",
            "USA_WIND",
            "WMO_PRES",
            "USA_PRES",
        ]
        # Get the season (year) from the case start date, cast as string as polars is interpreting the schema as strings
        season = str(year)

        # First filter by name to get the storm data
        subset_observation_data = observation_data.filter(
            (pl.col("NAME") == case.title.upper())
        )

        # Create a subquery to find all storm numbers in the same season
        matching_numbers = (
            subset_observation_data.filter(pl.col("SEASON") == season)
            .select("NUMBER")
            .unique()
        )

        # Apply the filter to get all data for storms with the same number in the same season
        # This maintains the lazy evaluation
        subset_observation_data = observation_data.join(
            matching_numbers, on="NUMBER", how="inner"
        ).filter((pl.col("NAME") == case.title.upper()) & (pl.col("SEASON") == season))

        # check that the variables are in the observation data
        schema_fields = [field for field in subset_observation_data.collect_schema()]
        if variables is not None and any(
            var not in schema_fields for var in all_variables
        ):
            raise ValueError(f"Variables {all_variables} not found in observation data")

        # subset the variables
        if variables is not None:
            subset_observation_data = subset_observation_data.select(all_variables)

        return subset_observation_data

    def _maybe_convert_to_dataset(self, data: IncomingDataInput, **kwargs):
        if isinstance(data, pl.LazyFrame):
            data = data.collect().to_pandas()
            data = data.set_index(["ISO_TIME"])
            try:
                data = data.to_xarray()
            except ValueError as e:
                if "non-unique" in str(e):
                    pass
                data = data.drop_duplicates().to_xarray()
            return data
        else:
            raise ValueError(f"Data is not a polars LazyFrame: {type(data)}")


# Cases:

In [63]:
#TODO implement this in case.py
@dataclasses.dataclass
class BaseCaseMetadataCollection:
    cases: List[case.IndividualCase]

    def subset_cases_by_event_type(self, event_type: str) -> List[case.IndividualCase]:
        """Subset the cases in the collection by event type."""
        return [c for c in self.cases if c.event_type == event_type]

In [64]:
@dataclasses.dataclass
class CaseOperator:
    """A class which stores the graph to process an individual case."""
    
    case: case.IndividualCase
    metrics: list[BaseMetric]
    observations: list[Observation]
    observation_variables: list[str | DerivedVariable]
    forecast_variables: list[str | DerivedVariable]

    
    def evaluate_case(self, forecast: xr.Dataset):
        """Process a case."""
        self.process_metrics(forecast)
        
    def process_metrics(self, forecast: xr.Dataset):
        """Process the metrics."""
        for metric in self.metrics:
            metric.process_metric(forecast, self.observations)

    def build_observations(self, **kwargs) -> list[xr.Dataset]:
        """Build observation xarray Datasets from the observation sources."""
        observation_storage_options = kwargs.get('observation_storage_options', {'remote_protocol': 's3', 'remote_options': {'anon': True}})
        observation_variable_mapping = kwargs.get('observation_variable_mapping', {'anon': True})
        observation_datasets = []
        for observation in self.observations:
            #TODO: need to pipe in storage options here
            obs_dataset = observation().run_pipeline(case=self.case, storage_options=observation_storage_options, observation_variables=self.observation_variables, observation_variable_mapping=observation_variable_mapping)
            observation_datasets.append(obs_dataset)

        return observation_datasets
    

# Events:

In [65]:
class EventType(ABC):
    """A base class defining the interface for ExtremeWeatherBench event types.

    An Event in ExtremeWeatherBench defines a specific weather event type, such as a heat wave,
    severe convective weather, or atmospheric rivers. These events encapsulate a set of cases and
    derived behavior for evaluating those cases. These cases will share common metrics, observations,
    and variables while each having unique dates and locations.

    Attributes:
        event_type: The type of event.
        forecast_variables: A list of variables that are used to forecast the event.
        observation_variables: A list of variables that are used to observe the event.
        case_metadata: A dictionary or yaml file with guiding metadata.
        metrics: A list of Metrics that are used to evaluate the cases.
        observations: A list of Observations that are used as targets for the metrics.
    """

    def __init__(
        self,
        case_metadata: dict[str, Any],
    ):
        self.case_metadata = case_metadata
        self._maybe_expand_not_derived_variables()

    @property
    @abstractmethod
    def event_type(self) -> str:
        pass

    @property
    @abstractmethod
    def forecast_variables(self) -> List[str | DerivedVariable]:
        pass

    @property
    @abstractmethod
    def observation_variables(self) -> List[str | DerivedVariable]:
        pass

    @property
    @abstractmethod
    def metrics(self) -> List[AppliedMetric]:
        pass

    @property   
    @abstractmethod
    def observations(self) -> List[Observation]:
        pass

    def _maybe_expand_not_derived_variables(self) -> List[str | DerivedVariable]:
        """Expand the variables to include the input variables of any derived variables."""
        for v in self.forecast_variables:
            if hasattr(v, 'input_variables'):
                self.forecast_variables = self.forecast_variables + v.input_variables
        
        for v in self.observation_variables:
            if hasattr(v, 'input_variables'):
                self.observation_variables = self.observation_variables + v.input_variables

    def _build_base_case_metadata_collection(self) -> BaseCaseMetadataCollection:
        """Build a list of IndividualCases from the case_metadata."""
        cases = dacite.from_dict(
            data_class=BaseCaseMetadataCollection, 
            data=self.case_metadata, 
            config=dacite.Config(
                    type_hooks={regions.Region: regions.map_to_create_region},
                ),
        )
        cases = BaseCaseMetadataCollection(cases=[c for c in cases.cases if c.event_type == self.event_type])
        return cases
    
    def build_case_operators(self, observation_storage_options: dict[str, Any], observation_variable_mapping: dict[str, str]) -> list[CaseOperator]:
        """Build a CaseOperator from the event type."""
        case_metadata_collection = self._build_base_case_metadata_collection()
        case_operators = [
            CaseOperator(
                case = case,
                metrics = self.metrics,
                observations = self.observations,
                observation_variables=self.observation_variables,
                forecast_variables=self.forecast_variables
                ) 
                for case in case_metadata_collection.cases
                ]
        return case_operators

class HeatWave(EventType):
    event_type = 'heat_wave'
    forecast_variables = ['surface_air_temperature']
    observation_variables = ['surface_air_temperature']
    metrics = [MaximumMAE,
               MaxMinMAE, 
               RMSE,
               OnsetME,
               DurationME]
    observations = [ERA5]

class SevereConvection(EventType):
    event_type = 'severe_convection'
    forecast_variables = [
        CravenSignificantSevereParameter,
    ]
    observation_variables = [
        PracticallyPerfectHindcast,
    ]
    metrics = [CSI, 
               LeadTimeDetection, 
               RegionalHitsMisses, 
               HitsMisses]
    observations = [LSR]


class AtmosphericRiver(EventType):
    event_type = 'atmospheric_river'
    forecast_variables = []
    observation_variables = []
    metrics = [CSI, 
               LeadTimeDetection]
    observations = [ERA5]



# Forecasts

In [66]:
def open_and_preprocess_forecast_dataset(
    forecast_dir: str,
    forecast_variables: list[str | DerivedVariable],
    forecast_variable_mapping: dict[str, list[str | DerivedVariable]],
    forecast_preprocess: Callable = utils._default_preprocess,
    forecast_storage_options: dict = {"remote_protocol": "s3", "remote_options": {"anon": True}},
    forecast_chunks: dict = {"time": 48, "latitude": 721, "longitude": 1440},
) -> xr.Dataset:
    """Open the forecast dataset specified for evaluation.

    If a URI is provided (e.g. s3://bucket/path/to/forecast), the filesystem
    will be inferred from the provided source (in this case, s3). Otherwise,
    the filesystem will assumed to be local.

    Preprocessing examples:
        A typical preprocess function handles metadata changes:

        def _preprocess_cira_forecast_dataset(
            ds: xr.Dataset
        ) -> xr.Dataset:
            ds = ds.rename({"time": "lead_time"})
            return ds

        The preprocess function is applied before variable renaming occurs, so it should
        reference the original variable names in the forecast dataset, not the standardized
        names defined in the ForecastSchemaConfig.

    Args:
        eval_config: The evaluation configuration.
        forecast_schema_config: The forecast schema configuration.
        preprocess: A function that preprocesses the forecast dataset.

    Returns:
        The opened forecast dataset.
    """
    if "zarr" in forecast_dir:
        forecast_ds = xr.open_zarr(forecast_dir, chunks=forecast_chunks)
    elif (
        "parq" in forecast_dir
        or "json" in forecast_dir
        or "parquet" in forecast_dir
    ):
        forecast_ds = open_kerchunk_reference(forecast_dir, storage_options=forecast_storage_options, chunks=forecast_chunks)
    else:
        raise TypeError(
            "Unknown file type found in forecast path, only json, parquet, and zarr are supported."
        )
    forecast_ds = forecast_preprocess(forecast_ds)
    forecast_ds = _maybe_rename_and_subset_forecast_dataset(
        forecast_ds, forecast_variable_mapping
    )
    forecast_ds = _maybe_convert_dataset_lead_time_to_int(forecast_ds)

    return forecast_ds


def open_kerchunk_reference(
    forecast_dir: str,
    storage_options: dict = {"remote_protocol": "s3", "remote_options": {"anon": True}},
    chunks: dict = {"time": 48, "latitude": 721, "longitude": 1440},
) -> xr.Dataset:
    """Open a dataset from a kerchunked reference file in parquet or json format.
    This has been built primarily for the CIRA MLWP S3 bucket's data (https://registry.opendata.aws/aiwp/),
    but can work with other data in the future. Currently only supports CIRA data unless
    schema is identical to the CIRA schema.

    Args:
        file: The path to the kerchunked reference file.
        remote_protocol: The remote protocol to use.

    Returns:
        The opened dataset.
    """
    if "parq" in forecast_dir or "parquet" in forecast_dir:
        bucket = "/".join(forecast_dir.split("/")[:3])
        store = from_url(bucket, skip_signature=True)
        
        registry = ObjectStoreRegistry({bucket: store})
        kerchunk_ds = xr.open_dataset(
            forecast_dir,
            engine="kerchunk",
            storage_options=storage_options
        )
        kerchunk_ds = kerchunk_ds.compute()
    elif "json" in forecast_dir:
        storage_options['fo'] = forecast_dir
        kerchunk_ds = xr.open_dataset(
            "reference://",
            engine="zarr",
            backend_kwargs={
                "storage_options": storage_options,
                "consolidated": False,
            },
        )
    else:
        raise TypeError(
            "Unknown kerchunk file type found in forecast path, only json and parquet are supported."
        )
    return kerchunk_ds


def _maybe_rename_and_subset_forecast_dataset(
    forecast_ds: xr.Dataset, variable_mapping: dict[str, list[str | DerivedVariable]]
) -> xr.Dataset:
    """Rename the forecast dataset to the correct names expected by the evaluation routines.

    Args:
        forecast_ds: The forecast dataset to rename.
        forecast_schema_config: The forecast schema configuration.

    Returns:
        The renamed forecast dataset.
    """
    # Mapping here is used to rename the incoming data variables to the correct
    # names expected by the evaluation routines.
    mapping = {
        variable: variable
        for variable in variable_mapping.keys()
    }
    # Filter the mapping to only include variables that are in the forecast dataset, else
    # an error will be raised.
    mapping = {k: v for k, v in variable_mapping.items() if k in forecast_ds.data_vars}
    variables = mapping.keys()
    forecast_ds = forecast_ds[variables]
    forecast_ds = forecast_ds.rename(mapping)

    return forecast_ds


def _maybe_convert_dataset_lead_time_to_int(
    dataset: xr.Dataset
) -> xr.Dataset:
    """Convert types of variables in an xarray Dataset based on the schema,
    ensuring that, for example, the variable representing lead_time is of type int.

    Args:
        dataset: The input xarray Dataset that uses the schema's variable names.

    Returns:
        An xarray Dataset with adjusted types.
    """

    lead_time = dataset["lead_time"] if "lead_time" in dataset.data_vars else dataset["time"]
    if lead_time.dtype == np.dtype("timedelta64[ns]"):
        # Convert timedelta64[ns] to hours and cast to int
        dataset["lead_time"] = (lead_time / np.timedelta64(1, "h")).astype(int)
    elif lead_time.dtype == np.dtype("int64"):
        # Already an int, do nothing
        pass
    else:
        temporal_resolution_hours = np.squeeze(np.unique(np.diff(dataset['time'].values))/np.timedelta64(1, 'h'))
        dataset["time"] = np.arange(0, dataset['time'].shape[0]*temporal_resolution_hours, temporal_resolution_hours)
        dataset = dataset.rename({"time": "lead_time"})
    return dataset

In [67]:
bucket = "gs://extremeweatherbench"
store = from_url(bucket, skip_signature=True)

[n for n in store.list().collect()]
registry = ObjectStoreRegistry({bucket: store})
storage_options = {
    "remote_protocol": "s3",
    "remote_options": {"anon": True},
}  # options passed to fsspec
kerchunk_ds = xr.open_dataset(
    f"{bucket}/FOUR_v200_GFS.parq",
    engine="kerchunk",
    storage_options=storage_options
)
kerchunk_ds

# Orchestration

In [68]:
def _build_forecast(forecast_dir: str, case_operator: CaseOperator, **kwargs):
    forecast_ds = open_and_preprocess_forecast_dataset(forecast_dir,
                                                       forecast_variables=case_operator.forecast_variables,
                                                       forecast_variable_mapping=kwargs.get('forecast_variable_mapping', {}),
                                                       forecast_storage_options=kwargs.get('forecast_storage_options', {'remote_protocol': 's3', 'remote_options': {'anon': True}}),
                                                       forecast_chunks=kwargs.get('forecast_chunks', {'time': 48, 'latitude': 721, 'longitude': 1440}))
    time_indices = utils.derive_indices_from_init_time_and_lead_time(forecast_ds, case_operator.case.start_date, case_operator.case.end_date)
    forecast_ds_time_subset = forecast_ds.isel(init_time=np.unique(time_indices))
    forecast_ds = case_operator.case.location.mask(forecast_ds_time_subset,drop=True)
    return forecast_ds

class ExtremeWeatherBench:
    def run(self, events: list[EventType], forecast_dir: str, cache_dir: Optional[str] = None, *args, **kwargs):
        '''Runs the workflow in the order of the event operators and cases inside the event operators.'''

        for event in events:
            case_operators = event.build_case_operators(observation_storage_options=kwargs.get('observation_storage_options', {'remote_protocol': 's3', 'remote_options': {'anon': True}}), observation_variable_mapping=kwargs.get('observation_variable_mapping', {'anon': True}))
            
        for case_operator in case_operators:
            observation_ds, forecast_ds = self.build_datasets(case_operator, forecast_dir, *args, **kwargs)
            return observation_ds, forecast_ds

    def build_datasets(self, case_operator: CaseOperator, forecast_dir: str, **kwargs):
        observation_ds = case_operator.build_observations(**kwargs)
        forecast_ds = self.build_forecast(case_operator, forecast_dir, **kwargs)
        return observation_ds, forecast_ds


    
    def build_forecast(self, case_operator: CaseOperator, forecast_dir: str, **kwargs):
        pre_derived_forecast_ds = _build_forecast(forecast_dir, case_operator, **kwargs)
        derived_forecast_ds = maybe_derive_variables(pre_derived_forecast_ds, case_operator, variables=case_operator.forecast_variables)
        return derived_forecast_ds



# test pph

In [69]:
def practically_perfect_hindcast(
    ds: xr.Dataset,
    resolution: float = 0.25,
    # TODO: add report type back in
    # report_type: Union[Literal["all"], list[Literal["tor", "hail", "wind"]]] = "all",
    sigma: float = 1.5,
) -> xr.Dataset:
    """Compute the Practically Perfect Hindcast (PPH) using storm report data using latitude/longitude grid spacing
    instead of the NCEP 212 Eta Lambert Conformal projection; based on the method described in Hitchens et al 2013,
    https://doi.org/10.1175/WAF-D-12-00113.1

    Args:
        ds: An xarray Dataset containing the storm report data as a sparse (COO) array.
        resolution: The resolution of the grid in degrees to use. Default is 0.25 degrees.
        sigma: The standard deviation of the gaussian filter to use. Default is 1.5.
    Returns:
        pph: An xarray Dataset containing the PPH and storm report data.
    """

    # Create a global grid with 0.25 degree resolution (721 x 1440)
    min_lat_fixed = -90.0  # Start at -90 degrees
    max_lat_fixed = 90.0  # End at 90 degrees
    min_lon_fixed = 0.0  # Start at 0 degrees
    max_lon_fixed = 359.75  # End at 359.75 degrees (360-0.25)

    # Create the grid coordinates
    grid_lats = np.arange(min_lat_fixed, max_lat_fixed + resolution, resolution)
    grid_lons = np.arange(min_lon_fixed, max_lon_fixed + resolution, resolution)

    # Create target coordinates for regridding
    target_coords = {
        "latitude": xr.DataArray(grid_lats, dims=["latitude"]),
        "longitude": xr.DataArray(grid_lons, dims=["longitude"]),
    }

    # Regrid the sparse dataset to the fixed global grid
    # First, ensure longitude is in 0-360 range
    if "longitude" in ds.coords:
        ds = ds.assign_coords(longitude=utils.convert_longitude_to_360(ds.longitude))

    # Interpolate the sparse data to the target grid
    regridded_ds = ds.interp(
        latitude=target_coords["latitude"],
        longitude=target_coords["longitude"],
        method="nearest",  # Use nearest neighbor for sparse data
    )

    # Fill NaN values with 0 for the reports
    if "reports" in regridded_ds.data_vars:
        regridded_ds["reports"] = regridded_ds["reports"].fillna(0)

    # Apply gaussian smoothing to the regridded data
    if "reports" in regridded_ds.data_vars:
        reports_data = regridded_ds["reports"].values
        smoothed_reports = gaussian_filter(reports_data, sigma=sigma)
        regridded_ds["practically_perfect"] = xr.DataArray(
            smoothed_reports,
            dims=["latitude", "longitude"],
            coords={"latitude": grid_lats, "longitude": grid_lons},
        )

    return regridded_ds

# Simple Event Operator:

In [70]:
case_yaml = utils.read_event_yaml('/Users/taylor/code/ExtremeWeatherBench/src/extremeweatherbench/data/events.yaml')
test_yaml = {'cases': [case_yaml['cases'][0]]}
test_heat_wave = HeatWave(case_metadata=test_yaml)
# case_operator = test_heat_wave.build_case_operator()

In [71]:
forecast_dir = 'gs://extremeweatherbench/FOUR_v200_GFS.parq'

In [72]:
test_ewb = ExtremeWeatherBench()
test_ewb.run(
    events=[test_heat_wave],
    forecast_dir=forecast_dir,
    storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, 
    chunks={'time': 48, 'latitude': 721, 'longitude': 1440}, 
    observation_variable_mapping={'2m_temperature':'surface_air_temperature'}, 
    forecast_variable_mapping={'t2':'surface_air_temperature'}
)

([<xarray.Dataset> Size: 503kB
  Dimensions:                  (time: 313, latitude: 20, longitude: 20)
  Coordinates:
    * latitude                 (latitude) float32 80B 50.0 49.75 ... 45.5 45.25
    * longitude                (longitude) float32 80B 235.2 235.5 ... 239.8 240.0
    * time                     (time) datetime64[ns] 3kB 2021-06-20 ... 2021-07-03
  Data variables:
      surface_air_temperature  (time, latitude, longitude) float32 501kB dask.array<chunksize=(48, 20, 20), meta=np.ndarray>
  Attributes:
      last_updated:           2025-08-01 01:47:15.698116+00:00
      valid_time_start:       1940-01-01
      valid_time_stop:        2025-04-30
      valid_time_stop_era5t:  2025-07-26],
 <xarray.Dataset> Size: 3MB
 Dimensions:                  (init_time: 45, lead_time: 41, latitude: 20,
                               longitude: 20)
 Coordinates:
   * init_time                (init_time) datetime64[ns] 360B 2021-06-10T12:00...
   * latitude                 (latitude) float

In [73]:



# future approach
# test_ewb.run(
#     events=[test_heat_wave, severe],
#     forecast_dir=forecast_dir,
#     forecast_storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, 
#     chunks={'time': 48, 'latitude': 721, 'longitude': 1440}, 
#     observation_variable_mapping={'2m_temperature':'surface_air_temperature'}, 
#     forecast_variable_mapping={'t2':'surface_air_temperature'},
#     cache_dir=cache_dir,
#     output_dir=output_dir
# )

# from tqdm.auto import tqdm
# for n in tqdm(test_heat_wave_event_operator.pre_composed_case_operators):
#     print(n)

In [82]:
test_yaml = {'cases': [[n for n in case_yaml['cases'] if n['event_type']=='severe_convection'][0]]}
severe = SevereConvection(case_metadata=test_yaml)
ewb = ExtremeWeatherBench()
ewb.run(    
    events=[severe],
    forecast_dir=forecast_dir,
    storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}}, 
    chunks={'time': 48, 'latitude': 721, 'longitude': 1440}, 
    observation_variable_mapping=rs.ERA5_MAPPING, 
    forecast_variable_mapping={'msl':'mean_sea_level_pressure',
 'r':'relative_humidity',
 'sp':'surface_pressure',
 't2':'surface_air_temperature',
 'u10':'surface_eastward_wind',
 'v10':'surface_northward_wind',
 't':'air_temperature',
 'u':'eastward_wind',
 'v':'northward_wind',
 'z':'geopotential_height'})   

([<xarray.Dataset> Size: 309kB
  Dimensions:                       (valid_time: 197, latitude: 176,
                                     longitude: 211)
  Coordinates:
    * valid_time                    (valid_time) datetime64[ns] 2kB 2024-07-13 ...
    * latitude                      (latitude) float64 1kB 37.92 38.18 ... 48.83
    * longitude                     (longitude) float64 2kB -113.0 ... -67.87
  Data variables:
      report_type                   (valid_time, latitude, longitude) object 3kB <COO: nnz=249, fill_value=nan>
      Scale                         (valid_time, latitude, longitude) object 3kB <COO: nnz=249, fill_value=nan>
      practically_perfect_hindcast  (latitude, longitude) float64 297kB nan ......],
 <xarray.Dataset> Size: 5GB
 Dimensions:                              (init_time: 21, lead_time: 41,
                                           latitude: 100, longitude: 200,
                                           level: 13)
 Coordinates:
   * init_time      

In [95]:
# With the EventOperator class
simple_event_operator = EventOperator(events=[heat_waves, severe])
ewb = ExtremeWeatherBench(simple_event_operator, forecast_dir)

In [96]:
ewb.event_operator

EventOperator(events=[<__main__.HeatWave object at 0x1198d6650>, <__main__.SevereConvection object at 0x1198d69e0>], pre_composed_metrics={'heat_wave': [<class '__main__.MaximumMAE'>, <class '__main__.MaxMinMAE'>, <class '__main__.RegionalRMSE'>, <class '__main__.OnsetME'>, <class '__main__.DurationME'>], 'severe_convection': [<class '__main__.CSI'>, <class '__main__.LeadTimeDetection'>, <class '__main__.RegionalHitsMisses'>, <class '__main__.HitsMisses'>]}, pre_composed_observations={'heat_wave': [<class '__main__.ERA5'>], 'severe_convection': [<class '__main__.LSR'>]}, pre_composed_forecast_variables={'heat_wave': ['surface_air_temperature'], 'severe_convection': [<class '__main__.CravenSignificantSevereParameter'>, 'air_temperature', 'dewpoint_temperature', 'relative_humidity', 'eastward_wind', 'northward_wind', 'surface_eastward_wind', 'surface_northward_wind', 'air_temperature', 'dewpoint_temperature', 'relative_humidity', 'eastward_wind', 'northward_wind', 'surface_eastward_wind'

# Tests:

In [25]:
severe = SevereConvection(case_metadata=case_yaml)

# With the EventOperator class
simple_event_operator = EventOperator(events=[severe])
ewb = ExtremeWeatherBench(simple_event_operator, forecast_dir)

In [None]:
for n in ewb.event_operator.pre_composed_case_operators:
    test_output = ewb.process_observations(n)
    ax = plot_practically_perfect_hindcast(test_output)
    plt.show()

Exception ignored in: <function WeakSet.__init__.<locals>._remove at 0x100bd85e0>
Traceback (most recent call last):
  File "/Users/taylor/.local/share/uv/python/cpython-3.13.1-macos-aarch64-none/lib/python3.13/_weakrefset.py", line 39, in _remove
    def _remove(item, selfref=ref(self)):
KeyboardInterrupt: 
