In [None]:
!uv pip install dask distributed

: 

In [5]:
"""Helper functions to identify the date ranges of heat waves and freeze events."""

import xarray as xr
import numpy as np
import pandas as pd
from extremeweatherbench import utils, case
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from cartopy.mpl.gridliner import LongitudeFormatter, LatitudeFormatter
import seaborn as sns
from matplotlib import dates as mdates
import datetime
from extremeweatherbench import derived

sns.set_theme(style="whitegrid", context="talk")


def subset_event_and_mask_climatology(
    era5: xr.Dataset,
    climatology: xr.Dataset,
    actual_start_date: datetime.datetime,
    actual_end_date: datetime.datetime,
    single_case: case.IndividualCase,
):
    """Calculate the times where regional average of temperature exceeds the climatology."""
    era5_event = era5[["2m_temperature"]].sel(
        time=slice(actual_start_date, actual_end_date)
    )
    era5_event = era5_event.sel(time=utils.is_6_hourly(era5_event["time.hour"]))
    subset_climatology = utils.convert_day_yearofday_to_time(
        climatology, np.unique(era5_event.time.dt.year.values)[0]
    ).rename_vars({"2m_temperature": "2m_temperature_85th_percentile"})

    merged_dataset = xr.merge([subset_climatology, era5_event], join="inner")
    merged_dataset = utils.clip_dataset_to_bounding_box(
        merged_dataset,
        single_case.location_center,
        single_case.box_length_width_in_km,
    )
    merged_dataset = utils.remove_ocean_gridpoints(merged_dataset)
    time_averaged_merged_dataset = merged_dataset.mean(["latitude", "longitude"])

    mask = (
        time_averaged_merged_dataset["2m_temperature"]
        > time_averaged_merged_dataset["2m_temperature_85th_percentile"]
    )
    return mask.compute(), merged_dataset


def find_heatwave_events(
    era5: xr.Dataset,
    climatology: xr.Dataset,
    single_case: case.IndividualCase,
    plot: bool = True,
):
    """Find the start and end dates of heatwave events, stepping +- 6 hours until
    < climatology timesteps are located."""
    start_date = pd.to_datetime(single_case.start_date)
    end_date = pd.to_datetime(single_case.end_date)
    location_center = single_case.location
    era5_event = era5[["2m_temperature"]].sel(time=slice(start_date, end_date))
    era5_event = era5_event.sel(time=utils.is_6_hourly(era5_event["time.hour"]))
    subset_climatology = utils.convert_day_yearofday_to_time(
        climatology, np.unique(era5_event.time.dt.year.values)[0]
    ).rename_vars({"2m_temperature": "2m_temperature_85th_percentile"})

    mask, merged_dataset = subset_event_and_mask_climatology(
        era5, subset_climatology, start_date, end_date, single_case
    )
    before = True
    after = True
    while before or after:
        # Check if there are 48 hours before and after the event
        try:
            last_true_time = mask.where(mask, drop=True).time[-1].values
            event_end_duration = (mask.time[-1] - last_true_time).values.astype(
                "timedelta64[h]"
            )
            first_true_time = mask.where(mask, drop=True).time[0].values
            event_start_duration = (mask.time[0] - first_true_time).values.astype(
                "timedelta64[h]"
            )
            if np.datetime64(last_true_time, "D") == np.datetime64("2022-12-31"):
                after = False
            if abs(event_start_duration) >= np.timedelta64(6, "h"):
                before = False
            else:
                start_date -= pd.DateOffset(hours=6)
                mask, merged_dataset, time_based_merged_dataset = (
                    subset_event_and_mask_climatology(
                        era5, climatology, start_date, end_date, single_case
                    )
                )
            if abs(event_end_duration) >= np.timedelta64(6, "h"):
                after = False
            else:
                end_date += pd.DateOffset(hours=6)
                mask, merged_dataset, time_based_merged_dataset = (
                    subset_event_and_mask_climatology(
                        era5, climatology, start_date, end_date, single_case
                    )
                )
        except IndexError:
            print(f"No dates valid for {location_center}, {start_date}, {end_date}")
            before = False
            after = False
    start_date -= pd.DateOffset(hours=42)
    end_date += pd.DateOffset(hours=42)

    mask, merged_dataset, time_based_merged_dataset = subset_event_and_mask_climatology(
        era5, climatology, start_date, end_date, single_case
    )
    if plot:
        case_plot(merged_dataset, time_based_merged_dataset, single_case)
    return (
        mask,
        time_based_merged_dataset.time.min().values,
        time_based_merged_dataset.time.max().values,
    )


def case_plot(
    dataset: xr.Dataset,
    single_case: case.IndividualCase,
    variable: str
):
    """Plot the max timestep of the heatwave event, the average regional temperature time series,
    and the associated climatology."""
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(6, 10), gridspec_kw={"height_ratios": [1, 1]}
    )
    plt.subplots_adjust(hspace=0.3)
    ax1.remove()
    ax1 = plt.subplot(2, 1, 1, projection=ccrs.PlateCarree())
    subset_timestep = (
        dataset[variable].mean(["latitude", "longitude"])
        == dataset[variable].mean(["latitude", "longitude"]).max()
    )
    im = (
        (
            dataset[variable]
        )
        .sel(time=subset_timestep)
        .plot(
            ax=ax1,
            transform=ccrs.PlateCarree(),
            cmap="inferno",
            add_colorbar=False,
        )
    )
    # Add coastlines and gridlines
    ax1.coastlines()
    ax1.add_feature(cfeature.BORDERS, linestyle=":")
    ax1.add_feature(cfeature.LAND, edgecolor="black")
    ax1.add_feature(cfeature.LAKES, edgecolor="black")
    ax1.add_feature(cfeature.RIVERS, edgecolor="black")
    ax1.add_feature(cfeature.STATES, edgecolor="grey")
    # Add gridlines
    gl = ax1.gridlines(draw_labels=True)
    gl.top_labels = False
    gl.right_labels = False
    gl.xformatter = LongitudeFormatter()
    gl.yformatter = LatitudeFormatter()
    gl.xlabel_style = {"size": 12, "color": "k"}
    gl.ylabel_style = {"size": 12, "color": "k"}
    ax1.set_title(
        f"Event ID {case.id}: 2m Temperature, {dataset['time'].sel(time=subset_timestep).dt.strftime('%Y-%m-%d %Hz').values[0]}",
        fontsize=12,
    )
    # Add the location coordinate as a dot on the map
    ax1.plot(
        single_case.location_center.longitude,
        single_case.location_center.latitude,
        "ko",
        markersize=10,
        transform=ccrs.PlateCarree(),
    )
    # Create a colorbar with the same height as the plot
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.1, axes_class=plt.Axes)
    cbar = fig.colorbar(im, cax=cax, label="Temp > 85th Percentile (C)")
    cbar.set_label("Temp > 85th Percentile (C)", size=14)
    lss = ["-.", "-"]
    lc = ["k", "tab:red"]
    lws = [0.75, 1.5]
    plt.show()

In [8]:
import distributed
client = distributed.Client(n_workers=10)
client


+---------+---------------------------------------------+-----------+----------+
| Package | Worker-762c4f9b-f210-455b-b576-7f3024bb5792 | Scheduler | Workers  |
+---------+---------------------------------------------+-----------+----------+
| dask    | 2025.2.0                                    | 2025.1.0  | 2025.2.0 |
+---------+---------------------------------------------+-----------+----------+

+---------+---------------------------------------------+-----------+----------+
| Package | Worker-38af6c2e-77fe-4953-b5c3-612205f3faf1 | Scheduler | Workers  |
+---------+---------------------------------------------+-----------+----------+
| dask    | 2025.2.0                                    | 2025.1.0  | 2025.2.0 |
+---------+---------------------------------------------+-----------+----------+

+---------+---------------------------------------------+-----------+----------+
| Package | Worker-e8476d32-01ab-4287-8477-e50f8714990e | Scheduler | Workers  |
+---------+--------------

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 10
Total threads: 10,Total memory: 62.79 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36681,Workers: 10
Dashboard: http://127.0.0.1:8787/status,Total threads: 10
Started: Just now,Total memory: 62.79 GiB

0,1
Comm: tcp://127.0.0.1:40735,Total threads: 1
Dashboard: http://127.0.0.1:35671/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:34183,
Local directory: /tmp/dask-scratch-space/worker-vi6ovacs,Local directory: /tmp/dask-scratch-space/worker-vi6ovacs

0,1
Comm: tcp://127.0.0.1:45589,Total threads: 1
Dashboard: http://127.0.0.1:40967/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:42013,
Local directory: /tmp/dask-scratch-space/worker-slj_d2d3,Local directory: /tmp/dask-scratch-space/worker-slj_d2d3

0,1
Comm: tcp://127.0.0.1:37325,Total threads: 1
Dashboard: http://127.0.0.1:37751/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:44993,
Local directory: /tmp/dask-scratch-space/worker-23w67i_x,Local directory: /tmp/dask-scratch-space/worker-23w67i_x

0,1
Comm: tcp://127.0.0.1:38951,Total threads: 1
Dashboard: http://127.0.0.1:33125/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:42547,
Local directory: /tmp/dask-scratch-space/worker-v5oq_z7t,Local directory: /tmp/dask-scratch-space/worker-v5oq_z7t

0,1
Comm: tcp://127.0.0.1:45997,Total threads: 1
Dashboard: http://127.0.0.1:33997/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:33803,
Local directory: /tmp/dask-scratch-space/worker-nlb3a_h_,Local directory: /tmp/dask-scratch-space/worker-nlb3a_h_

0,1
Comm: tcp://127.0.0.1:33937,Total threads: 1
Dashboard: http://127.0.0.1:44711/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:36829,
Local directory: /tmp/dask-scratch-space/worker-j9hinh8p,Local directory: /tmp/dask-scratch-space/worker-j9hinh8p

0,1
Comm: tcp://127.0.0.1:38495,Total threads: 1
Dashboard: http://127.0.0.1:40689/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:40785,
Local directory: /tmp/dask-scratch-space/worker-hmrx2gfd,Local directory: /tmp/dask-scratch-space/worker-hmrx2gfd

0,1
Comm: tcp://127.0.0.1:41153,Total threads: 1
Dashboard: http://127.0.0.1:34721/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:43339,
Local directory: /tmp/dask-scratch-space/worker-whd98jt9,Local directory: /tmp/dask-scratch-space/worker-whd98jt9

0,1
Comm: tcp://127.0.0.1:44299,Total threads: 1
Dashboard: http://127.0.0.1:45493/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:36723,
Local directory: /tmp/dask-scratch-space/worker-290luj6p,Local directory: /tmp/dask-scratch-space/worker-290luj6p

0,1
Comm: tcp://127.0.0.1:41611,Total threads: 1
Dashboard: http://127.0.0.1:34227/status,Memory: 6.28 GiB
Nanny: tcp://127.0.0.1:35001,
Local directory: /tmp/dask-scratch-space/worker-d7mpahol,Local directory: /tmp/dask-scratch-space/worker-d7mpahol


In [24]:
yaml_event_case = utils.load_events_yaml()
severe_event_list = [n for n in yaml_event_case['cases'] if n['event_type']=='severe_day']

In [25]:
icase = severe_event_list[0]
icase['location'] = utils.Location(icase['location']['latitude'], icase['location']['longitude'])

In [26]:
# storage_options = {
#     "remote_options": {"anon": True},
#     "remote_protocol": "s3",
# }  # options passed to fsspec
# open_dataset_options: dict = {"chunks": {}}  # opens passed to xarray
# file = 'gcs://extremeweatherbench/PANG_v100_GFS_combined_all_small.parq'
# ds = xr.open_dataset(
#     file,
#     engine="kerchunk",
#     storage_options=storage_options,
#     open_dataset_options=open_dataset_options,
# )
# ds

In [27]:
era5_map = {v: k for k, v in utils.ERA5_MAPPING.items()}

In [28]:
era5 = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
era5_icase = era5.sel(
    time=slice(icase['start_date'], icase['end_date'])
)
era5_icase

era5_icase = era5_icase.rename(era5_map)[list(era5_map.values())]

In [29]:
era5_icase = utils.clip_dataset_to_bounding_box_degrees(
            era5_icase, icase['location'], icase['bounding_box_degrees']
        )

In [30]:
era5_icase = era5_icase.compute()

In [31]:
era5_icase = era5_icase.sel(level=slice(100, 1000))

In [32]:
from metpy.calc import (
    dewpoint_from_relative_humidity,
    mixed_layer_cape_cin,
    dewpoint_from_specific_humidity,
    relative_humidity_from_specific_humidity,
)
from metpy.units import units
# Use ProgressBar to visualize the computation
shear_0_6_km = np.sqrt(
    (era5_icase["eastward_wind"].sel(level=500) - era5_icase["surface_eastward_wind"]) ** 2
    + (era5_icase["northward_wind"].sel(level=500) - era5_icase["surface_northward_wind"]) ** 2
)
pressure_levels = era5_icase["level"] * units.hPa
temperature = era5_icase["air_temperature"] * units.degK

# temperature
if "dewpoint_temperature" in era5_icase.variables:
    dewpoint_temperature = (
        era5_icase["dewpoint_temperature"] * units.degK
    )
elif "relative_humidity" in era5_icase.variables:
    rh = era5_icase["relative_humidity"] * units.dimensionless
    dewpoint_temperature = dewpoint_from_relative_humidity(temperature, rh)
elif "specific_humidity" in era5_icase.variables:
    dewpoint_temperature = dewpoint_from_specific_humidity(
        pressure_levels,
        temperature,                                
        era5_icase["specific_humidity"] * units('g/kg')
        )
else:
    raise ValueError("No humidity variable found in dataset")

# Vectorize the CAPE calculation over all dimensions except pressure level
# Create empty DataArray with the same dimensions as temperature but without the level dimension
mlcape = xr.DataArray(
    np.zeros(temperature.shape[:1] + temperature.shape[2:]),
    coords={
        'valid_time': temperature.coords['valid_time'],
        'latitude': temperature.coords['latitude'],
        'longitude': temperature.coords['longitude']
    },
    dims=['valid_time', 'latitude', 'longitude']
)

# Iterate over each lat, lon, and time point to calculate CAPE
for t_idx, t in enumerate(temperature.coords['valid_time'].values):
    for lat_idx, lat in enumerate(temperature.coords['latitude'].values):
        for lon_idx, lon in enumerate(temperature.coords['longitude'].values):
            # Extract 1D profile for this location and time
            temp_profile = temperature.sel(valid_time=t, latitude=lat, longitude=lon).compute()
            dewpt_profile = dewpoint_temperature.sel(valid_time=t, latitude=lat, longitude=lon).compute()
            
            # Calculate CAPE for this profile
            cape_value, _ = mixed_layer_cape_cin(
                pressure_levels,
                temp_profile, 
                dewpt_profile,
                depth=100 * units.hPa
            )
            
            # Store the result
            mlcape.values[t_idx, lat_idx, lon_idx] = cape_value.magnitude

# Add units back to the result
mlcape = mlcape * units('J/kg')
sigsvr = mlcape * shear_0_6_km

  val = np.log(vapor_pressure / mpconsts.nounit.sat_pressure_0c)
  cape_value, _ = mixed_layer_cape_cin(
  var_interp = var[below] + (var[above] - var[below]) * ((x_array - xp[below])
  magnitude = magnitude_op(self._magnitude, other_magnitude)


In [34]:
mlcape

0,1
Magnitude,[[[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]] [[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]] [[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]] ... [[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]] [[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]] [[0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] ... [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0] [0.0 0.0 0.0 ... 0.0 0.0 0.0]]]
Units,joule/kilogram


In [12]:
mixed_layer_cape_cin(
                pressure_levels.values,
                temp_profile.values, 
                dewpt_profile.values
            )

KilledWorker: Attempted to run task ('open_dataset-temperature-original-getitem-e8126bb67d0a90ed97e959e505eb6b5d', 0, 0, 0, 0) on 4 different workers, but all those workers died while running it. The last worker that attempt to run the task was tcp://127.0.0.1:53872. Inspecting worker logs is often a good next step to diagnose what went wrong. For more information see https://distributed.dask.org/en/stable/killed.html.

In [21]:
era5 = utils.clip_dataset_to_bounding_box(
    era5,
    utils.Location(icase['location']['latitude'], icase['location']['longitude']),
    icase['bounding_box_km'],
)
era5_icase = era5.sel(
    time=slice(icase['start_date'], icase['end_date'])
)
shear_0_6_km = np.sqrt(
    (era5_icase["eastward_wind"].sel(level=500) - era5_icase["surface_eastward_wind"]) ** 2
    + (era5_icase["northward_wind"].sel(level=500) - era5_icase["surface_northward_wind"]) ** 2
)

{'id': 37,
 'title': 'July 2024 South Dakota',
 'start_date': datetime.datetime(2024, 7, 13, 0, 0),
 'end_date': datetime.datetime(2024, 7, 14, 0, 0),
 'location': {'latitude': 44.3677, 'longitude': -100.3516},
 'bounding_box_km': 500,
 'event_type': 'severe'}

In [32]:
era5 = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
era5_icase = era5.sel(
    time=slice(icase['start_date'], icase['end_date'])
)
era5 = era5.rename_vars({"10m_u_component_of_wind":"surface_eastward_wind", 
"10m_v_component_of_wind":"surface_northward_wind", 
"u_component_of_wind": "eastward_wind", 
"v_component_of_wind": "northward_wind"})

# era5 = utils.clip_dataset_to_bounding_box(
#     era5_icase,
#     utils.Location(icase['location']['latitude'], icase['location']['longitude']),
#     icase['bounding_box_km'],
# )

KeyboardInterrupt: 

In [5]:
icase = severe_event_list[0]

In [12]:
era5 = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
era5 = era5[['10m_u_component_of_wind', '10m_v_component_of_wind', 'u_component_of_wind', 'v_component_of_wind']]
era5_icase = era5.sel(
    time=slice(icase['start_date'], icase['end_date']), level=500
)

In [14]:
era5_icase

In [15]:
era5_icase = utils.clip_dataset_to_bounding_box(
    era5_icase,
    utils.Location(icase['location']['latitude'], icase['location']['longitude']),
    icase['bounding_box_km'],
)
era5_icase = era5_icase.rename_vars({"10m_u_component_of_wind":"surface_eastward_wind", 
"10m_v_component_of_wind":"surface_northward_wind", 
"u_component_of_wind": "eastward_wind", 
"v_component_of_wind": "northward_wind"})


# for icase in severe_event_list:

#     era5_icase = era5.sel(
#         time=slice(icase['start_date'], icase['end_date'])
#     )
#     era5 = utils.clip_dataset_to_bounding_box(
#         era5,
#         utils.Location(icase['location']['latitude'], icase['location']['longitude']),
#         icase['bounding_box_km'],
#     )
#     era5_icase = era5_icase
#     shear = shear06km(era5_icase)
#     print(shear)
    # case_plot(era5_icase, time_based_merged_dataset, single_icase)

KeyboardInterrupt: 