# Code to generate figure 2 part 2 for freeze events for the paper

In [1]:
# setup all the imports
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import yaml
import urllib.request
import matplotlib.font_manager
flist = matplotlib.font_manager.get_font_names()
from tempfile import NamedTemporaryFile
import urllib
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import matplotlib.patches as patches
from extremeweatherbench import evaluate, utils, cases, defaults, inputs, metrics
sns.set_theme(style='whitegrid')
from shapely.geometry import Polygon
import shapely
from pathlib import Path
import multiprocessing

# make the basepath - change this to your local path
basepath = Path.home() / 'ExtremeWeatherBench' / ''
basepath = str(basepath) + '/'

In [3]:
# setup the templates to load in the data

# Forecast Examples
cira_freeze_forecast_FOURv2 = inputs.KerchunkForecast(
    source="gs://extremeweatherbench/FOUR_v200_GFS.parq",
    variables=[
        "surface_air_temperature",
        "surface_eastward_wind",
        "surface_northward_wind",
    ],
    variable_mapping={
        "t2": "surface_air_temperature",
        "10u": "surface_eastward_wind",
        "10v": "surface_northward_wind",
    },
    storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
    preprocess=defaults._preprocess_bb_cira_forecast_dataset,
)

cira_freeze_forecast_GC = inputs.KerchunkForecast(
    source="gs://extremeweatherbench/GRAP_v100_IFS.parq",
    variables=[
        "surface_air_temperature",
        "surface_eastward_wind",
        "surface_northward_wind",
    ],
    variable_mapping={
        "t2": "surface_air_temperature",
        "10u": "surface_eastward_wind",
        "10v": "surface_northward_wind",
    },
    storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
    preprocess=defaults._preprocess_bb_cira_forecast_dataset,
)

cira_freeze_forecast_PANG = inputs.KerchunkForecast(
    source="gs://extremeweatherbench/PANG_v100_IFS.parq",
    variables=[
        "surface_air_temperature",
        "surface_eastward_wind",
        "surface_northward_wind",
    ],
    variable_mapping={
        "t2": "surface_air_temperature",
        "10u": "surface_eastward_wind",
        "10v": "surface_northward_wind",
    },
    storage_options={"remote_protocol": "s3", "remote_options": {"anon": True}},
    preprocess=defaults._preprocess_bb_cira_forecast_dataset,
)


hres_forecast = inputs.ZarrForecast(
    source="gs://weatherbench2/datasets/hres/2016-2022-0012-1440x721.zarr",
    variables=[
        "surface_air_temperature",
        "surface_eastward_wind",
        "surface_northward_wind",
    ],
    variable_mapping=inputs.HRES_metadata_variable_mapping,
    storage_options={"remote_options": {"anon": True}},
)


FOURv2_GHCN_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.ghcn_heatwave_target,
        forecast=cira_freeze_forecast_FOURv2, 
    ),
]

FOURv2_ERA5_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.era5_heatwave_target,
        forecast=cira_freeze_forecast_FOURv2, 
    ),
]

GC_GHCN_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.ghcn_heatwave_target,
        forecast=cira_freeze_forecast_GC, 
    ),
]

GC_ERA5_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.era5_heatwave_target,
        forecast=cira_freeze_forecast_GC, 
    ),
]

PANG_GHCN_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.ghcn_heatwave_target,
        forecast=cira_freeze_forecast_PANG, 
    ),
]

PANG_ERA5_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.era5_heatwave_target,
        forecast=cira_freeze_forecast_PANG, 
    ),
]

HRES_GHCN_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.ghcn_heatwave_target,
        forecast=hres_forecast,
    ),
]

HRES_ERA5_EVALUATION_OBJECTS = [
    inputs.EvaluationObject(
        event_type="freeze",
        metric_list=[
            metrics.MinimumMAE,
            metrics.RMSE,
            metrics.OnsetME,
            metrics.DurationME,
        ],
        target=defaults.era5_heatwave_target,
        forecast=hres_forecast,
    ),
]


In [4]:
# load in all of the events in the yaml file
case_dict = utils.load_events_yaml()
freeze_test = {"cases": case_dict["cases"]}

ewb_fourv2_ghcn = evaluate.ExtremeWeatherBench(freeze_test, FOURv2_GHCN_EVALUATION_OBJECTS)
ewb_fourv2_era5 = evaluate.ExtremeWeatherBench(freeze_test, FOURv2_ERA5_EVALUATION_OBJECTS)

ewb_gc_ghcn = evaluate.ExtremeWeatherBench(freeze_test, GC_GHCN_EVALUATION_OBJECTS)
ewb_gc_era5 = evaluate.ExtremeWeatherBench(freeze_test, GC_ERA5_EVALUATION_OBJECTS)

ewb_pang_ghcn = evaluate.ExtremeWeatherBench(freeze_test, PANG_GHCN_EVALUATION_OBJECTS)
ewb_pang_era5 = evaluate.ExtremeWeatherBench(freeze_test, PANG_ERA5_EVALUATION_OBJECTS)

ewb_hres_ghcn = evaluate.ExtremeWeatherBench(freeze_test, HRES_GHCN_EVALUATION_OBJECTS)
ewb_hres_era5 = evaluate.ExtremeWeatherBench(freeze_test, HRES_ERA5_EVALUATION_OBJECTS)

This function is deprecated and will be removed in a future release. Please use cases.load_ewb_events_yaml_into_case_collection instead.
This function is deprecated and will be removed in a future release. Please use cases.read_incoming_yaml instead.


In [5]:
# load in the results for all heat waves in parallel
# this will take awhile to run if you do them all in one code box so I commented most of them out here and copied them below
n_threads_per_process = 4
n_processes = max(1, multiprocessing.cpu_count() // n_threads_per_process)

fourv2_ghcn_results = ewb_fourv2_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
fourv2_era5_results = ewb_fourv2_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

# gc_ghcn_results = ewb_gc_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
# gc_era5_results = ewb_gc_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

# pang_ghcn_results = ewb_pang_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
# pang_era5_results = ewb_pang_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

# hres_ghcn_results = ewb_hres_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
# hres_era5_results = ewb_hres_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

  0%|          | 0/14 [00:00<?, ?it/s]

KeyError: "One of the variables ['surface_air_temperature', 'surface_eastward_wind', 'surface_northward_wind'] not found in forecast data"

In [None]:
gc_ghcn_results = ewb_gc_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
gc_era5_results = ewb_gc_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

In [None]:
pang_ghcn_results = ewb_pang_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
pang_era5_results = ewb_pang_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

In [None]:
hres_ghcn_results = ewb_hres_ghcn.run(parallel=True, n_jobs=n_processes, pre_compute=True)
hres_era5_results = ewb_hres_era5.run(parallel=True, n_jobs=n_processes, pre_compute=True)

In [None]:
# save the results so I don't have to keep re-running
fourv2_ghcn_results.to_csv(basepath + 'docs/notebooks/figure2_part2_fourv2_ghcn_results.csv')
fourv2_era5_results.to_csv(basepath + 'docs/notebooks/figure2_part2_fourv2_era5_results.csv')
# gc_ghcn_results.to_csv(basepath + 'docs/notebooks/figure2_part2_gc_ghcn_results.csv')
# gc_era5_results.to_csv(basepath + 'docs/notebooks/figure2_part2_gc_era5_results.csv')
pang_ghcn_results.to_csv(basepath + 'docs/notebooks/figure2_part2_pang_ghcn_results.csv')
pang_era5_results.to_csv(basepath + 'docs/notebooks/figure2_part2_pang_era5_results.csv')
hres_ghcn_results.to_csv(basepath + 'docs/notebooks/figure2_part2_hres_ghcn_results.csv')
hres_era5_results.to_csv(basepath + 'docs/notebooks/figure2_part2_hres_era5_results.csv')

In [None]:
fourv2_era5_results

In [None]:
gc_ghcn_results

In [None]:
# grab only the maximum MAE results for the heat wave
fourv2_ghcn_group = fourv2_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
fourv2_ghcn_group = fourv2_ghcn_group.reset_index()

fourv2_era5_group = fourv2_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
fourv2_era5_group = fourv2_era5_group.reset_index()

# gc_ghcn_group = gc_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
# gc_ghcn_group = gc_ghcn_group.reset_index()

# gc_era5_group = gc_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
# gc_era5_group = gc_era5_group.reset_index()

pang_ghcn_group = pang_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
pang_ghcn_group = pang_ghcn_group.reset_index()

pang_era5_group = pang_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
pang_era5_group = pang_era5_group.reset_index()

hres_ghcn_group = hres_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
hres_ghcn_group = hres_ghcn_group.reset_index()

hres_era5_group = hres_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
hres_era5_group = hres_era5_group.reset_index()



In [None]:
fourv2_ghcn_group[fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values

In [None]:
plt.plot(fourv2_ghcn_group[fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'r', label='FourCastNet V2 GHCN')
plt.plot(fourv2_era5_group[fourv2_era5_group['metric'] == 'MaximumMAE']['value'].values, 'r.-', label='FourCastNet V2 ERA5')

# plt.plot(gc_ghcn_group[gc_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'b', label='GraphCast GHCN')
# plt.plot(gc_era5_group[gc_era5_group['metric'] == 'MaximumMAE']['value'].values, 'b.-', label='GraphCast ERA5')

plt.plot(pang_ghcn_group[pang_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'g', label='Pangu Weather GHCN')
plt.plot(pang_era5_group[pang_era5_group['metric'] == 'MaximumMAE']['value'].values, 'g.-', label='Pangu Weather ERA5')

plt.plot(hres_ghcn_group[hres_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'm', label='HRES GHCN')
plt.plot(hres_era5_group[hres_era5_group['metric'] == 'MaximumMAE']['value'].values, 'm.-', label='HRES ERA5')

plt.title('Maximum MAE for All Heat Waves')
plt.legend()

In [None]:
plt.plot(fourv2_ghcn_group[fourv2_ghcn_group['metric'] == 'RMSE']['value'].values, 'r', label='FourCastNet V2 GHCN')
plt.plot(fourv2_era5_group[fourv2_era5_group['metric'] == 'RMSE']['value'].values, 'r.-', label='FourCastNet V2 ERA5')

# plt.plot(gc_ghcn_group[gc_ghcn_group['metric'] == 'RMSE']['value'].values, 'b', label='GraphCast GHCN')
# plt.plot(gc_era5_group[gc_era5_group['metric'] == 'RMSE']['value'].values, 'b.-', label='GraphCast ERA5')

plt.plot(pang_ghcn_group[pang_ghcn_group['metric'] == 'RMSE']['value'].values, 'g', label='Pangu Weather GHCN')
plt.plot(pang_era5_group[pang_era5_group['metric'] == 'RMSE']['value'].values, 'g.-', label='Pangu Weather ERA5')

plt.plot(hres_ghcn_group[hres_ghcn_group['metric'] == 'RMSE']['value'].values, 'm', label='HRES GHCN')
plt.plot(hres_era5_group[hres_era5_group['metric'] == 'RMSE']['value'].values, 'm.-', label='HRES ERA5')

plt.title('RMSE for All Heat Waves')
plt.legend()

In [None]:
plt.plot(fourv2_ghcn_group[fourv2_ghcn_group['metric'] == 'MaxMinMAE']['value'].values, 'r', label='FourCastNet V2 GHCN')
plt.plot(fourv2_era5_group[fourv2_era5_group['metric'] == 'MaxMinMAE']['value'].values, 'r.-', label='FourCastNet V2 ERA5')

# plt.plot(gc_ghcn_group[gc_ghcn_group['metric'] == 'MaxMinMAE']['value'].values, 'b', label='GraphCast GHCN')
# plt.plot(gc_era5_group[gc_era5_group['metric'] == 'MaxMinMAE']['value'].values, 'b.-', label='GraphCast ERA5')

plt.plot(pang_ghcn_group[pang_ghcn_group['metric'] == 'MaxMinMAE']['value'].values, 'g', label='Pangu Weather GHCN')
plt.plot(pang_era5_group[pang_era5_group['metric'] == 'MaxMinMAE']['value'].values, 'g.-', label='Pangu Weather ERA5')

plt.plot(hres_ghcn_group[hres_ghcn_group['metric'] == 'MaxMinMAE']['value'].values, 'm', label='HRES GHCN')
plt.plot(hres_era5_group[hres_era5_group['metric'] == 'MaxMinMAE']['value'].values, 'm.-', label='HRES ERA5')

plt.title('Mean Absolute Error of the predicted minimum for All Heat Waves')
plt.legend()

# subset the data into regions

In [None]:
# helper function to convert a bounding box tuple to a shapely Polygon
def get_polygon_from_bounding_box(bounding_box):
    """Convert a bounding box tuple to a shapely Polygon."""
    if bounding_box is None:
        return None
    left_lon, right_lon, bot_lat, top_lat = bounding_box
    return Polygon(
        [
            (left_lon, bot_lat),
            (right_lon, bot_lat),
            (right_lon, top_lat),
            (left_lon, top_lat),
            (left_lon, bot_lat),
        ]
    )

# North America
na_bounding_box = [-172, -45, 7, 85]
na_bounding_box_polygon = get_polygon_from_bounding_box(na_bounding_box)

# Europe bounding box
eu_bounding_box = [50, -15, 15, 75]
eu_bounding_box_polygon = get_polygon_from_bounding_box(eu_bounding_box)

# australia bounding box
au_bounding_box = [110, 180, -50, -10]
au_bounding_box_polygon = get_polygon_from_bounding_box(au_bounding_box)


In [None]:
na_cases = list()
eu_cases = list()
au_cases = list()

for heat_case in ewb_fourv2_era5.case_operators:
    #print(heat_case.case_metadata)

    my_case = heat_case.case_metadata

    # collect the North America, Europe, and Australia cases
    if (shapely.intersects(my_case.location.geopandas.geometry[0], na_bounding_box_polygon)):
        na_cases.append(my_case.case_id_number)
    elif (shapely.intersects(my_case.location.geopandas.geometry[0], eu_bounding_box_polygon)):
        eu_cases.append(my_case.case_id_number)
    elif (shapely.intersects(my_case.location.geopandas.geometry[0], au_bounding_box_polygon)):
        au_cases.append(my_case.case_id_number)

print(f'North America Cases: {na_cases}')
print(f'Europe Cases: {eu_cases}')
print(f'Australia Cases: {au_cases}')


In [None]:
# make all the subsets
na_fourv2_era5_results = fourv2_era5_results[fourv2_era5_results['case_id_number'].isin(na_cases)]
eu_fourv2_era5_results = fourv2_era5_results[fourv2_era5_results['case_id_number'].isin(eu_cases)]
au_fourv2_era5_results = fourv2_era5_results[fourv2_era5_results['case_id_number'].isin(au_cases)]

na_fourv2_ghcn_results = fourv2_ghcn_results[fourv2_ghcn_results['case_id_number'].isin(na_cases)]
eu_fourv2_ghcn_results = fourv2_ghcn_results[fourv2_ghcn_results['case_id_number'].isin(eu_cases)]
au_fourv2_ghcn_results = fourv2_ghcn_results[fourv2_ghcn_results['case_id_number'].isin(au_cases)]

na_gc_ghcn_results = gc_ghcn_results[gc_ghcn_results['case_id_number'].isin(na_cases)]
eu_gc_ghcn_results = gc_ghcn_results[gc_ghcn_results['case_id_number'].isin(eu_cases)]
au_gc_ghcn_results = gc_ghcn_results[gc_ghcn_results['case_id_number'].isin(au_cases)]

na_gc_era5_results = gc_era5_results[gc_era5_results['case_id_number'].isin(na_cases)]
eu_gc_era5_results = gc_era5_results[gc_era5_results['case_id_number'].isin(eu_cases)]
au_gc_era5_results = gc_era5_results[gc_era5_results['case_id_number'].isin(au_cases)]


In [None]:
# do the groupby for each subset
na_fourv2_ghcn_group = na_fourv2_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
na_fourv2_ghcn_group = na_fourv2_ghcn_group.reset_index()

eu_fourv2_ghcn_group = eu_fourv2_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
eu_fourv2_ghcn_group = eu_fourv2_ghcn_group.reset_index()

au_fourv2_ghcn_group = au_fourv2_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
au_fourv2_ghcn_group = au_fourv2_ghcn_group.reset_index()


# fourv2_era5_group = fourv2_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
# fourv2_era5_group = fourv2_era5_group.reset_index()

# gc_ghcn_group = gc_ghcn_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
# gc_ghcn_group = gc_ghcn_group.reset_index()

# gc_era5_group = gc_era5_results[['metric', 'lead_time', 'value']].groupby(['metric', 'lead_time']).mean()
# gc_era5_group = gc_era5_group.reset_index()


In [None]:
plt.plot(fourv2_ghcn_group[fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'k--', label='FourCastNet V2 GHCN Global')
plt.plot(na_fourv2_ghcn_group[na_fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'r.-', label='FourCastNet V2 GHCN North American')
plt.plot(eu_fourv2_ghcn_group[eu_fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'b.-', label='FourCastNet V2 GHCN Europe')
plt.plot(au_fourv2_ghcn_group[au_fourv2_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'm.-', label='FourCastNet V2 GHCN Australia')


# plt.plot(fourv2_era5_group[fourv2_era5_group['metric'] == 'MaximumMAE']['value'].values, 'r.-', label='FourCastNet V2 ERA5')

# plt.plot(gc_ghcn_group[gc_ghcn_group['metric'] == 'MaximumMAE']['value'].values, 'b', label='GraphCast GHCN')
# plt.plot(gc_era5_group[gc_era5_group['metric'] == 'MaximumMAE']['value'].values, 'b.-', label='GraphCast ERA5')

plt.title('Maximum MAE for All Heat Waves')
plt.legend()