# 🌍 Demo 4: Putting Your Forecast to the Test

## 🧠 Learning Objectives
- Use the `weatherbench2` Python library for forecast verification.
- Load custom forecast and ERA5 data for comparison.
- Compute standard metrics like RMSE and ACC.
- Focus evaluation on specific countries or regions.
- Visualize how forecast skill changes with lead time.
- Understand the power of localized hindcasting for tailored insights.


## 🎯 Objective

In this demo, we scientifically **grade the forecast** you created in **Demo 2**, using professional tools like `weatherbench2`. Instead of relying on generic operational forecast products, we’ll do a **custom local evaluation**.

**Theme:** *We are doing this locally to get custom answers that operational websites can't provide.*

You’ll compute metrics like RMSE and ACC over custom regions (like Kenya or Chile), visualize model skill, and learn why localized hindcasts are powerful.


In [20]:
import xarray as xr
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown
import xskillscore as xs  # For verification metrics

# Optional: Import warnings to suppress potential warnings
import warnings
warnings.filterwarnings('ignore')

In [21]:
forecast_path = widgets.Text(
    description='Forecast file:',
    placeholder='Enter path to NetCDF forecast from Demo 2',
    layout=widgets.Layout(width='80%')
)

display(forecast_path)

Text(value='', description='Forecast file:', layout=Layout(width='80%'), placeholder='Enter path to NetCDF for…

In [22]:
def load_forecast(path):
    try:
        forecast_ds = xr.open_dataset(path)
        display(Markdown("✅ Forecast data loaded successfully."))
        return forecast_ds
    except Exception as e:
        display(Markdown(f"❌ Error loading forecast: {e}"))
        return None

forecast_ds = None

# Button to trigger loading
load_button = widgets.Button(description="Load Forecast")
load_output = widgets.Output()

def on_load_clicked(b):
    global forecast_ds
    with load_output:
        load_output.clear_output()
        if not forecast_path.value.strip():
            display(Markdown("❌ Please enter a valid file path."))
        else:
            forecast_ds = load_forecast(forecast_path.value)

load_button.on_click(on_load_clicked)
display(load_button, load_output)


Button(description='Load Forecast', style=ButtonStyle())

Output()

In [23]:
import xarray as xr

if forecast_ds is not None:
    forecast_time = forecast_ds.time
    forecast_var = list(forecast_ds.data_vars)[0]  # e.g., "2t"

    # Mapping from forecast variable names to ERA5 GCS variable names
    var_map = {
        "2t": "2m_temperature",
        "z500": "geopotential",  # adjust if needed
        "u10": "10m_u_component_of_wind",
        "v10": "10m_v_component_of_wind",
        # Add more mappings if needed
    }

    era5_var = var_map.get(forecast_var, forecast_var)

    # Load from public Zarr store on GCS
    era5_path = "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3"

    full_era5 = xr.open_zarr(
        era5_path,
        chunks=None,
        storage_options={"token": "anon"}
    )[[era5_var]]

    # Select only the matching forecast times
    era5_ds = full_era5.sel(time=forecast_time)

    display(Markdown(f"✅ ERA5 ground truth variable `{era5_var}` loaded and time-aligned."))


✅ ERA5 ground truth variable `2m_temperature` loaded and time-aligned.

In [34]:
region_bounds = {
    "Global": (-90, 90, 0, 360),
    "Northern Hemisphere": (0, 90, 0, 360),
    "Tropics": (-30, 30, 0, 360),
    "Bangladesh": (20.5, 26.5, 88, 92.5),
    "Chile": (-56, -17, -76, -66),
    "Nigeria": (4, 14, 3, 15),
    "Ethiopia": (3, 15, 33, 48),
    "Kenya": (-5, 5, 33, 42)
}

region_selector = widgets.Dropdown(
    options=region_bounds.keys(),
    value="Global",
    description="Region:"
)

metric_selector = widgets.SelectMultiple(
    options=["rmse", "acc"],  # Supported: rmse, acc (anomaly correlation)
    value=["rmse", "acc"],
    description="Metrics:"
)

display(region_selector, metric_selector)

Dropdown(description='Region:', options=('Global', 'Northern Hemisphere', 'Tropics', 'Bangladesh', 'Chile', 'N…

SelectMultiple(description='Metrics:', index=(0, 1), options=('rmse', 'acc'), value=('rmse', 'acc'))

In [35]:
# -------------------------------
# Step 5 (FIXED): Spatial Subsetting with Flexible Coordinate Names
# -------------------------------

def standardize_spatial_dims(ds):
    """Rename lat/lon dimensions to 'latitude' and 'longitude' if needed."""
    ds = ds.copy()
    # Mapping common names
    lat_names = ['latitude', 'lat', 'g0_lat_1', 'y']
    lon_names = ['longitude', 'lon', 'g0_lon_2', 'x']

    lat_dim = None
    for name in lat_names:
        if name in ds.dims:
            lat_dim = name
            break

    lon_dim = None
    for name in lon_names:
        if name in ds.dims:
            lon_dim = name
            break

    if lat_dim is None or lon_dim is None:
        raise ValueError(f"Could not find latitude/longitude in dataset dimensions: {list(ds.dims)}")

    # Rename to standard names
    ds = ds.rename({lat_dim: 'latitude', lon_dim: 'longitude'})

    # Also standardize coordinate variables
    if 'latitude' in ds.coords:
        pass
    elif 'latitude' in ds.variables:  # sometimes stored as variable
        ds = ds.assign_coords(latitude=ds['latitude'])
    else:
        raise ValueError("Latitude coordinate not found.")

    if 'longitude' in ds.coords:
        pass
    elif 'longitude' in ds.variables:
        ds = ds.assign_coords(longitude=ds['longitude'])
    else:
        raise ValueError("Longitude coordinate not found.")

    # Ensure longitude is in 0-360 format for consistency with ERA5
    # Convert from -180-180 to 0-360 if needed
    if ds['longitude'].min() < 0:
        ds = ds.assign_coords(longitude=((ds['longitude'] + 180) % 360 - 180))
        # Then wrap to 0-360
        ds = ds.assign_coords(longitude=(ds['longitude'] % 360))
        # Sort by longitude because wrapping may unsort
        ds = ds.sortby('longitude')

    return ds

def subset_region(ds, region_key):
    ds = standardize_spatial_dims(ds)  # Ensure consistent names
    bounds = region_bounds[region_key]
    
    if isinstance(bounds, str) and bounds == "land":
        raise NotImplementedError("Land mask not implemented here. Skipping.")
    
    south, north, west, east = bounds

    # Make sure longitudes are in 0–360 to match ERA5
    west = west % 360
    east = east % 360

    # Handle crossing the prime meridian
    if west <= east:
        longitude_slice = slice(west, east)
    else:
        # Crosses 0° meridian (e.g., 350 to 10)
        ds1 = ds.sel(longitude=slice(west, 360))
        ds2 = ds.sel(longitude=slice(0, east))
        return xr.concat([ds1, ds2], dim='longitude')

    try:
        subset = ds.sel(
            latitude=slice(north, south),   # latitude usually descends from 90 to -90
            longitude=longitude_slice
        )
        return subset
    except Exception as e:
        raise RuntimeError(f"Error during spatial selection: {e}")
run_button = widgets.Button(description="Run Verification")
output = widgets.Output()

def run_verification(forecast, truth, region, metrics):
    results = xr.Dataset()

    # Get variable name
    var_name = list(forecast.data_vars)[0]
    era5_var = var_map.get(var_name, var_name)
    
    # Subset both datasets
    fc_sub = subset_region(forecast, region)
    tr_sub = subset_region(truth, region)

    # Align time (just in case)
    fc_sub, tr_sub = xr.align(fc_sub, tr_sub, join="inner")

    # Extract data
    fc_var = fc_sub[var_name]
    tr_var = tr_sub[era5_var]

    # Lead dimension: assume forecast has 'step' or 'lead_time', or use 'time'
    if 'step' in fc_var.dims:
        lead_dim = 'step'
        lags = fc_var[lead_dim]
    else:
        lead_dim = 'time'  # fallback
        lags = fc_var[lead_dim]  # not ideal, but works for single-step

    # Compute metrics over space, reduce to lead time
    if "rmse" in metrics:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            rmse = xs.rmse(fc_var, tr_var, dim=['latitude', 'longitude'], skipna=True)
            results['rmse'] = rmse

    if "acc" in metrics:
        # ACC requires climatology; use global mean over time as reference?
        # Here we compute anomaly correlation coefficient
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            acc = xs.pearson_r(
                fc_var - fc_var.mean(dim=['latitude', 'longitude']),
                tr_var - tr_var.mean(dim=['latitude', 'longitude']),
                dim=['latitude', 'longitude']
            )
            results['acc'] = acc

    return results

def plot_results(result):
    for metric in result.data_vars:
        plt.figure(figsize=(10, 4))
        var = result[metric]
        lead_dim = var.dims[0]  # Assume first dim is lead time

        var.plot(marker='o')
        plt.title(f"{metric.upper()} vs Forecast Lead Time")
        plt.xlabel("Lead Time")
        plt.ylabel(metric.upper())
        plt.grid(True, alpha=0.5)
        plt.xticks(rotation=0)
        plt.tight_layout()
        plt.show()

def on_run_clicked(b):
    with output:
        output.clear_output()
        if forecast_ds is None or era5_ds is None:
            display(Markdown("❌ Please load both forecast and ERA5 data first."))
            return
        display(Markdown("⏳ Running verification..."))
        try:
            result = run_verification(
                forecast_ds,
                era5_ds,
                region_selector.value,
                list(metric_selector.value)
            )
            display(Markdown("✅ Verification complete."))
            plot_results(result)
        except Exception as e:
            display(Markdown(f"❌ Error during verification: {e}"))

run_button.on_click(on_run_clicked)
display(run_button, output)

Button(description='Run Verification', style=ButtonStyle())

Output()

## 📊 Interpret Your Results

- **Look at the RMSE plot**: How does the error change as the forecast lead time increases? Why is this expected?
- **Try changing the region** from "Global" to "Kenya" and re-run the analysis.
    - Does the model's **ACC score** degrade faster or slower in Kenya?
    - What does this imply about model performance for **East African** weather?

Play with different regions and metrics to gain insights!


## 🔍 Key Takeaways

- Large operational centers provide broad forecasts — but can't offer **custom regional analysis**.
- By running a **local hindcast** and benchmarking it with tools like `weatherbench2`, you can answer **targeted, high-impact questions**:
    - *How accurate is this model over Bangladesh during monsoon?*
    - *How reliable is the 5-day forecast for heatwaves in Chile?*
- This is the true power of open science and AI: **empowering local experts** to run meaningful evaluations for their region.

🧪 Keep exploring. Try different models, dates, and regions!
