In [8]:
# libraries

import seaborn as sns
import xarray as xr
import numpy as np
import pandas as pd
import intake
import fsspec
import dask
import metpy.calc as mpcalc
from metpy.units import units
import warnings
from tqdm import tqdm
from xmip.preprocessing import combined_preprocessing
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
# data

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    url = "https://storage.googleapis.com/cmip6/pangeo-cmip6.json"
    col = intake.open_esm_datastore(url)

In [3]:
# set up the local filecache path for faster reading (~5 min w 50 cores and 50 GB memory)

cache_path = '/scratch/fld1/cmip_cache'

storage_options = {
    'filecache': {
        'cache_storage': cache_path,
        'target_protocol': 'gs',
    }
}

# load 6-hourly data with meridional and zonal wind, air temperature, surface pressure, and specific humidity

query = dict(source_id='MRI-ESM2-0',
             table_id='6hrLev',
             experiment_id=['historical', 'ssp585'],
             variable_id=['va', 'ua', 'ta']
)

cat = col.search(**query)

# load data into dictionary use print(dset_dict.keys()) for keys

warnings.filterwarnings("ignore")

z_kwargs = {'consolidated': True, 'decode_times':True}

with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    dset_dict = cat.to_dataset_dict(
        zarr_kwargs=z_kwargs,
        storage_options=storage_options,
        preprocessing=combined_preprocessing
)

print(dset_dict.keys())


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'


dict_keys(['CMIP.MRI.MRI-ESM2-0.historical.6hrLev.gn', 'ScenarioMIP.MRI.MRI-ESM2-0.ssp585.6hrLev.gn'])


In [4]:
# select one year of 6-hourly ssp585 data for testing (with 50 cores and 50 GB memory)

warm = dset_dict['ScenarioMIP.MRI.MRI-ESM2-0.ssp585.6hrLev.gn'].sel(time=slice('2060', '2060'), lat=slice(-60, 10), lon=slice(260, 330))
hist = dset_dict['CMIP.MRI.MRI-ESM2-0.historical.6hrLev.gn'].sel(time=slice('2060', '2060'), lat=slice(-60, 10), lon=slice(260, 330))

Functions for computation (finding LLJs, computing vorticity, etc)

In [21]:
# THIS DOES NOT WORK YET
# function to find LLJs

def detect_low_level_jets(ds, wind_var="wind", lev_var="lev"):
    """
    Detect low-level jets: core speed > 10, core level between 700–1000 hPa.
    Stores core wind speed and core level into dataset.
    """
    wind = ds[wind_var]

    # Find wind max along vertical dimension
    jet_core = wind.max(dim=lev_var)
    jet_core_idx = wind.argmax(dim=lev_var)

    # Convert index to actual pressure level values (force compute to avoid chunk issue)
    jet_core_lev = (
        wind[lev_var]
        .isel({lev_var: jet_core_idx.compute()})
        .compute()
    )

    # Criteria
    is_llj = (jet_core > 10) & (jet_core_lev >= 700) & (jet_core_lev <= 1000)

    ds["llj_core_speed"] = jet_core.where(is_llj)
    ds["llj_core_level"] = jet_core_lev.where(is_llj)
    ds["is_llj"] = is_llj

    return ds


In [20]:
# calculate vorticity

def compute_vorticity(ds, u_name="ua", v_name="va", lat_name="lat", lon_name="lon"):
    """
    Compute relative and planetary vorticity using MetPy, applied across all time/lev.
    Ensures winds have proper units.
    """
    # Grab coords
    lat = ds[lat_name]
    lon = ds[lon_name]

    # Grid spacing (in meters)
    dx, dy = mpcalc.lat_lon_grid_deltas(lon.values, lat.values)

    # Extract winds and enforce units
    u = ds[u_name]
    v = ds[v_name]

    # If units are missing, assign them
    if not hasattr(u, "metpy_unit"):
        u = u * units("m/s")
    else:
        u = u.metpy.convert_units("m/s")

    if not hasattr(v, "metpy_unit"):
        v = v * units("m/s")
    else:
        v = v.metpy.convert_units("m/s")

    # Apply vorticity slice by slice
    rel_vort = xr.apply_ufunc(
        lambda uu, vv: mpcalc.vorticity(uu, vv, dx=dx, dy=dy),
        u, v,
        input_core_dims=[[lat_name, lon_name], [lat_name, lon_name]],
        output_core_dims=[[lat_name, lon_name]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[float],
    )

    # Planetary vorticity (Coriolis parameter f)
    f = mpcalc.coriolis_parameter(lat.values * units.degrees).to("1/s")

    # Add to dataset
    ds["rel_vort"] = rel_vort
    ds["f"] = (lat.dims, f.magnitude)  # store as plain array, keep lat dims

    return ds


In [22]:
# run compute functions!!

subset = (warm
    # calculate wind speed on pressure levels
    .assign(wind=lambda x: np.sqrt(x['ua']**2 + x['va']**2))
    # calculate vorticity
    .pipe(compute_vorticity, u_name="ua", v_name="va", lat_name="lat", lon_name="lon")
    # find low level jets
    #.pipe(detect_low_level_jets, wind_var="wind", lev_var="lev")
    .squeeze()
    .compute())

subset

ValueError: This function changed in version(s) ['1.0']--double check that the function is being called properly.
`vorticity` given arguments with incorrect units: `u` requires "[speed]" but given "none", `v` requires "[speed]" but given "none"
A xarray DataArray or numpy array `x` can be assigned a unit as follows:
    from metpy.units import units
    x = x * units("m/s")
For more information see the Units Tutorial: https://unidata.github.io/MetPy/latest/tutorials/unit_tutorial.html

In [None]:
# quiver plot during LLJ time (time = 225)

dst = subset.isel(time=225).sel(lev=0.8, method='nearest')

In [None]:
import cartopy.feature as cfeature
import cartopy.crs as ccrs

# cartopy mapping things
proj = ccrs.PlateCarree()
fig = plt.subplots(figsize=(14, 6), ncols=2, gridspec_kw={"width_ratios": [2.5, 1]})
fig, (ax1, ax2) = fig
sns.set_style('white')
sns.despine(ax=ax2, top=True, bottom=True, right=True, left=True)
sns.despine(ax=ax1, top=True, bottom=True, right=True, left=True)

# plot number one
ax1 = plt.subplot(1, 2, 1, projection=proj)

# add background wind speed as color shading
speed_plot = ax1.pcolormesh(
    dst.lon, dst.lat, dst.wind,
    transform=proj, cmap='plasma', shading='auto'
)

# add colorbar
cbar = plt.colorbar(speed_plot, ax=ax1, orientation='vertical', pad=0.02)
cbar.set_label('Wind Speed (m/s)')

# plot wind vectors
q = ax1.quiver(
    dst.lon, dst.lat,
    dst.uas, dst.vas,
    transform=proj, scale=500, width=0.002, color='k'
)

# add quiver key
ax1.quiverkey(q, 0.9, -0.1, 10, "10 m/s", labelpos='E')
ax1.set_xlabel('Longitude (degrees east)')
ax1.set_ylabel('Latitude (degrees north)')

# add Cartopy features
ax1.coastlines()
ax1.add_feature(cfeature.BORDERS, linewidth=0.5)
ax1.add_feature(cfeature.LAND, facecolor='lightgray', zorder=0)
ax1.add_feature(cfeature.OCEAN, facecolor='lightblue', zorder=0)
# ax1.gridlines(draw_labels=True, linewidth=0.3, color='gray', alpha=0.5)

# title
ax1.set_title('Surface Wind Vectors with Wind Speed (m/s)', fontsize=14)

# left panel
ax = subset.sel(lon=slice(280, 320)).sel(lat=-20, method='nearest').isel(time=225).va.plot(ax=ax2)
ax.axes.set_ylim(1, 0.4)
ax.axes.set_title('Meridional Wind Profile')

In [None]:
jets.where('llj_class'!='C3')