The aim is to read plot EFE

1) Adam, O., T. Bischoff, and T. Schneider, 2016: Seasonal and Interannual Variations of the Energy Flux Equator and ITCZ. Part I: Zonally Averaged ITCZ Position. J. Climate, 29, 3219–3230, https://doi.org/10.1175/JCLI-D-15-0512.1.
2) EFPM : NOT USING NOW. keeping the reference here: Boos, W., Korty, R. Regional energy budget control of the intertropical convergence zone and application to mid-Holocene rainfall. Nature Geosci 9, 892–897 (2016). https://doi.org/10.1038/ngeo2833


In [None]:
import numpy as np
import xarray as xr
# our local module:
import itcz

import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
import xarray as xr
from pathlib import Path
import myfunctions as mf

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
# from xarrayutils import divergence_spherical, helmholtz_decomposition_spectral  # placeholder functions

In [None]:
ds_t = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/ta/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("T in units of K",ds_t)

ds_q = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/hus/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("q in units of 1",ds_q)

ds_z = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/zg/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("z in unit of meter",ds_z)

ds_u = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/ua/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("u in unit of meter/sec",ds_u)

ds_v = xr.open_mfdataset(
    "/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/historical/r1i1p1f2/Amon/va/gn/latest/*.nc",
    combine="by_coords", parallel=True, decode_times=True, use_cftime=True)

print("z in unit of meter/sec",ds_v)

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import myfunctions as mf

# ----------------------------
# CONSTANTS
# ----------------------------
Cp = 1004.0        # J kg-1 K-1
Lv = 2.5e6         # J kg-1
g  = 9.81          # m s-2
R  = 6.371e6       # Earth radius (m)

# ----------------------------
# BLOCK 1: TIME MEAN
# ----------------------------
# ta_mean = ds_t.ta.mean("time")
# q_mean  = ds_q.hus.mean("time")
# z_mean  = ds_z.zg.mean("time")
# v_mean  = ds_v.va.mean("time")   # use divergent component later

ta_mean = mf.seasonal_mean_by_year(ds_t.ta, 6, 9).mean("year")
q_mean  = mf.seasonal_mean_by_year(ds_q.hus, 6, 9).mean("year")
z_mean  = mf.seasonal_mean_by_year(ds_z.zg, 6, 9).mean("year")
v_mean  = mf.seasonal_mean_by_year(ds_v.va, 6, 9).mean("year") # use divergent component later

# ----------------------------
# BLOCK 2: SELECT SECTOR (50E-120E)
# ----------------------------
# lon_min, lon_max = 50, 120
lon_min, lon_max = 0, 360

ta_sec = ta_mean.sel(lon=slice(lon_min, lon_max))
q_sec  = q_mean.sel(lon=slice(lon_min, lon_max))
z_sec  = z_mean.sel(lon=slice(lon_min, lon_max))
v_sec  = v_mean.sel(lon=slice(lon_min, lon_max))

# ----------------------------
# BLOCK 2a: BAROTROPIC MASS CORRECTION
# ----------------------------
# Convert pressure to Pa if needed
plev_pa = v_sec.plev ;# * 100.0
v_sec = v_sec.assign_coords(plev=plev_pa)

# Compute vertical mean (mass-weighted if needed)
# Simple average for now (can include dp weights later)
v_vert_mean = v_sec.mean("plev")

# Subtract vertical mean to get mass-corrected meridional wind
Vc = v_sec - v_vert_mean

# ----------------------------
# BLOCK 3: MOIST STATIC ENERGY
# ----------------------------
h = Cp * ta_sec + Lv * q_sec + g * z_sec

#Interpolate to wind grid for consistency
h_on_Vc = h.interp(
    lat=Vc.lat,
    lon=Vc.lon,
    plev=Vc.plev,
    method="linear"
)


# ----------------------------
# BLOCK 4: MERIDIONAL MSE FLUX
# ----------------------------
vh = Vc * h_on_Vc

# ----------------------------
# BLOCK 5: SECTOR ZONAL MEAN
# ----------------------------
vh_zm = vh.mean("lon")

# ----------------------------
# BLOCK 6: VERTICAL INTEGRATION
# ----------------------------
# Ensure pressure increases downward (top -> bottom)
vh_sorted = vh_zm.sortby("plev")  


# # ----------------------------
# # BLOCK 8: COMPUTE EFE
# # ----------------------------
# lat_vals = F_lat.lat.values
# F_vals   = F_lat.values

# # Find first sign change
# sign_change = np.where(np.diff(np.sign(F_vals)))[0][0]

# lat1, lat2 = lat_vals[sign_change], lat_vals[sign_change + 1]
# F1, F2     = F_vals[sign_change],  F_vals[sign_change + 1]

# # Linear interpolation to find zero crossing
# EFE = lat1 - F1 * (lat2 - lat1) / (F2 - F1)

# print("Sector Energy Flux Equator (EFE) latitude:", EFE)


In [None]:
import matplotlib.pyplot as plt
h_com=h_on_Vc.compute()
# h_com=vh.compute()
h_com_mean = h_com.mean(dim=["plev", "lon"])

h_com_mean.plot()
plt.show()

In [None]:
print(np.diff(vh_zm.plev.values))
plt.plot(np.diff(vh_zm.plev.values))

In [None]:
print(np.diff(vh_sorted.plev.values))
plt.plot(np.diff(vh_sorted.plev.values))

In [None]:
# Since I see NaNs above, Remove any NaNs

vh_computed = vh_sorted.compute()
# vh_computed = vh_zm.compute()
vh_clean = vh_computed.fillna(0.0)

In [None]:
#force strictly monotonic pressure

plev_vals = np.array(vh_clean.plev.values, dtype=float)
order = np.argsort(plev_vals)

vh_clean = vh_clean.isel(plev=order)
vh_clean = vh_clean.assign_coords(plev=plev_vals[order])


In [None]:
#Integrate
vh_int = vh_clean.integrate("plev")
print(vh_int.min().values, vh_int.max().values)

In [None]:
F_lat = (R * np.cos(np.deg2rad(vh_int.lat)) / g) * vh_int
F_lat.plot()

In [None]:
F_lat.max()

In [None]:
# ----------------------------
# BLOCK 7: PLOT MERIDIONAL ENERGY TRANSPORT
# ----------------------------
plt.figure(figsize=(8,4))
plt.plot(F_lat.lat, F_lat, label="Meridional MSE transport")
plt.axhline(0, linestyle="--", color="k")
plt.xlabel("Latitude")
plt.ylabel("Meridional MSE transport [J m^-1 s^-1]")
plt.title("Global-mean SUMMER vertically integrated meridional MSE transport")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# ----------------------------
# BLOCK 8: COMPUTE EFE
# ----------------------------

F_lat_tropics=F_lat.sel(lat=slice(-20,20))

lat_vals = F_lat_tropics.lat.values
F_vals   = F_lat_tropics.values

# Find first sign change
sign_change = np.where(np.diff(np.sign(F_vals)))[0][0]

lat1, lat2 = lat_vals[sign_change], lat_vals[sign_change + 1]
F1, F2     = F_vals[sign_change],  F_vals[sign_change + 1]

# Linear interpolation to find zero crossing
EFE = lat1 - F1 * (lat2 - lat1) / (F2 - F1)

print("Energy Flux Equator (EFE) latitude:", EFE)

In [None]:
# ----------------------------
# BLOCK 7: PLOT MERIDIONAL ENERGY TRANSPORT
# ----------------------------
F_lat_45=F_lat.sel(lat=slice(-30,30))
plt.figure(figsize=(8,4))
plt.plot(F_lat_45.lat, F_lat_45, label="Meridional MSE transport")
plt.axhline(0, linestyle="--", color="k")
plt.xlabel("Latitude")
plt.ylabel("Meridional MSE transport [J m^-1 s^-1]")
plt.title("Global-mean vertically integrated meridional MSE transport")
plt.grid(True)
plt.legend()
plt.show()