This script tries to reproduce the Meridional Energy Transport plots from 

**Energy transport, polar amplification, and ITCZ shifts in the GeoMIP G1 ensemble**
**by Rick D. Russotto and Thomas P. Ackerman**
**(2018).**
**doi: https://doi.org/10.5194/acp-18-2287-2018**

In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [2]:
# Physical constants
cp = 1004.0        # J kg-1 K-1
Lv = 2.5e6         # J kg-1
g  = 9.81          # m s-2

In [3]:
#Load Data
ds_t = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/ta/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

ds_q = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/hus/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

ds_z = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/zg/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

ds_v = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/va/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

In [4]:
ds_t

Unnamed: 0,Array,Chunk
Bytes,30.94 kiB,16 B
Shape,"(1980, 2)","(1, 2)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 30.94 kiB 16 B Shape (1980, 2) (1, 2) Dask graph 1980 chunks in 5 graph layers Data type object numpy.ndarray",2  1980,

Unnamed: 0,Array,Chunk
Bytes,30.94 kiB,16 B
Shape,"(1980, 2)","(1, 2)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.35 MiB,2.64 MiB
Shape,"(1980, 144, 2)","(1200, 144, 2)"
Dask graph,2 chunks in 7 graph layers,2 chunks in 7 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 4.35 MiB 2.64 MiB Shape (1980, 144, 2) (1200, 144, 2) Dask graph 2 chunks in 7 graph layers Data type float64 numpy.ndarray",2  144  1980,

Unnamed: 0,Array,Chunk
Bytes,4.35 MiB,2.64 MiB
Shape,"(1980, 144, 2)","(1200, 144, 2)"
Dask graph,2 chunks in 7 graph layers,2 chunks in 7 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.80 MiB,3.52 MiB
Shape,"(1980, 192, 2)","(1200, 192, 2)"
Dask graph,2 chunks in 7 graph layers,2 chunks in 7 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 5.80 MiB 3.52 MiB Shape (1980, 192, 2) (1200, 192, 2) Dask graph 2 chunks in 7 graph layers Data type float64 numpy.ndarray",2  192  1980,

Unnamed: 0,Array,Chunk
Bytes,5.80 MiB,3.52 MiB
Shape,"(1980, 192, 2)","(1200, 192, 2)"
Dask graph,2 chunks in 7 graph layers,2 chunks in 7 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.87 GiB 2.00 MiB Shape (1980, 19, 144, 192) (1, 19, 144, 192) Dask graph 1980 chunks in 5 graph layers Data type float32 numpy.ndarray",1980  1  192  144  19,

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [5]:
# Select common period and align
t_slice = slice("2071-01", "2100-12")

T = ds_t.ta ;#.sel(time=t_slice)
q = ds_q.hus ;#.sel(time=t_slice)
z = ds_z.zg ;#.sel(time=t_slice)
v = ds_v.va ;#.sel(time=t_slice)

# Ensure identical grids
#T, q, z, v = xr.align(T, q, z, v)

In [6]:
T

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.87 GiB 2.00 MiB Shape (1980, 19, 144, 192) (1, 19, 144, 192) Dask graph 1980 chunks in 5 graph layers Data type float32 numpy.ndarray",1980  1  192  144  19,

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [7]:
#Compute energy components

DSE = cp * T + g * z
LAT = Lv * q
MSE = DSE + LAT

In [8]:
MSE

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 20 graph layers,1980 chunks in 20 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.87 GiB 2.00 MiB Shape (1980, 19, 144, 192) (1, 19, 144, 192) Dask graph 1980 chunks in 20 graph layers Data type float32 numpy.ndarray",1980  1  192  144  19,

Unnamed: 0,Array,Chunk
Bytes,3.87 GiB,2.00 MiB
Shape,"(1980, 19, 144, 192)","(1, 19, 144, 192)"
Dask graph,1980 chunks in 20 graph layers,1980 chunks in 20 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [9]:
v

Unnamed: 0,Array,Chunk
Bytes,3.90 GiB,2.02 MiB
Shape,"(1980, 19, 145, 192)","(1, 19, 145, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.90 GiB 2.02 MiB Shape (1980, 19, 145, 192) (1, 19, 145, 192) Dask graph 1980 chunks in 5 graph layers Data type float32 numpy.ndarray",1980  1  192  145  19,

Unnamed: 0,Array,Chunk
Bytes,3.90 GiB,2.02 MiB
Shape,"(1980, 19, 145, 192)","(1, 19, 145, 192)"
Dask graph,1980 chunks in 5 graph layers,1980 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
#Interpolating to same grid as the meridional wind v

def to_v(energy, v):
    return (energy.interp(lat=v.lat, lon=v.lon, plev=v.plev, method="linear"))

DSE_on_v = to_v(DSE, v)
LAT_on_v = to_v(LAT, v)
MSE_on_v = to_v(MSE, v)

In [11]:
#Meridional energy fluxes
v_DSE = v * DSE_on_v
v_LAT = v * LAT_on_v
v_MSE = v * MSE_on_v

In [12]:
v_MSE

Unnamed: 0,Array,Chunk
Bytes,3.90 GiB,2.02 MiB
Shape,"(1980, 19, 145, 192)","(1, 19, 145, 192)"
Dask graph,1980 chunks in 54 graph layers,1980 chunks in 54 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.90 GiB 2.02 MiB Shape (1980, 19, 145, 192) (1, 19, 145, 192) Dask graph 1980 chunks in 54 graph layers Data type float32 numpy.ndarray",1980  1  192  145  19,

Unnamed: 0,Array,Chunk
Bytes,3.90 GiB,2.02 MiB
Shape,"(1980, 19, 145, 192)","(1, 19, 145, 192)"
Dask graph,1980 chunks in 54 graph layers,1980 chunks in 54 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [13]:
v_MSE.plev.values

array([100000.,  92500.,  85000.,  70000.,  60000.,  50000.,  40000.,
        30000.,  25000.,  20000.,  15000.,  10000.,   7000.,   5000.,
         3000.,   2000.,   1000.,    500.,    100.])

In [14]:
#Vertical integration function (pressure coordinates)
#Pressure is already in Pa. Otherwise convert.
plev = T.plev

def vertical_integral(flux):
    flux = flux.sortby("plev") # ensures ascending order
    return (flux
            .integrate("plev") / g)

In [15]:
#Compute vertical integration
VI_DSE = vertical_integral(v_DSE)
VI_LAT = vertical_integral(v_LAT)
VI_MSE = vertical_integral(v_MSE)

In [16]:
VI_DSE

Unnamed: 0,Array,Chunk
Bytes,420.56 MiB,217.50 kiB
Shape,"(1980, 145, 192)","(1, 145, 192)"
Dask graph,1980 chunks in 56 graph layers,1980 chunks in 56 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 420.56 MiB 217.50 kiB Shape (1980, 145, 192) (1, 145, 192) Dask graph 1980 chunks in 56 graph layers Data type float64 numpy.ndarray",192  145  1980,

Unnamed: 0,Array,Chunk
Bytes,420.56 MiB,217.50 kiB
Shape,"(1980, 145, 192)","(1, 145, 192)"
Dask graph,1980 chunks in 56 graph layers,1980 chunks in 56 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [17]:
#Time mean (monthly → climatology)
VI_DSE_clim = VI_DSE.mean("time")
VI_LAT_clim = VI_LAT.mean("time")
VI_MSE_clim = VI_MSE.mean("time")

In [18]:
VI_MSE_clim

Unnamed: 0,Array,Chunk
Bytes,217.50 kiB,217.50 kiB
Shape,"(145, 192)","(145, 192)"
Dask graph,1 chunks in 70 graph layers,1 chunks in 70 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 217.50 kiB 217.50 kiB Shape (145, 192) (145, 192) Dask graph 1 chunks in 70 graph layers Data type float64 numpy.ndarray",192  145,

Unnamed: 0,Array,Chunk
Bytes,217.50 kiB,217.50 kiB
Shape,"(145, 192)","(145, 192)"
Dask graph,1 chunks in 70 graph layers,1 chunks in 70 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [19]:
#Plotting function (lat–lon maps)
def plot_map(data, title, vmax=None):
    fig = plt.figure(figsize=(10,4))
    ax = plt.axes(projection=ccrs.Robinson())
    
    im = data.plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap="RdBu_r",
        vmax=vmax,
        vmin=-vmax,
        cbar_kwargs={"label": "W m$^{-1}$"}
    )
    
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.set_title(title)
    plt.show()

In [None]:
VI_MSE_clim_compute=VI_MSE_clim.compute()
VI_LAT_clim_compute=VI_LAT_clim.compute()
VI_DSE_clim_compute=VI_DSE_clim.compute()

In [None]:
plot_map(VI_MSE_clim_compute, "Meridional Transport of MSE", vmax=3e9)
plot_map(VI_LAT_clim_compute, "Meridional Transport of Latent Energy", vmax=2e9)
plot_map(VI_DSE_clim_compute, "Meridional Transport of Dry Static Energy", vmax=2e9)

In [None]:
# plt.plot(VI_LAT_clim_compute.mean('lon')/10e8)

lhfln_mean_zonal_mean_model = VI_LAT_clim_compute.mean(dim="lon")
# tefln_mean_zonal_mean = VI_DSE_clim_compute.mean(dim="lon")

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))

ax.plot(
    lhfln_mean_zonal_mean_model.lat,
    lhfln_mean_zonal_mean_model,
    linewidth=2.5,
    label="Latent heat flux",
    color="tab:blue"
)

# ax.plot(
#     tefln_mean_zonal_mean.lat,
#     tefln_mean_zonal_mean,
#     linewidth=2.5,
#     label="Moist static energy flux",
#     color="tab:red"
# )

# Zero line
ax.axhline(0, linestyle="--", linewidth=1)

# Labels
ax.set_xlabel("Latitude (°N)")
ax.set_ylabel("Northward flux (W m$^{-1}$)")
ax.set_title("Zonal-Mean Northward Latent Heat Flux (UKESM Climatology : Historical)")

# Limits and grid
ax.set_xlim(-90, 90)
ax.grid(True, linestyle=":", alpha=0.6)
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
lhfln_mean_zonal_mean_model

In [None]:
%store -r lhfln_mean_zonal_mean

lhfln_mean_zonal_mean_CDS=lhfln_mean_zonal_mean.interp(latitude=lhfln_mean_zonal_mean_model.lat, method="linear")

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))

ax.plot(
    lhfln_mean_zonal_mean_CDS.latitude,
    lhfln_mean_zonal_mean_CDS,
    linewidth=2.5,
    label="Latent heat flux : CDS 2022",
    color="tab:blue"
)

ax.plot(
    lhfln_mean_zonal_mean_model.lat,
    lhfln_mean_zonal_mean_model,
    linewidth=2.5,
    label="Latent heat flux : UKESM Clim Hist",
    color="tab:red"
)

# Zero line
ax.axhline(0, linestyle="--", linewidth=1)

# Labels
ax.set_xlabel("Latitude (°N)")
ax.set_ylabel("Northward flux (W m$^{-1}$)")
# ax.set_title("""Zonal-Mean Northward Latent Heat Flux (CDS data 2022)\n
#     This dataset provides monthly means of mass-consistent, vertically integrated,\n
#     atmospheric energy and moisture budget quantities derived from 1-hourly ERA5 data.""")
ax.set_title("""CDS dataset provides monthly means of mass-consistent, vertically integrated,\n
    atmospheric energy and moisture budget quantities derived from 1-hourly ERA5 data.""")

# Limits and grid
ax.set_xlim(-90, 90)
ax.grid(True, linestyle=":", alpha=0.6)
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
VI_LAT_clim_compute

In [None]:
# import numpy as np

# F_phi = VI_LAT_clim_compute  # vertical integral of meridional latent heat flux
# a = 6.371e6  # Earth radius in meters
# phi = np.deg2rad(F_phi.lat)  # convert latitude to radians

# # cos(phi) weighting (broadcast over lon)
# F_weighted = F_phi * np.cos(phi[:, np.newaxis])

# # derivative w.r.t latitude (axis=0)
# dF_dphi = np.gradient(F_weighted, phi, axis=0)  # d(F*cos)/dphi

# # divergence
# divergent_flux = dF_dphi / (a * np.cos(phi)[:, np.newaxis])

In [None]:
# import numpy as np
# import xarray as xr

# F_phi = VI_LAT_clim_compute  # vertical integral of meridional latent heat flux
# a = 6.371e6  # Earth radius in meters

# # convert latitude to radians
# phi = np.deg2rad(F_phi['lat'])

# # expand dims to allow broadcasting over lon
# cos_phi = np.cos(phi).expand_dims(lon=F_phi.lon)

# # cos(phi) weighting
# F_weighted = F_phi * cos_phi

# # derivative w.r.t latitude (axis=0)
# dF_dphi = F_weighted.differentiate('lat')  # xarray computes derivative in the coordinate units

# # if lat is in degrees, convert derivative to radians
# dF_dphi = dF_dphi / (np.pi/180)

# # divergence
# divergent_flux = dF_dphi / cos_phi

In [None]:
import numpy as np
import xarray as xr

def compute_meridional_flux_divergence(F_phi):
    """
    Compute the divergent component of vertically integrated meridional flux.
    
    Parameters
    ----------
    F_phi : xarray.DataArray
        Vertically integrated meridional flux (lat x lon).
    
    Returns
    -------
    divergent_flux : xarray.DataArray
        Divergent component of the flux (lat x lon) in W/m^2.
    """
    a = 6.371e6  # Earth radius in meters

    # convert latitude to radians
    phi = np.deg2rad(F_phi['lat'])

    # expand dims to allow broadcasting over longitude
    cos_phi = np.cos(phi).expand_dims(lon=F_phi.lon)

    # cos(phi) weighting
    F_weighted = F_phi * cos_phi

    # derivative w.r.t latitude (xarray handles coordinates)
    dF_dphi = F_weighted.differentiate('lat')

    # convert derivative from per degree to per radian if lat is in degrees
    dF_dphi = dF_dphi / (np.pi / 180)

    # divergence
    divergent_flux = dF_dphi / cos_phi

    return divergent_flux

In [None]:
div_flux_LAT = compute_meridional_flux_divergence(VI_LAT_clim_compute)
div_flux_DSE = compute_meridional_flux_divergence(VI_DSE_clim_compute)
div_flux_MSE = compute_meridional_flux_divergence(VI_MSE_clim_compute)


In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

div_flux_MSE_plot=div_flux_MSE/10e9

# Create figure and axes with PlateCarree projection (regular lat-lon)
fig, ax = plt.subplots(figsize=(12,6), subplot_kw={'projection': ccrs.PlateCarree()})

# Plot the divergent flux
im = div_flux_MSE_plot.plot.pcolormesh(
    ax=ax,
    transform=ccrs.PlateCarree(),  # your data is in lat-lon coords
    cmap='coolwarm',               # choose a colormap
    vmin=-5,
    vmax=5,
    add_colorbar=False              # we'll add manually
)

# Add coastlines and gridlines
ax.coastlines()
# ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.gridlines(draw_labels=True)

# Add colorbar
cbar = plt.colorbar(im, ax=ax, orientation='vertical', pad=0.02)
cbar.set_label('Divergent Meridional Flux (W/m²)')

# Add title
ax.set_title('Divergent Component of Vertically Integrated MSE Flux')

plt.show()

In [None]:
# plt.plot(VI_LAT_clim_compute.mean('lon')/10e8)

lhfln_mean_zonal_mean_model = div_flux_LAT.mean(dim="lon")
# tefln_mean_zonal_mean = VI_DSE_clim_compute.mean(dim="lon")

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))

ax.plot(
    lhfln_mean_zonal_mean_model.lat,
    lhfln_mean_zonal_mean_model,
    linewidth=2.5,
    label="Latent heat flux",
    color="tab:blue"
)

# ax.plot(
#     tefln_mean_zonal_mean.lat,
#     tefln_mean_zonal_mean,
#     linewidth=2.5,
#     label="Moist static energy flux",
#     color="tab:red"
# )

# Zero line
ax.axhline(0, linestyle="--", linewidth=1)

# Labels
ax.set_xlabel("Latitude (°N)")
ax.set_ylabel("Northward flux (W m$^{-1}$)")
ax.set_title("Zonal-Mean Northward Latent Heat Flux (UKESM Climatology : Historical)")

# Limits and grid
ax.set_xlim(-90, 90)
ax.grid(True, linestyle=":", alpha=0.6)
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
plot_map(VI_MSE_clim_compute, "Meridional Transport of MSE", vmax=3e9)