# Rotational and divergent components

In [None]:
import warnings

warnings.filterwarnings("ignore")  # noqa

In [None]:
# Standard library
import multiprocessing.popen_spawn_posix

# Data analysis and viz libraries
import dask
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from dask.distributed import Client

In [None]:
# Local modules
import mypaths
import names
from calc import (
    wind_rot_div,
)
from grid import (
    spatial_mean,
)
from load_thai import LOAD_CONF
from model_exocam import calc_pres_exocam
from plot_func import (
    KW_AUX_TTL,
    KW_AXGR,
    KW_MAIN_TTL,
    KW_SBPLT_LABEL,
    draw_scalar,
    draw_vector,
    figsave,
    linspace_pm1,
    make_map_figure,
)

In [None]:
plt.style.use("paper.mplstyle")

In [None]:
# client = Client(processes=True, n_workers=4, threads_per_worker=1)
# client

## Choose case

In [None]:
THAI_case = "Hab1"

## Constants

Define atmospheric and planetary constants. Note the planet's radius is in meters!

In [None]:
if THAI_case.endswith("1"):
    import const_ben1_hab1 as const
else:
    import const_ben2_hab2 as const

## Loading the data

Lazily load all data into one dictionary.

In [None]:
# Load data
datasets = {}  # Create an empty dictionary to store all data
# Loop over THAI cases (this loop can swapped places with the loop below if needed)
for THAI_case in [THAI_case]:
    # for each of the THAI cases, create a nested directory for models
    datasets[THAI_case] = {}
    for model_key in LOAD_CONF.keys():
        datasets[THAI_case][model_key] = LOAD_CONF[model_key]["loader"](THAI_case)

Regrid ExoCAM and ROCKE3D data to be compatible with `windspharm`: if latitudes are equally-spaced and even-numbered, they should not include poles.

In [None]:
for model_key in ["ExoCAM", "ROCKE3D"]:
    model_names = getattr(names, model_key.lower())
    nlat = 50  # new number of latitudes: 50
    delta_lat = 180 / nlat
    new_lats = np.linspace(90 - 0.5 * delta_lat, -90 + 0.5 * delta_lat, nlat)
    new_ds = {}
    for d in datasets[THAI_case][model_key].data_vars:
        vrbl = datasets[THAI_case][model_key][d]
        if model_names.y in vrbl.dims:
            new_ds[d] = vrbl.interp(
                **{model_names.y: new_lats, "kwargs": {"fill_value": "extrapolate"}},
            )
        else:
            new_ds[d] = vrbl
    datasets[THAI_case][model_key] = xr.Dataset(new_ds)

In [None]:
wind_cmpnts = {}
for model_key in LOAD_CONF.keys():
    ds = datasets[THAI_case][model_key]
    model_names = getattr(names, model_key.lower())
    u_tm = ds[model_names.u].mean(model_names.t)
    v_tm = ds[model_names.v].mean(model_names.t)

    wind_cmpnts[model_key] = wind_rot_div(u_tm, v_tm, truncation=None, const=const)
    if model_key == "ExoCAM":
        # Approximately the pressure velocity (Pa/s) to the vertical wind velocity (m/s)
        pres_ExoCAM = calc_pres_exocam(ds)
        rho_ExoCAM = pres_ExoCAM / (const.rgas * ds.T)
        wind_cmpnts[model_key]["w"] = (
            ds[model_names.w] * (-1 / (const.gplanet * rho_ExoCAM))
        ).mean(model_names.t)
    elif model_key in ["ROCKE3D", "LMDG"]:
        wind_cmpnts[model_key]["w"] = -1 * ds[model_names.w].mean(model_names.t)
    else:
        wind_cmpnts[model_key]["w"] = ds[model_names.w].mean(model_names.t)

## Diagnostics and plots

In [None]:
WIND_CMPNT_META = {
    "total": {"title": "Total", "qk_ref_wspd": 30, "kw_plt": dict(color="#222222")},
    "rot_zm": {
        "title": "Zonal mean rotational",
        "qk_ref_wspd": 30,
        "kw_plt": dict(color="C1"),
    },
    "rot_eddy": {
        "title": "Eddy rotational",
        "qk_ref_wspd": 15,
        "kw_plt": dict(color="C2"),
    },
    "div": {"title": "Divergent", "qk_ref_wspd": 15, "kw_plt": dict(color="C0")},
}
KW_QUIVER = dict(
    scale_units="inches",
    scale=100,
    cmap="magma_r",
    # facecolors=("#444444"),
    edgecolors=("#EEEEEE"),
    linewidths=0.15,
    width=0.004,
    headaxislength=4,
)
KW_QUIVERKEY = dict(
    labelpos="N",
    labelsep=0.05,
    coordinates="axes",
    color="#444444",
    fontproperties=dict(size="small"),
)
KW_W_CNTRF = dict(cmap="RdBu_r", levels=linspace_pm1(5) * 0.1, extend="both")
skips = {
    "ExoCAM": (4, 3),
    "LMDG": (4, 3),
    "ROCKE3D": (4, 3),
    "UM": (8, 6),
}

In [None]:
P_LEVEL = 250
height_constraints = {
    "Ben1": {
        "ExoCAM": dict({names.exocam.lev: P_LEVEL}, method="nearest"),
        "LMDG": dict({names.lmdg.z: 7}, method="nearest"),
        "ROCKE3D": dict({names.rocke3d.lev: 19}),
        "UM": dict({names.um.z: 11_500}, method="nearest"),
    },
    "Ben2": {
        "ExoCAM": dict({names.exocam.lev: P_LEVEL}, method="nearest"),
        "LMDG": dict({names.lmdg.z: 7}, method="nearest"),
        "ROCKE3D": dict({names.rocke3d.lev: 20}),
        "UM": dict({names.um.z: 7_000}, method="nearest"),
    },
    "Hab1": {
        "ExoCAM": dict({names.exocam.lev: P_LEVEL}, method="nearest"),
        "LMDG": dict({names.lmdg.z: 7}, method="nearest"),
        "ROCKE3D": dict({names.rocke3d.lev: 19}),
        "UM": dict({names.um.z: 11_000}, method="nearest"),
    },
    "Hab2": {
        "ExoCAM": dict({names.exocam.lev: P_LEVEL}, method="nearest"),
        "LMDG": dict({names.lmdg.z: 7}, method="nearest"),
        "ROCKE3D": dict({names.rocke3d.lev: 20}),
        "UM": dict({names.um.z: 7_000}, method="nearest"),
    },
}

In [None]:
%%time
ncols = 4
nrows = 4
fig, axgr = make_map_figure(ncols, nrows, **KW_AXGR)
# cbar_axes_col = np.array(axgr.cbar_axes).reshape((3, 2)).T
cax = axgr.cbar_axes[0]
fig.suptitle(
    f"{THAI_case}\n~{P_LEVEL} hPa",
    y=0.925,
)

for model_key, axcol in zip(LOAD_CONF.keys(), axgr.axes_column):
    model_names = getattr(names, model_key.lower())
    lev_sel = height_constraints[THAI_case][model_key]
    if model_key != "ExoCAM":
        print(
            model_key,
            float(
                spatial_mean(
                    datasets[THAI_case][model_key][model_names.pres]
                    .sel(**lev_sel)
                    .mean(model_names.t)
                )
            ),
        )
    for ax, (wind_key, wind_meta) in zip(axcol, WIND_CMPNT_META.items()):
        ax.set_title(wind_meta["title"], **KW_MAIN_TTL)
        ax.set_title(model_key, **KW_AUX_TTL)
        draw_scalar(
            wind_cmpnts[model_key]["w"].sel(**lev_sel),
            ax,
            method="contourf",
            cax=cax,
            tex_units="$m$ $s^{-1}$",
            cbar_ticks=None,
            use_cyclic=False,
            model_names=model_names,
            **KW_W_CNTRF,
        )
        u, v = (
            wind_cmpnts[model_key][f"u_{wind_key}"].sel(**lev_sel),
            wind_cmpnts[model_key][f"v_{wind_key}"].sel(**lev_sel),
        )
        _wspd = (u ** 2 + v) ** 0.5
        QK_REF_WSPD = wind_meta["qk_ref_wspd"]
        _kw_quiv = {
            **KW_QUIVER,
            "scale": QK_REF_WSPD * 4,
            "norm": plt.Normalize(0, QK_REF_WSPD),
        }
        draw_vector(
            u,
            v,
            ax,
            # cax=cax,
            tex_units="$m$ $s^{-1}$",
            # cbar_ticks=None,
            mag=(_wspd,),
            xstride=skips[model_key][0],
            ystride=skips[model_key][1],
            qk_ref_wspd=QK_REF_WSPD,
            kw_quiver=_kw_quiv,
            kw_quiverkey=KW_QUIVERKEY,
            model_names=model_names,
            quiverkey_xy=(0.1, 0.935),
        )

In [None]:
figsave(
    fig,
    mypaths.plotdir / f"{THAI_case}__rot_div_vert_wind_map__{int(P_LEVEL)}hpa",
)

In [None]:
# model_key = "UM"
# model_names = getattr(names, model_key.lower())
# ds = datasets[THAI_case][model_key]
# lon2d, lat2d = np.meshgrid(ds[model_names.x], ds[model_names.y])
# lev_sel = dict({model_names.z: 11.5e3}, method="nearest")

In [None]:
# %%time
# ncols = 4
# nrows = 4
# fig, axgr = make_map_figure(ncols, nrows, **KW_AXGR)
# # cbar_axes_col = np.array(axgr.cbar_axes).reshape((3, 2)).T
# cax = axgr.cbar_axes[0]

# for ax, (wind_key, wind_meta) in zip(axgr.axes_column, WIND_CMPNT_META.items()):
#     ax.set_title(wind_meta["title"], **KW_MAIN_TTL)
#     ax.set_title(f"({next(iletters)})", **KW_SBPLT_LABEL)
#     ax.set_ylim(-90, 90)
#     ax.set_yticks(YLOCS)
#     ax.set_yticklabels([fmt_lonlat(i, "lat", True) for i in YLOCS])
#     ax.set_xlim(-180, 180)
#     ax.set_xticks(XLOCS)
#     ax.set_xticklabels([fmt_lonlat(i, "lon", True) for i in XLOCS])
#     if ax.is_first_col():
#         ax.set_ylabel("Latitude [$\degree$]")
#     if ax.is_last_row():
#         ax.set_xlabel("Longitude [$\degree$]")
#     ax.contourf(
#         lons.points,
#         lats.points,
#         w_plev.extract(lev_constr).data,
#         cmap=cm.vik,
#         levels=clev101(11) * 5e-2,
#         extend="both",
#     )
#     u, v = (
#         wind_cmpnts[model_key][f"u_{wind_key}"].sel(**lev_sel),
#         wind_cmpnts[model_key][f"v_{wind_key}"].sel(**lev_sel),
#     )
#     QK_REF_WSPD = wind_meta["qk_ref_wspd"]
#     _kw_quiv = {**KW_QUIVER, **{"scale": QK_REF_WSPD * 4}}
#     quiv = ax.quiver(
#         lon2d[SKIP], lat2d[SKIP],
#         u[SKIP],
#         v[SKIP],
#         (u[SKIP] ** 2 + v[SKIP] ** 2) ** 0.5,
#         norm=plt.Normalize(0, wind_meta["qk_ref_wspd"]),
#         **_kw_quiv,
#     )
#     qk = ax.quiverkey(
#         quiv,
#         *(0.125, 0.9),
#         QK_REF_WSPD,
#         fr"${QK_REF_WSPD}$" + r" $m$ $s^{-1}$",
#         **KW_QUIVERKEY,
#     )
#     fig.colorbar(quiv, ax=ax)
# fig.suptitle(
#     f'{THAI_case}\n{int(100500):d} hPa',
#     y=0.95,
# )