In [None]:
import matplotlib.pyplot as plt
import numpy as np
import warnings

warnings.filterwarnings("ignore")

In [None]:
def psi(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: streamfunction (2d array)
    """
    return np.outer((a1 * np.sin(k1 * x) + a2 * np.sin(k2 * x)), np.sin(l * y))


def u(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: zonal wind (2d array)
    """
    return np.outer(-(a1 * np.sin(k1 * x) + a2 * np.sin(k2 * x)), np.cos(l * y) * l)


def v(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: meridional wind (2d array)
    """
    return np.outer(
        (k1 * a1 * np.cos(k1 * x) + k2 * a2 * np.cos(k2 * x)), np.sin(l * y)
    )


def zeta(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: vorticity (2d array)
    """
    return np.outer(
        (-(k1 ** 2) * a1 * np.sin(k1 * x) - k2 ** 2 * a2 * np.sin(k2 * x)),
        np.sin(l * y),
    ) - np.outer((a1 * np.sin(k1 * x) + a2 * np.sin(k2 * x)), np.sin(l * y) * l ** 2)


def dx_zeta(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: x-derivative of vorticity (2d array)
    """
    return np.outer(
        (-(k1 ** 3) * a1 * np.cos(k1 * x) - k2 ** 3 * a2 * np.cos(k2 * x)),
        np.sin(l * y),
    ) - np.outer(
        (k1 * a1 * np.cos(k1 * x) + k2 * a2 * np.cos(k2 * x)), np.sin(l * y) * l ** 2,
    )


def dy_zeta(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: y-derivative of vorticity (2d array)
    """
    return np.outer(
        (-(k1 ** 2) * a1 * np.sin(k1 * x) - k2 ** 2 * a2 * np.sin(k2 * x)),
        np.cos(l * y) * l,
    ) - np.outer((a1 * np.sin(k1 * x) + a2 * np.sin(k2 * x)), np.cos(l * y) * l ** 3)


def adv(x, y, k1, k2, l, a1, a2):
    """
    :param x: x grid (1d array)
    :param y: y grid (1d array)
    :param k1: zonal wavenumber of wave 1 (float)
    :param k2: zonal wavenumber of wave 2 (float)
    :param l: meridional wavenumber (float)
    :param a1: amplitude of wave 1 (float)
    :param a2: amplitude of wave 2 (float)
    :return: u*d_zeta/dx + v*d_zeta/dy (2d array)
    """
    params = (x, y, k1, k2, l, a1, a2)
    return u(*params) * dx_zeta(*params) + v(*params) * dy_zeta(*params)


def spectral_power(field, x, y):
    """
    compute zonal power spectrum
    :param field: for spectral analysis (2d array)
    :param x: x grid in km (1d array)
    :param y: y grid in km (1d array)
    :return: frequencies (1d array) and sepctral power (1d array)
    """
    dx = x[1] - x[0]
    nx, ny = (len(x), len(y))
    freq = np.fft.fftfreq(nx, d=dx)
    fft = field[:, int(ny / 2)]  # pick middle x series
    psd = np.square(np.abs(np.fft.fft(fft))) / nx
    return freq, psd


def plot_triad_interaction(x, y, l, a1, a2, psi, adv, freq_psi, psd_psi, freq_adv, psd_adv):
    # plot
    fig, axes = plt.subplots(2, 2, figsize=(7, 4), gridspec_kw={"width_ratios": [2, 1]})

    plot_xlim = 1_000
    dx = x[1] - x[0]
    idx = slice(0, plot_xlim // dx)  # x-range to plot [km]

    ax = axes[0, 0]
    p = ax.contourf(
        x[idx], y, psi.T[:, idx], levels=18, cmap="RdBu"
    )
    ax.axhline(np.max(y) / 2, color="black", ls=":", lw=3)
    # fig.colorbar(p, ax=ax)
    ax.set_title(r"$\psi$", loc="center")
    ax.set_ylabel("$y$ [km]")
    ax.set_xlabel("$x$ [km]")

    ax = axes[1, 0]
    p = ax.contourf(
        x[idx], y, adv.T[:, idx], levels=22, cmap="RdBu"
    )
    # fig.colorbar(p, ax=ax)
    ax.axhline(np.max(y) / 4, color="black", ls=":", lw=3)
    ax.set_title(
        r"$u \partial_x \zeta + v \partial_y \zeta $", loc="center",
    )
    ax.set_ylabel("$y$ [km]")
    ax.set_xlabel("$x$ [km]")

    ax = axes[0, 1]
    ax.plot(1 / freq_psi, psd_psi, marker="x", c="black")
    ax.set_xlim(0, 800)
    ax.set_xlabel("$\lambda_x$ [km]")
    ax.set_title("Spectral Power")
    ax.set_yticklabels([])

    ax = axes[1, 1]
    ax.plot(1 / freq_adv, psd_adv, marker="x", c="black")
    ax.set_xlim(0, 800)
    ax.set_xlabel("$\lambda_x$ [km]")
    ax.set_yticklabels([])

    for ax in axes.flatten():
        ax.minorticks_on()

    fig.tight_layout()
    return fig, ax

In [None]:
# set up domain
dx, dy = (5, 5)  # grid spacing x,y [km]
x = ...  # x grid [km]
y = ...  # y grid [km]

# set up two waves
k1 = ...  # zonal wavenumber wave 1
k2 = ...  # zonal wavenumber wave 2
l = ...  # meridional wavenumber
a1, a2 = (1, 1)  # amplitudes of the two waves

# compute streamfunction and advection term
psi_ = psi(x, y, k1, k2, l, a1, a2)
adv_ = adv(x, y, k1, k2, l, a1, a2)

# compute power spectra
freq_psi, psd_psi = spectral_power(psi_, x, y)
freq_adv, psd_adv = spectral_power(adv_, x, y)

# plot
fig, ax = plot_triad_interaction(x, y, l, a1, a2, psi_, adv_, freq_psi, psd_psi, freq_adv, psd_adv)
# fig.savefig("sheet10_triads.pdf", bbox_inches="tight")
plt.show()