# General Pre-processing

In [2]:
import xarray as xr

In [3]:
ds = xr.open_dataset("./output_states/init_ERA5_20230630T00_lead_360.nc")

In [4]:
ds["step"] = ds["step"].astype("timedelta64[h]").astype("timedelta64[ns]")
ds["lat"] = ds["lat"].astype("float32")
ds["lon"] = ds["lon"].astype("float32")

In [5]:
ds = ds.rename(
    {
        "2t": "2m_temperature",
        "tp": "total_precipitation_6hr",
        "z_500": "geopotential_500",
        "lat": "latitude",
        "lon": "longitude",
        "step": "prediction_timedelta",
    }
)

# For running benchmarking on a pressure level variable

In [58]:
ds_z500 = (
    ds[["geopotential_500"]]
    .expand_dims(dim={"level": [500.0]}, axis=2)
    .rename({"geopotential_500": "geopotential"})
)

# Run Benchmark

In [None]:
import apache_beam as beam
from weatherbenchX.data_loaders import xarray_loaders
from weatherbenchX.metrics import deterministic
from weatherbenchX import aggregation
from weatherbenchX import weighting
from weatherbenchX import binning
from weatherbenchX import time_chunks
from weatherbenchX import beam_pipeline

In [59]:
ds_preds = ds_z500.sel(
    prediction_timedelta=ds_z500.prediction_timedelta.dt.seconds == 0
)

In [20]:
# Load in the ERA5 data from Google Cloud Storage
ERA5_PATH = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"
FULL_ERA5 = xr.open_zarr(ERA5_PATH, chunks=None)

In [63]:
variables = list(ds_preds.data_vars)
levels = ds_preds.level.values
times = ds_preds.time.values + ds_preds.prediction_timedelta.values
hour = ds_preds.time.dt.hour[0].values
dayofyear = (ds_preds.time + ds_preds.prediction_timedelta).dt.dayofyear.values

In [64]:
target_ds = FULL_ERA5[variables].sel(time=times, level=levels).compute()

In [65]:
LSM = FULL_ERA5["land_sea_mask"].sel(time=ds_z500.time[0]).drop_vars("time").compute()

In [68]:
CLIM_PATH = (
    "gs://weatherbench2/datasets/era5-hourly-climatology/1990-2017_6h_1440x721.zarr"
)
CLIM_DS = (
    xr.open_zarr(CLIM_PATH, chunks={})[variables]
    .sel(hour=[hour], dayofyear=dayofyear.squeeze(), level=levels)
    .compute()
)

In [70]:
prediction_data_loader = xarray_loaders.PredictionsFromXarray(
    ds=ds_preds, variables=variables
)
target_data_loader = xarray_loaders.TargetsFromXarray(ds=target_ds, variables=variables)


In [74]:
times = time_chunks.TimeChunks(
    ds_preds.time.values,
    ds_preds.prediction_timedelta.values,
    init_time_chunk_size=1,
    lead_time_chunk_size=len(ds_preds.prediction_timedelta.values),
)

In [75]:
metrics = {
    "rmse": deterministic.RMSE(),
    "acc": deterministic.ACC(CLIM_DS),
}
weigh_by = [weighting.GridAreaWeighting()]

# Change these to participant regions
regions = {
    # ((lat_min, lat_max), (lon_min, lon_max))
    "global": ((-90, 90), (0, 360)),
    "na": ((24.08, 50), (360 - 126, 360 - 65)),
    "europe": ((35, 71), (360 - 10, 36)),
}
bin_by = [binning.Regions(regions)]

aggregator = aggregation.Aggregator(
    reduce_dims=["init_time", "latitude", "longitude"],
    bin_by=bin_by,
    weigh_by=weigh_by,
)

In [76]:
root = beam.Pipeline(runner="DirectRunner")
beam_pipeline.define_pipeline(
    root=root,
    times=times,
    predictions_loader=prediction_data_loader,
    targets_loader=target_data_loader,
    metrics=metrics,
    aggregator=aggregator,
    out_path="./out.nc",
)
root.run()



<apache_beam.runners.portability.fn_api_runner.fn_runner.RunnerResult at 0x7d26012536d0>