The aim is to read plot EFE

1) Adam, O., T. Bischoff, and T. Schneider, 2016: Seasonal and Interannual Variations of the Energy Flux Equator and ITCZ. Part I: Zonally Averaged ITCZ Position. J. Climate, 29, 3219–3230, https://doi.org/10.1175/JCLI-D-15-0512.1.
2) EFPM : NOT USING NOW. keeping the reference here: Boos, W., Korty, R. Regional energy budget control of the intertropical convergence zone and application to mid-Holocene rainfall. Nature Geosci 9, 892–897 (2016). https://doi.org/10.1038/ngeo2833


In [None]:
import numpy as np
import xarray as xr
# our local module:
import itcz

import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
import xarray as xr
from pathlib import Path
import myfunctions as mf

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
# from xarrayutils import divergence_spherical, helmholtz_decomposition_spectral  # placeholder functions

In [None]:
ds_t = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/ta/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("T in units of K",ds_t)

ds_q = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/hus/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("q in units of 1",ds_q)

ds_z = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/zg/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("z in unit of meter",ds_z)

ds_u = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/ua/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("u in unit of meter/sec",ds_u)

ds_v = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/va/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("v in unit of meter/sec",ds_v)

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import myfunctions as mf
from windspharm.standard import VectorWind


# ----------------------------
# CONSTANTS
# ----------------------------
Cp = 1004.0        # J kg-1 K-1
Lv = 2.5e6         # J kg-1
g  = 9.81          # m s-2
R  = 6.371e6       # Earth radius (m)

# ----------------------------
# BLOCK 1: TIME MEAN
# ----------------------------
# ta_mean = ds_t.ta.mean("time")
# q_mean  = ds_q.hus.mean("time")
# z_mean  = ds_z.zg.mean("time")
# v_mean  = ds_v.va.mean("time")   # use divergent component later

ta_mean = mf.seasonal_mean_by_year(ds_t.ta, 1, 12).mean("year")
q_mean  = mf.seasonal_mean_by_year(ds_q.hus, 1, 12).mean("year")
z_mean  = mf.seasonal_mean_by_year(ds_z.zg, 1, 12).mean("year")
v_mean  = mf.seasonal_mean_by_year(ds_v.va, 1, 12).mean("year") # use divergent component later

# ----------------------------
# BLOCK 2: SELECT SECTOR (50E-120E)
# ----------------------------
# lon_min, lon_max = 50, 120
lon_min, lon_max = 0, 360

ta_sec = ta_mean.sel(lon=slice(lon_min, lon_max))
q_sec  = q_mean.sel(lon=slice(lon_min, lon_max))
z_sec  = z_mean.sel(lon=slice(lon_min, lon_max))
v_sec  = v_mean.sel(lon=slice(lon_min, lon_max))

# ----------------------------
# BLOCK 2a: BAROTROPIC MASS CORRECTION
# ----------------------------
# Convert pressure to Pa if needed
plev_pa = v_sec.plev ;# * 100.0
v_sec = v_sec.assign_coords(plev=plev_pa)

# Compute vertical mean (mass-weighted if needed)
# Simple average for now (can include dp weights later)
v_vert_mean = v_sec.mean("plev")

# Subtract vertical mean to get mass-corrected meridional wind
Vc = v_sec - v_vert_mean

# ----------------------------
# BLOCK 3: MOIST STATIC ENERGY
# ----------------------------
h = Cp * ta_sec + Lv * q_sec + g * z_sec

#Interpolate to wind grid for consistency
h_on_Vc = h.interp(
    lat=Vc.lat,
    lon=Vc.lon,
    plev=Vc.plev,
    method="linear"
)


# ----------------------------
# BLOCK 4: MERIDIONAL MSE FLUX
# ----------------------------
vh = Vc * h_on_Vc


In [None]:
# ----------------------------
# BLOCK 5: VERTICAL INTEGRATION (NO ZONAL MEAN)
# ----------------------------

vh_sorted = vh.sortby("plev")

vh_computed = vh_sorted.compute()
vh_clean = vh_computed.fillna(0.0)

In [None]:
# Force strictly monotonic pressure
plev_vals = np.array(vh_clean.plev.values, dtype=float)
order = np.argsort(plev_vals)

vh_clean = vh_clean.isel(plev=order)
vh_clean = vh_clean.assign_coords(plev=plev_vals[order])

# Vertically integrated MSE flux (still lat × lon)
vh_int = vh_clean.integrate("plev")   # units: W m⁻¹

F_lat_lon = (R * np.cos(np.deg2rad(vh_int.lat)) / g) * vh_int ;########################### check GPT chat the difference with simona's plot

In [None]:
plt.figure(figsize=(12,5))
F_lat_lon.plot(
    x="lon",
    y="lat",
    cmap="RdBu_r",
    center=0,
    levels=21
)
plt.title("Vertically integrated meridional MSE transport")
plt.show()


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature

F_lat_lon_tropics=F_lat_lon.sel(lat=slice(-30,30))

plt.figure(figsize=(13,5))

ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=180))
ax.set_global()

# Shaded plot
pcm = F_lat_lon.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    x="lon",
    y="lat",
    cmap="RdBu_r",
    center=0,
    levels=21,
    add_colorbar=True,
    cbar_kwargs={"label": "Meridional MSE transport (W m$^{-1}$)"}
)

# Cartopy features
ax.coastlines(linewidth=0.8)
ax.add_feature(cfeature.BORDERS, linewidth=0.4)
ax.add_feature(cfeature.LAND, facecolor="lightgray", alpha=0.3)
ax.add_feature(cfeature.OCEAN, facecolor="white")

# Gridlines
gl = ax.gridlines(
    draw_labels=True,
    linewidth=0.3,
    color="gray",
    alpha=0.5,
    linestyle="--"
)
gl.top_labels = False
gl.right_labels = False

ax.set_title("Vertically Integrated Meridional MSE Transport")

plt.show()


In [None]:
plt.plot(F_lat_lon.mean("lon"))

In [None]:
EFE_lon = []

for lon in F_lat_lon.lon.values:
    F_slice = F_lat_lon.sel(lon=lon)

    # Restrict to tropics
    F_trop = F_slice.sel(lat=slice(-20, 20))

    lat_vals = F_trop.lat.values
    F_vals   = F_trop.values

    # Skip if all same sign
    if np.all(np.sign(F_vals) == np.sign(F_vals[0])):
        EFE_lon.append(np.nan)
        continue

    idx = np.where(np.diff(np.sign(F_vals)))[0][0]

    lat1, lat2 = lat_vals[idx], lat_vals[idx + 1]
    F1, F2     = F_vals[idx],   F_vals[idx + 1]

    efe = lat1 - F1 * (lat2 - lat1) / (F2 - F1)
    EFE_lon.append(efe)

EFE_lon = xr.DataArray(
    EFE_lon,
    coords={"lon": F_lat_lon.lon},
    dims="lon",
    name="EFE"
)

plt.figure(figsize=(12,4))
EFE_lon.plot(color="k")
plt.axhline(0, linestyle="--", color="gray")
plt.ylabel("Latitude (°)")
plt.xlabel("Longitude (°)")
plt.title("Energy Flux Equator (EFE) vs Longitude")
plt.grid(True)
plt.show()


In [None]:
def compute_EFE_per_lon(F_lat_lon, lat_bounds=(-20, 20)):
    """
    Compute Energy Flux Equator (EFE) latitude for each longitude
    """

    EFE_list = []

    for lon in F_lat_lon.lon.values:
        F_lon = F_lat_lon.sel(lon=lon)
        F_trop = F_lon.sel(lat=slice(*lat_bounds))

        lat_vals = F_trop.lat.values
        F_vals   = F_trop.values

        # skip if all NaN or no sign change
        if np.all(np.isnan(F_vals)):
            EFE_list.append(np.nan)
            continue

        sign_change = np.where(np.diff(np.sign(F_vals)))[0]
        if len(sign_change) == 0:
            EFE_list.append(np.nan)
            continue

        i = sign_change[0]
        lat1, lat2 = lat_vals[i], lat_vals[i+1]
        F1, F2     = F_vals[i], F_vals[i+1]

        efe = lat1 - F1 * (lat2 - lat1) / (F2 - F1)
        EFE_list.append(efe)

    return xr.DataArray(
        EFE_list,
        coords={"lon": F_lat_lon.lon},
        dims="lon",
        name="EFE"
    )


In [None]:
EFE_lon = compute_EFE_per_lon(F_lat_lon)


In [None]:
lon = EFE_lon.lon.values
EFE_vals = EFE_lon.values

lon_wrap = ((lon + 180) % 360) - 180
sort_idx = np.argsort(lon_wrap)

lon_sorted = lon_wrap[sort_idx]
EFE_sorted = EFE_vals[sort_idx]


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    1, 1,
    figsize=(14,4),
    subplot_kw=dict(projection=ccrs.PlateCarree())
)

# --- dummy background mesh (same trick as Adam) ---
lat_dummy = np.linspace(-30, 30, 61)
lon2d, lat2d = np.meshgrid(lon_sorted, lat_dummy)
dummy = np.zeros_like(lon2d)

ax.pcolormesh(
    lon2d, lat2d, dummy,
    shading="auto",
    alpha=0,
    transform=ccrs.PlateCarree()
)

##########################################
gl = ax.gridlines(
    crs=ccrs.PlateCarree(),
    draw_labels=True,
    linewidth=0.5,
    color="gray",
    alpha=0.5,
    linestyle="--"
)

# Only label left and bottom (Adam-style clean look)
gl.top_labels = False
gl.right_labels = False

# Optional: control tick locations
gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 60))
gl.ylocator = plt.FixedLocator(np.arange(-30, 31, 10))

# Optional: formatting
gl.xlabel_style = {"size": 10}
gl.ylabel_style = {"size": 10}

##########################################

# --- EFE bars ---
ax.bar(
    lon_sorted,
    EFE_sorted,
    width=(lon_sorted[1] - lon_sorted[0]),
    color="C0",
    alpha=0.8
)

# equator
ax.axhline(0, color="k", linestyle="--", linewidth=0.6)

# cartopy features
ax.add_feature(cfeature.LAND, facecolor="lightgray")
ax.add_feature(cfeature.OCEAN, facecolor="aliceblue")
ax.coastlines(linewidth=0.6)

ax.set_xlim(-180, 180)
ax.set_ylim(-30, 30)

ax.set_title("Energy Flux Equator (EFE) by Longitude")
ax.set_xlabel("Longitude (deg)")
ax.set_ylabel("Latitude (deg)")

plt.tight_layout()
plt.show()
