### Visualizing the Mean Absolute Error between forecast runs

In [None]:
import logging

import matplotlib.pyplot as plt

from jua.client import JuaClient
from jua.weather.models import Models
from jua.weather.variables import Variables

logging.basicConfig(level=logging.INFO)

In [None]:
# Define the data we want to load
europe_lat_slice = slice(36, 71)
europe_lon_slice = slice(-15, 50)
forecast_window = slice(0, 48)  # Hours
variables = [
    Variables.AIR_TEMPERATURE_AT_HEIGHT_LEVEL_2M,
    Variables.WIND_SPEED_AT_HEIGHT_LEVEL_100M,
]

In [None]:
# Setup the client and the model to use
client = JuaClient()
model = client.weather.get_model(Models.EPT1_5)

In [None]:
# Load the latest and second-to-last forecast
# This will take some time to load, depending on the speed of the internet connection
# Usually roughly 2 minutes total
# Skipping to latest forecast to ensure we always have a full forecast available
forecast_a_init_time = model.forecast.get_available_init_times()[1]
print(f"Loading forecast for init_time={forecast_a_init_time}")
forecast_a = model.forecast.get_forecast(
    init_time=forecast_a_init_time,
    variables=variables,
    latitude=europe_lat_slice,
    longitude=europe_lon_slice,
    prediction_timedelta=forecast_window,
)
forecast_b_init_time = model.forecast.get_available_init_times()[2]
print(f"Loading forecast for init_time={forecast_b_init_time}")
forecast_b = model.forecast.get_forecast(
    init_time=forecast_b_init_time,
    variables=variables,
    latitude=europe_lat_slice,
    longitude=europe_lon_slice,
    prediction_timedelta=forecast_window,
)

In [None]:
# Instead of prediction_timedelta, we use total_time
# so we can compare the two forecasts
# Adds a new dimension to the dataset
forecast_a_ds = forecast_a.to_xarray().jua.to_total_time()
forecast_b_ds = forecast_b.to_xarray().jua.to_total_time()

# Next we drop the time dimension
forecast_a_ds = forecast_a_ds.drop_vars("time")
forecast_b_ds = forecast_b_ds.drop_vars("time")

# Stack the two datasets base
delta = forecast_a_ds - forecast_b_ds
delta_abs = abs(delta)
delta_mean = delta_abs.mean(dim="total_time")
print(delta_abs)
num_vars = len(variables)
_, axs = plt.subplots(num_vars, 1, figsize=(10, num_vars * 5))
for i, var in enumerate(variables):
    delta_mean[var].plot(ax=axs[i])
    axs[i].set_title(var)
plt.show()