# EWB paper figure 1: showing all the cases

We provide the exact code used to generate each figure in order to be completely reproducible and to encourage others to use EWB with their own models quickly. 
This is Figure 1, the initial figure that shows all of the cases, both globally and then zoomed in

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
# setup all the imports
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.font_manager
flist = matplotlib.font_manager.get_font_names()
from tempfile import NamedTemporaryFile
from mpl_toolkits.axes_grid1 import make_axes_locatable
from extremeweatherbench import evaluate, utils, cases, defaults
from pathlib import Path

# make the basepath - change this to your local path
basepath = Path.home() / 'ExtremeWeatherBench' / ''
basepath = str(basepath) + '/'

# ugly hack to load in our plotting scripts
import sys
sys.path.append(basepath + "/docs/notebooks/")
import case_plotting as cp

In [4]:
# load in all of the events in the yaml file
case_dict = utils.load_events_yaml()

# turn the dictionary into a list of case objects
ewb_cases = cases.load_individual_cases(case_dict)

# build out all of the expected data to evalate the case
# this will not be a 1-1 mapping with ewb_cases because there are multiple data sources to evaluate for some cases
# for example, a heat/cold case will have both a case operator for ERA-5 data and GHCN
case_operators = cases.build_case_operators(case_dict, defaults.get_brightband_evaluation_objects())

In [3]:
# useful for debugging
print(ewb_cases)


In [5]:
# plot all cases on one giant world map
cp.plot_all_cases(ewb_cases, event_type=None, filename=basepath + 'docs/notebooks/figs/ewb_all.png', fill_boxes=False)

In [16]:
# plot the indivdual cases for each event type
cp.plot_all_cases(ewb_cases, event_type='tropical_cyclone', filename=basepath + 'docs/notebooks/figs/ewb_tcs.png', fill_boxes=True)
cp.plot_all_cases(ewb_cases, event_type='freeze', filename=basepath + 'docs/notebooks/figs/ewb_freeze.png', fill_boxes=True)
cp.plot_all_cases(ewb_cases, event_type='heat_wave', filename=basepath + 'docs/notebooks/figs/ewb_heat.png', fill_boxes=True)
cp.plot_all_cases(ewb_cases, event_type='atmospheric_river', filename=basepath + 'docs/notebooks/figs/ewb_ar.png', fill_boxes=True)
cp.plot_all_cases(ewb_cases, event_type='severe_convection', filename=basepath + 'docs/notebooks/figs/ewb_convective.png', fill_boxes=True)


In [12]:
# plot North America
bot_lat = 7
top_lat = 85
left_lon = -172
right_lon = -45

bounding_box = [left_lon, right_lon, bot_lat, top_lat]
plot_title = 'ExtremeWeatherBench Cases in North America'

cp.plot_all_cases(ewb_cases, event_type=None, bounding_box=bounding_box, filename=basepath + 'docs/notebooks/figs/extreme_weather_cases_NA.png', fill_boxes=True)

In [13]:
# plot Europe
bot_lat = 20
top_lat = 75
left_lon = -15
right_lon = 20

print(right_lon, left_lon, bot_lat, top_lat)

bounding_box = [left_lon, right_lon, bot_lat, top_lat]
plot_title = 'ExtremeWeatherBench Cases in Europe'
cp.plot_all_cases(ewb_cases, event_type=None, bounding_box=bounding_box, filename=basepath + 'docs/notebooks/figs/extreme_weather_cases_Europe.png', fill_boxes=True)


In [14]:
# plot Australia
bot_lat = -45
top_lat = -5
left_lon = 105
right_lon = 160
bounding_box = [left_lon, right_lon, bot_lat, top_lat]
plot_title = 'ExtremeWeatherBench Cases in Australia'

cp.plot_all_cases(ewb_cases, event_type=None, bounding_box=bounding_box, filename=basepath + 'docs/notebooks/figs/extreme_weather_cases_Aus.png', fill_boxes=True)

In [16]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LongitudeFormatter, LatitudeFormatter
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
from shapely.geometry import Polygon
import shapely
from matplotlib.patches import Patch
from extremeweatherbench import evaluate, utils, cases, defaults
import matplotlib.colors as mcolors
import xarray as xr
import matplotlib.dates as mdates
from mpl_toolkits.axes_grid1 import make_axes_locatable



In [None]:
# Example: Creating a hexbin plot to show case locations
# Hexbin plots are great for visualizing density of points on a map

import numpy as np

# Get all case center points
case_lons = []
case_lats = []
case_types = []

for case in ewb_cases.cases:
    # Get the center point of each case's bounding box
    center_lat = (case.location.latitude_min + case.location.latitude_max) / 2
    center_lon = (case.location.longitude_min + case.location.longitude_max) / 2
    # Convert longitude to -180 to 180 range if needed
    if center_lon > 180:
        center_lon -= 360
    case_lons.append(center_lon)
    case_lats.append(center_lat)
    case_types.append(case.event_type)

case_lons = np.array(case_lons)
case_lats = np.array(case_lats)

# Create a hexbin plot
fig = plt.figure(figsize=(12, 8))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_global()

# Add coastlines and other map features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, edgecolor='black')
ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='white')

# Add gridlines
gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()

# Create hexbin plot using matplotlib's hexbin
# Note: matplotlib's hexbin needs to be created with the same transform
hb = ax.hexbin(case_lons, case_lats, gridsize=30, cmap='viridis', 
               mincnt=1, transform=ccrs.PlateCarree(),
               linewidths=0.5, edgecolors='black')

# Add colorbar
cbar = plt.colorbar(hb, ax=ax, orientation='horizontal', pad=0.05)
cbar.set_label('Number of Cases', size=12)

ax.set_title('Extreme Weather Cases - Hexbin Density Plot', fontsize=16, pad=20)
plt.savefig(basepath + 'docs/notebooks/figs/ewb_hexbin.png', dpi=300, bbox_inches='tight')
plt.show()

print("Hexbin plot created! This shows the density of cases across the globe.")


In [None]:
# Hexbin plot showing overlapping case types with transparency
# Each layer uses a solid color so overlaps are visible

event_colors = {
    'heat_wave': 'firebrick',
    'freeze': 'royalblue',
    'tropical_cyclone': 'darkorange',
    'severe_convection': 'mediumpurple',
    'atmospheric_river': 'mediumseagreen',
}




# Create figure
fig = plt.figure(figsize=(16, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_global()

# Add map features
ax.coastlines(linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
ax.add_feature(cfeature.LAND, edgecolor='black', alpha=0.2)
ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='white', alpha=0.2)

# Add gridlines
gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.3, linestyle='--')
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()

# Plot hexbin for each event type with transparency
for event_type, color in event_colors.items():
    mask = np.array(case_types) == event_type
    if np.any(mask) and np.sum(mask) > 0:
        # Create a single-color colormap for this event type
        cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', 
                                                         [color, color], N=256)
        
        ax.hexbin(case_lons[mask], case_lats[mask], 
                  gridsize=25, cmap=cmap, 
                  alpha=0.5, mincnt=1,
                  transform=ccrs.PlateCarree(),
                  linewidths=0.3, edgecolors='black')

# Add legend
legend_elements = [Patch(facecolor=color, alpha=0.5, label=event_type.replace('_', ' ').title()) 
                   for event_type, color in event_colors.items()]
ax.legend(handles=legend_elements, loc='upper left', fontsize=10, framealpha=0.9)

ax.set_title('Extreme Weather Cases - Overlapping Hexbin Densities', 
             fontsize=16, pad=20)

plt.savefig(basepath + 'docs/notebooks/figs/ewb_overlapping_hexbins.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print("Overlapping hexbin plot created! Darker/mixed color areas show where case types overlap.")


In [23]:
# Alternative: Panel of subplots showing hexbin for each event type
# This makes it easier to see individual patterns and compare them

fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.1)

for idx, (event_type, color) in enumerate(event_colors.items()):
    ax = fig.add_subplot(gs[idx], projection=ccrs.PlateCarree())
    ax.set_global()
    
    # Add map features
    ax.coastlines(linewidth=0.5)
    ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
    ax.add_feature(cfeature.LAND, edgecolor='black', alpha=0.3)
    ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='lightblue', alpha=0.2)
    
    # Add gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.3, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    gl.xformatter = LongitudeFormatter()
    gl.yformatter = LatitudeFormatter()
    
    # Get data for this event type
    mask = np.array(case_types) == event_type
    if np.any(mask) and np.sum(mask) > 0:
        # Create a gradient colormap based on the event color
        cmap = mcolors.LinearSegmentedColormap.from_list(
            'custom', ['white', color, color], N=256)
        
        hb = ax.hexbin(case_lons[mask], case_lats[mask], 
                      gridsize=25, cmap=cmap, 
                      mincnt=1, transform=ccrs.PlateCarree(),
                      linewidths=0.5, edgecolors='black')
        
        # Add colorbar for each subplot
        cbar = plt.colorbar(hb, ax=ax, orientation='horizontal', 
                           pad=0.05, shrink=0.7)
        cbar.set_label('Count', size=9)
        cbar.ax.tick_params(labelsize=8)
    
    ax.set_title(f'{event_type.replace("_", " ").title()} (n={np.sum(mask)})', 
                fontsize=12, pad=10)

# Add overall title
fig.suptitle('Extreme Weather Cases - Hexbin Densities by Event Type', 
            fontsize=16, y=0.98)

plt.savefig(basepath + 'docs/notebooks/figs/ewb_hexbin_panel.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print("Panel of hexbin plots created! Each subplot shows one event type's distribution.")


In [29]:
# Hybrid approach: Hexbin showing total density + colored points showing event types
# This clearly shows both overall density AND what types of cases occur where

fig = plt.figure(figsize=(16, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_global()

# Add map features
ax.coastlines(linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
ax.add_feature(cfeature.LAND, edgecolor='black', alpha=0.2)
ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='white', alpha=0.2)

# Add gridlines
gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.3, linestyle='--')
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()

# First, create a hexbin showing total density (gray background)
hb_total = ax.hexbin(case_lons, case_lats, gridsize=25, cmap='Greys',
                     mincnt=1, transform=ccrs.PlateCarree(),
                     linewidths=0.3, edgecolors='black', alpha=0.8)

# Now overlay colored scatter points to show event types
# Larger points to make them visible
for event_type, color in event_colors.items():
    mask = np.array(case_types) == event_type
    if np.any(mask):
        ax.scatter(case_lons[mask], case_lats[mask], c=color, 
                  s=40, alpha=0.8, transform=ccrs.PlateCarree(),
                  edgecolors='white', linewidths=1, 
                  label=event_type.replace('_', ' ').title(),
                  zorder=10)

# Add colorbar for total density
cbar_total = plt.colorbar(hb_total, ax=ax, orientation='horizontal', 
                          pad=0.05, aspect=30)
cbar_total.set_label('Total Case Density', size=11)

# Add legend for event types
ax.legend(loc='upper left', fontsize=10, framealpha=0.9, ncol=2)

ax.set_title('Extreme Weather Cases - Density (Gray Hexbin) + Event Types (Colored Points)', 
             fontsize=16, pad=20)

plt.savefig(basepath + 'docs/notebooks/figs/ewb_hexbin_with_types.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print("This shows total case density (gray hexbins) with colored points indicating event types.")
print("Where points of different colors overlap, multiple event types occur nearby!")


In [31]:
# Create a regular hexagon grid over the whole world
# Count events in each bin by type

from collections import defaultdict

# Get center points of all cases (we'll use centers for counting which hexbin each case falls into)
case_centers_lons = []
case_centers_lats = []
case_centers_types = []
case_centers_ids = []

for case in ewb_cases.cases:
    # Get the center point of each case's bounding box
    center_lat = (case.location.latitude_min + case.location.latitude_max) / 2
    center_lon = (case.location.longitude_min + case.location.longitude_max) / 2
    # Convert longitude to -180 to 180 range if needed
    if center_lon > 180:
        center_lon -= 360
    case_centers_lons.append(center_lon)
    case_centers_lats.append(center_lat)
    case_centers_types.append(case.event_type)
    case_centers_ids.append(case.case_id_number)

case_centers_lons = np.array(case_centers_lons)
case_centers_lats = np.array(case_centers_lats)

# Create a fixed hexbin grid over the world
# Use the same gridsize and extent for consistency
world_extent = [-180, 180, -90, 90]  # Full world coverage
gridsize = 40  # Number of hexagons across the width

# Create temporary figure to get hexbin structure
fig_temp, ax_temp = plt.subplots(figsize=(12, 6))
hb_world = ax_temp.hexbin(case_centers_lons, case_centers_lats, 
                          gridsize=gridsize,
                          extent=world_extent)
plt.close(fig_temp)

# Get hexbin centers and create a mapping
hex_centers_lon = hb_world.get_offsets()[:, 0]
hex_centers_lat = hb_world.get_offsets()[:, 1]

# Now count cases by type in each hexbin
# We need to map each case to its hexbin
hexbin_counts = defaultdict(lambda: defaultdict(int))
hexbin_total_counts = defaultdict(int)

for i in range(len(case_centers_lons)):
    lon, lat = case_centers_lons[i], case_centers_lats[i]
    event_type = case_centers_types[i]
    case_id = case_centers_ids[i]
    
    # Find the nearest hexbin center for this point
    distances = np.sqrt((hex_centers_lon - lon)**2 + (hex_centers_lat - lat)**2)
    nearest_idx = np.argmin(distances)
    
    # Count by event type
    hexbin_counts[nearest_idx][event_type] += 1
    hexbin_total_counts[nearest_idx] += 1

print(f"Created {len(hex_centers_lon)} hexbins covering the whole world")
print(f"Hexbins with events: {len(hexbin_counts)}")
print()

# Print some statistics
for event_type in ['heat_wave', 'freeze', 'tropical_cyclone', 'severe_convection', 'atmospheric_river']:
    total_count = sum(hexbin_counts[idx][event_type] for idx in hexbin_counts)
    print(f"{event_type}: {total_count} total cases")



In [32]:
# Visualize the hexbin grid with total counts

fig = plt.figure(figsize=(16, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_global()

# Add map features
ax.coastlines(linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
ax.add_feature(cfeature.LAND, edgecolor='black', alpha=0.2)
ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='white', alpha=0.2)

# Add gridlines
gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.3, linestyle='--')
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()

# Prepare data for hexbin plot
plot_lons = []
plot_lats = []
count_values = []

for idx in range(len(hex_centers_lon)):
    total_count = hexbin_total_counts[idx]
    if total_count > 0:
        plot_lons.append(hex_centers_lon[idx])
        plot_lats.append(hex_centers_lat[idx])
        count_values.append(total_count)

# Create hexbin plot with counts
if len(plot_lons) > 0:
    hb = ax.hexbin(plot_lons, plot_lats, C=count_values,
                   gridsize=gridsize, cmap='YlOrRd',
                   mincnt=1, transform=ccrs.PlateCarree(),
                   linewidths=0.5, edgecolors='black')
    
    # Add colorbar
    cbar = plt.colorbar(hb, ax=ax, orientation='horizontal', 
                       pad=0.05, aspect=30)
    cbar.set_label('Number of Cases', size=12)

ax.set_title('Regular Hexbin Grid - Total Case Counts\n(Fixed grid covering the whole world)', 
             fontsize=16, pad=20)

plt.savefig(basepath + 'docs/notebooks/figs/ewb_regular_hexbin_total.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print(f"Hexbins with events: {len(plot_lons)} out of {len(hex_centers_lon)}")


In [19]:
# Alternative: Hexbin plot with different colors for different event types
# You can create separate hexbin plots for each event type
sns_palette = sns.color_palette("tab10")
sns.set_style("whitegrid")

event_colors = {
    'heat_wave': sns_palette[3],
    'freeze': sns_palette[0],
    'tropical_cyclone': sns_palette[1],
    'severe_convection': sns_palette[5],
    'atmospheric_river': sns_palette[7],
}

fig = plt.figure(figsize=(14, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_global()

# Add map features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, edgecolor='black', alpha=0.3)
ax.add_feature(cfeature.OCEAN, edgecolor='black', facecolor='lightblue', alpha=0.3)

# Add gridlines
gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.3, linestyle='--')
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()

# Plot hexbins for each event type separately
for event_type, color in event_colors.items():
    mask = np.array(case_types) == event_type
    if np.any(mask):
        ax.scatter(case_lons[mask], case_lats[mask], s=50, c=[color], 
                   alpha=0.6, transform=ccrs.PlateCarree(),
                   label=event_type.replace('_', ' ').title())

ax.set_title('Extreme Weather Cases by Type', fontsize=16)
ax.legend(loc='upper left', fontsize=10)

plt.savefig(basepath + 'docs/notebooks/figs/ewb_by_type.png', dpi=300, bbox_inches='tight')
plt.show()


In [23]:
cp.plot_all_cases_hexbin(ewb_cases, event_type=None, filename=basepath + 'docs/notebooks/figs/ewb_all.png', hexbin_size=[100,100])