In [None]:
%load_ext autoreload
%autoreload 2
import json
import numpy as np
import matplotlib.pyplot as plt
from dedalus import public as d3
import pathlib
import glob
import h5py
from scipy.fft import fft, fftfreq
from scipy.signal import find_peaks

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

%matplotlib inline
%config InlineBackend.figure_format='retina'

import sys
sys.path.append(str(pathlib.Path("..").joinpath("IVP")))
import plotting_setup
sans_name = plotting_setup.get_font("Avenir","sans")

figure_path = pathlib.Path('figures').absolute()
if not figure_path.exists():
    figure_path.mkdir()
wkb_fig_path = figure_path.joinpath('wkb')
if not wkb_fig_path.exists():
    wkb_fig_path.mkdir()

Input params

In [None]:
# Input parameters
sim_name = 'sim27'

evan_igw_0_name = "sparse_along_branch_target_eigfxn_exp_fulldiff_ky2p0pi_Gamma_0p1_eta0_Z0_ind2_cos_ky2p00e+00pi_Gamma1p00e-01_eta1p0e-08_Fr2p5e-02_1024_3072+x_diffusion_indep_eta+y_diffusion+z_diffusion.hdf5"
evan_igw_1_name = "sparse_along_branch_target_eigfxn_exp_fulldiff_ky2p0pi_Gamma_0p1_eta0_Z0_ind4_sin_ky2p00e+00pi_Gamma1p00e-01_eta1p0e-08_Fr2p5e-02_1024_3072+x_diffusion_indep_eta+y_diffusion+z_diffusion.hdf5"
evan_sm_0_name = "sparse_along_branch_target_eigfxn_exp_fulldiff_ky2p0pi_Gamma_0p1_eta0_Z0_ind3_cos_ky2p00e+00pi_Gamma1p00e-01_eta1p0e-08_Fr2p5e-02_1024_3072+x_diffusion_indep_eta+y_diffusion+z_diffusion.hdf5"
evan_sm_1_name = "sparse_along_branch_target_eigfxn_exp_fulldiff_ky2p0pi_Gamma_0p1_eta0_Z0_ind5_sin_ky2p00e+00pi_Gamma1p00e-01_eta1p0e-08_Fr2p5e-02_1024_3072+x_diffusion_indep_eta+y_diffusion+z_diffusion.hdf5"

### Load last snapshot from sim and plot

Load data

In [None]:
IVP_dir = pathlib.Path('..').joinpath('IVP')
params_path = IVP_dir.joinpath('params').joinpath(sim_name+'.json')
file_dir = ((IVP_dir.joinpath('data')).joinpath(sim_name)).joinpath(sim_name+'-last.hdf5')

# Load data
IVP_dict = {}
with h5py.File(file_dir, "r") as f:
    for key in list(f.keys()):
        IVP_dict[key] = f[key][()]

with open(params_path) as f: 
    params_IVP = json.load(f)

# IVP_dict
p_IVP = IVP_dict['p']
u_IVP = IVP_dict['u']
v_IVP = IVP_dict['v']
w_IVP = IVP_dict['w']
rho_IVP = IVP_dict['rho']
Bx_IVP = IVP_dict['Bx']
By_IVP = IVP_dict['By']
Bz_IVP = IVP_dict['Bz']
x_IVP = IVP_dict['x']
Z_IVP = IVP_dict['Z']
Fr = params_IVP['Fr']

Nx = params_IVP['Nx']
NZ = params_IVP['NZ']
LZ = params_IVP['LZ']
x = x_IVP

In [None]:
xx_IVP, ZZ_IVP = np.meshgrid(x_IVP,Z_IVP,indexing='ij')

In [None]:
tasks = [p_IVP, u_IVP, v_IVP, w_IVP, Bx_IVP, By_IVP, Bz_IVP]
task_labels = ['$p$','$u$','$v$','$w$',"$b_x$","$b_y$","$b_z$"]
fontsize = 14
plt_strd = 1
fig, axs = plt.subplots(1,len(tasks),figsize=(3*(len(tasks)),7.5))

for i, plot_field in enumerate(tasks):
    vmin = np.quantile(np.abs(plot_field.real[:,Z_IVP<0.1]),q=1)
    pmesh = axs[i].pcolormesh(xx_IVP[::plt_strd,::plt_strd],ZZ_IVP[::plt_strd,::plt_strd],plot_field[::plt_strd,::plt_strd].real,cmap='RdBu_r',vmin=(-vmin,vmin))
    plt.colorbar(pmesh,ax=axs[i],shrink=1,location='top',extend='both')
    axs[i].set_xlim(x_IVP[0],x_IVP[-1])
    axs[i].set_ylim(Z_IVP[0],Z_IVP[-1])
    axs[i].set_title(task_labels[i],fontsize=fontsize)
    axs[i].set_xlabel('$x$',fontsize=fontsize)
    axs[i].set_ylabel('$Z$',fontsize=fontsize)

plt.subplots_adjust(top=1,wspace=0.5)

### Check that sim is equilibrated

In [None]:
time_series_pattern = ((IVP_dir.joinpath('data')).joinpath(sim_name)).joinpath(sim_name+"-timeseries*.hdf5")
time_series_path = glob.glob(str(time_series_pattern))[0]

# Load data
timeseries_dict = {}
with h5py.File(time_series_path, "r") as f:
    for key in list(f.keys()):
        timeseries_dict[key] = f[key][()]

In [None]:
for scalar in ['E','ME','KE']:
    plt.plot(timeseries_dict['t'],timeseries_dict[scalar],label=scalar)
plt.legend()
plt.xlabel('$t$')
plt.grid(alpha=0.5)
# plt.yscale('log')

### Check that sim is resolved

In [None]:
Z_plot = 0.16
Z_plot_idx = np.argmin(np.abs(np.array(Z_IVP) - Z_plot))

yf = fft(v_IVP[:,Z_plot_idx])
xf = fftfreq(len(x_IVP), np.diff(x_IVP)[0])[:len(x_IVP)//2]
fig,axs=plt.subplots(1,2,figsize=(15,5))
axs[0].loglog(xf[0::2], 2.0/len(x_IVP) * np.abs(yf[0:len(x_IVP)//2])[0::2])
axs[1].loglog(xf[1::2], 2.0/len(x_IVP) * np.abs(yf[0:len(x_IVP)//2])[1::2])
plt.suptitle(f"$Z$ = {Z_IVP[Z_plot_idx]:.3f}")

In [None]:
plt.figure(figsize=(15,4))
plt.plot(x_IVP,v_IVP[:,Z_plot_idx].real)
plt.scatter(x_IVP,v_IVP[:,Z_plot_idx].real,s=3)
plt.title(f"$x$ = {Z_IVP[Z_plot_idx]:.3f}")

In [None]:
x_plot = 0.2
x_plot_idx = np.argmin(np.abs(np.array(x_IVP) - x_plot))

yf = fft(v_IVP[x_plot_idx,:])
xf = fftfreq(len(Z_IVP), np.diff(Z_IVP)[0])[:len(Z_IVP)//2]
fig,axs=plt.subplots(1,1,figsize=(7.5,5))
axs = [axs]
axs[0].loglog(xf, 2.0/len(Z_IVP) * np.abs(yf[0:len(Z_IVP)//2]))
plt.suptitle(f"$x$ = {x_IVP[x_plot_idx]:.3f}")

In [None]:
plt.figure(figsize=(4,15))
plt.plot(v_IVP[x_plot_idx,:].real,Z_IVP)
plt.scatter(v_IVP[x_plot_idx,:].real,Z_IVP,s=3)
plt.title(f"$x$ = {x_IVP[x_plot_idx]:.3f}")

# Construct WKB solution

### Load EVP data

In [None]:
full_branch_fnames = [evan_igw_0_name,evan_igw_1_name,evan_sm_0_name,evan_sm_1_name]
full_branch_names = ["evan-IGW-0","evan-IGW-1","evan-SM-0","evan-SM-1"]
full_branch_ids = {nm: i for i,nm in enumerate(full_branch_names)}
N_full_branches = len(full_branch_fnames)

In [None]:
# Get parameters
full_branch_params = {}
for i in range(N_full_branches):
    full_branch_fname = full_branch_fnames[i]
    full_branch_params_path = pathlib.Path('params').absolute().joinpath(full_branch_fname.split(".")[0]+".json")
    with open(full_branch_params_path) as f: 
        full_branch_params[full_branch_names[i]] = json.load(f)

# Check dimensions

Nx_EVP = full_branch_params["evan-IGW-0"]['Nx']
NZ_EVP = full_branch_params["evan-IGW-0"]['NZ_total']

for nm in full_branch_names[1:]:
    if full_branch_params[nm]['Nx'] != Nx_EVP:
        raise RuntimeError(f"Nx={full_branch_params[nm]['Nx']} for {nm} while Nx={Nx_EVP} for evan-IGW-0.")
    if full_branch_params[nm]['NZ_total'] != NZ_EVP:
        raise RuntimeError(f"NZ={full_branch_params[nm]['NZ_total']} for {nm} while NZ={NZ_EVP} for evan-IGW-0.")
if NZ_EVP%NZ != 0:
    raise RuntimeError("NZ_EVP is not an integer multiple of NZ for the IVP")
else:
    Z_stride_EVP = int(NZ_EVP/NZ)

x_EVP = np.linspace(0,1,Nx_EVP)

In [None]:
# Load data
data_path = pathlib.Path('data').absolute()

p_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
u_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
v_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
w_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
bx_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
by_full_branches_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)

kz_full_branches_stack = np.zeros((NZ,N_full_branches),dtype=np.complex128)

for i in range(N_full_branches):
    full_branch_fname = full_branch_fnames[i]
    full_branch_path = data_path.joinpath(full_branch_fname)
    with h5py.File(full_branch_path, "r") as f:
        p_full_branches_stack[...,i] = f["p"][:,::Z_stride_EVP]
        u_full_branches_stack[...,i] = f["u"][:,::Z_stride_EVP]
        v_full_branches_stack[...,i] = f["v"][:,::Z_stride_EVP]
        w_full_branches_stack[...,i] = f["w"][:,::Z_stride_EVP]
        bx_full_branches_stack[...,i] = f["bx"][:,::Z_stride_EVP]
        by_full_branches_stack[...,i] = f["by"][:,::Z_stride_EVP]

        kz_full_branches_stack[...,i]= f["kz"][::Z_stride_EVP]
        Z_EVP = f["Z"][::Z_stride_EVP].real

field_full_branches_stack_list = [p_full_branches_stack,u_full_branches_stack,v_full_branches_stack,w_full_branches_stack,bx_full_branches_stack,by_full_branches_stack]

### Identify turning points

In [None]:
# Find turning points
dkzdz_full_branches_stack = Fr * np.gradient(kz_full_branches_stack,Z_EVP,axis=0)

turning_pt_ids_stack = []
turning_pt_stack = []
for i in range(N_full_branches):
    peak_ids,_ = find_peaks(np.abs(dkzdz_full_branches_stack[...,i]),prominence=1)
    turning_pt_ids_stack.append(peak_ids)

# Set the first IGW and SM turning points equal to each other
turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0] = turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0]
turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0] = turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0]

for i in range(N_full_branches):
    peak_ids = turning_pt_ids_stack[i]
    Z_t_list = Z_EVP[peak_ids]
    turning_pt_stack.append(Z_t_list)

### Identify point where evan-SM-1 intersects the Alfven wavenumber boundary

In [None]:
kb = params_IVP['kb']
Gamma = params_IVP['Gamma']
alfven_bdry = 1/Gamma*np.exp(kb*Z_EVP)

alfven_intsc_idx = np.argmin(np.abs((kz_full_branches_stack[...,full_branch_ids["evan-SM-1"]]).real - alfven_bdry ))
alfven_intsc_Z = Z_EVP[alfven_intsc_idx]
alfven_intsc_kz = kz_full_branches_stack[alfven_intsc_idx,full_branch_ids["evan-SM-1"]]

In [None]:
fig,axs = plt.subplots(1,4,figsize=(20,5))
for i in range(N_full_branches):
    axs[0].plot(kz_full_branches_stack[...,i].real,Z_EVP,label=full_branch_names[i])
    axs[1].plot(kz_full_branches_stack[...,i].real,Z_EVP,label=full_branch_names[i])
    axs[2].plot(kz_full_branches_stack[...,i].imag,Z_EVP,label=full_branch_names[i])
    axs[3].plot(np.abs(dkzdz_full_branches_stack[...,i]),Z_EVP,label=full_branch_names[i])
    axs[3].scatter(np.abs(dkzdz_full_branches_stack[...,i])[turning_pt_ids_stack[i]],turning_pt_stack[i])
    for ax in axs:
        for Z_t in turning_pt_stack[i]:
            ax.axhline(Z_t,color=colors[i],linestyle='--',linewidth=0.75)
        ax.set_ylim(0,LZ)
xlims = axs[0].get_xlim()
axs[0].fill_betweenx(Z_EVP,alfven_bdry,100*np.ones(NZ),color='lightgray',zorder=0)
axs[1].fill_betweenx(Z_EVP,alfven_bdry,100*np.ones(NZ),color='lightgray',zorder=0)
axs[0].set_xlim(xlims)
axs[0].set_xlim(xlims)
axs[1].set_xlim(10.5,13)
axs[1].set_ylim(0.03,0.07)
axs[0].legend()
axs[0].set_xlabel(r"$\Re\{k_z\}$",fontsize=12)
axs[1].set_xlabel(r"$\Re\{k_z\}$",fontsize=12)
axs[2].set_xlabel(r"$\Im\{k_z\}$",fontsize=12)
axs[3].set_xlabel(r"$| \mathrm{d} k_z/\mathrm{d} z |$",fontsize=12)
axs[0].set_ylabel(r"$Z$",fontsize=12)
plt.show()

### Compute phase function $\Theta(Z)$

In [None]:
# Compute phase function, theta, by integrating kz over Z
theta_full_branches_stack = np.zeros(np.shape(kz_full_branches_stack),dtype=np.complex128)
for i in range(N_full_branches):
    for n in range(NZ-1):
        DeltaZ = Z_EVP[n+1] - Z_EVP[n]
        theta_full_branches_stack[n+1,i] = theta_full_branches_stack[n,i] + DeltaZ*kz_full_branches_stack[n,i]

# Zero out phases of IGW-1 and IGW-0 at forcing height
theta_IGW0_at_forcing_level = theta_full_branches_stack[np.argmin(np.abs(Z_EVP - params_IVP['Z0'])),full_branch_ids['evan-IGW-0']]
theta_IGW1_at_forcing_level = theta_full_branches_stack[np.argmin(np.abs(Z_EVP - params_IVP['Z0'])),full_branch_ids['evan-IGW-1']]
theta_full_branches_stack[...,full_branch_ids['evan-IGW-0']] = theta_full_branches_stack[...,full_branch_ids['evan-IGW-0']] - theta_IGW0_at_forcing_level
theta_full_branches_stack[...,full_branch_ids['evan-IGW-1']] = theta_full_branches_stack[...,full_branch_ids['evan-IGW-1']] - theta_IGW1_at_forcing_level

# Match phases of IGW-0 and SM-0 across turning point
theta_IGW0_turning_pt = theta_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0],full_branch_ids['evan-IGW-0']]
theta_SM0_turning_pt = theta_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0],full_branch_ids['evan-SM-0']]
theta_full_branches_stack[...,full_branch_ids["evan-SM-0"]] = theta_full_branches_stack[...,full_branch_ids["evan-SM-0"]] - theta_SM0_turning_pt + theta_IGW0_turning_pt

# Match phases of IGW-1 and SM-1 across turning point
theta_IGW1_turning_pt = theta_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0],full_branch_ids['evan-IGW-1']]
theta_SM1_turning_pt = theta_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0],full_branch_ids['evan-SM-1']]
theta_full_branches_stack[...,full_branch_ids["evan-SM-1"]] = theta_full_branches_stack[...,full_branch_ids["evan-SM-1"]] - theta_SM1_turning_pt + theta_IGW1_turning_pt

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,5))
for i in range(N_full_branches):
    for ax in axs:
        ax.plot(theta_full_branches_stack[:,i].real,Z_EVP,label=full_branch_names[i])
        ax.set_xlabel(r"$\Re\{\Theta(Z)\}$")
        ax.set_ylabel(r"$Z$")
        for Z_t in [turning_pt_stack[i][0]]:
            ax.axhline(Z_t,color=colors[i],linestyle='--',linewidth=0.75)
        ax.grid(alpha=0.5)
axs[0].legend()
axs[1].set_xlim(-1.8,-1.5)
axs[1].set_ylim(0.04,0.06)

### Normalize and match eigenfunctions

In [None]:
# Normalize eigenfunctions
# Create Dedalus domain and fields
dealias = 1 #3/2
coords = d3.CartesianCoordinates('x')
dist = d3.Distributor(coords, dtype=np.complex128)
xbasis = d3.ComplexFourier(coords['x'], size=Nx_EVP, bounds=(0, 1), dealias=dealias)
xgrid = dist.local_grid(xbasis)
p_mode_field = dist.Field(name='p_mode_field', bases=xbasis)
integ_x = lambda a: d3.Integrate(a, ('x'))

p_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
u_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
v_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
w_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
bx_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)
by_full_branches_normed_stack = np.zeros((Nx_EVP,NZ,N_full_branches),dtype=np.complex128)

field_full_branches_normed_stack_list = [p_full_branches_normed_stack,u_full_branches_normed_stack,v_full_branches_normed_stack,
                                         w_full_branches_normed_stack,bx_full_branches_normed_stack,by_full_branches_normed_stack]

for i in range(N_full_branches):
    for j in range(NZ):
        pmode = p_full_branches_stack[:,j,i]

        # Align phase
        max_p = pmode[np.argmax(np.abs(pmode))]
        pmode_over_max = pmode/max_p

        # Match sign over Z
        if j == 0:
            sign = 1
        else:
            pmode_over_max_jm1 = p_full_branches_normed_stack[:,j-1,i]
            # Flip sign to match previous level
            if np.mean(np.abs(pmode_over_max+pmode_over_max_jm1)) < 1:
                sign = -1
            else:
                sign = 1
        sign_corrected_pmode_over_max = sign*pmode_over_max
        
        # Make eigfxn continuous over Z at a specific value of x
        sign_corrected_pmode_peak = sign_corrected_pmode_over_max[np.argmin(np.abs(x_EVP - 1/3))]
        if j == 0:
            sign_corrected_pmode_peak_base = sign_corrected_pmode_peak
        p_full_branches_normed_stack[:,j,i] = sign_corrected_pmode_over_max * sign_corrected_pmode_peak_base/sign_corrected_pmode_peak

        # Scale all other fields accordingly
        for l in range(1,len(field_full_branches_normed_stack_list)):
            field_mode = field_full_branches_stack_list[l][:,j,i]
            field_mode_over_max = field_mode/max_p
            sign_corrected_field_mode_normed = sign*field_mode_over_max
            field_full_branches_normed_stack_list[l][:,j,i] = sign_corrected_field_mode_normed * sign_corrected_pmode_peak_base/sign_corrected_pmode_peak

# Normalize IGW-0 at the forcing height
forcing_Z_idx = np.argmin(np.abs(Z_EVP - params_IVP['Z0']))
p_mode_field['g'] = p_full_branches_normed_stack[:,forcing_Z_idx,full_branch_ids["evan-IGW-0"]]
norm_p_IGW0_at_forcing_level = (np.sqrt(integ_x(np.conjugate(p_mode_field)*p_mode_field))).evaluate()['g'][0]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-0"]] = 1/norm_p_IGW0_at_forcing_level * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-0"]]

# Fix phase of IGW-0 at forcing height
peak_IGW0_at_forcing_level = p_full_branches_normed_stack[np.argmin(np.abs(x_EVP - 0)),forcing_Z_idx,full_branch_ids["evan-IGW-0"]]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-0"]] = 1/peak_IGW0_at_forcing_level * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-0"]]

# Normalize IGW-1 at the forcing height
forcing_Z_idx = np.argmin(np.abs(Z_EVP - params_IVP['Z0']))
p_mode_field['g'] = p_full_branches_normed_stack[:,forcing_Z_idx,full_branch_ids["evan-IGW-1"]]
norm_p_IGW1_at_forcing_level = (np.sqrt(integ_x(np.conjugate(p_mode_field)*p_mode_field))).evaluate()['g'][0]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-1"]] = 1/norm_p_IGW1_at_forcing_level * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-1"]]

# Fix phase of IGW-1 at forcing height
peak_IGW1_at_forcing_level = p_full_branches_normed_stack[np.argmin(np.abs(x_EVP - 1/4)),forcing_Z_idx,full_branch_ids["evan-IGW-1"]]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-1"]] = 1/peak_IGW1_at_forcing_level * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-IGW-1"]]

# Match SM-0 to IGW-0 at turning point
p_IGW0_at_tp = p_full_branches_normed_stack[:,turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0],full_branch_ids["evan-IGW-0"]]
peak_p_IGW0_at_tp = p_IGW0_at_tp[np.argmin(np.abs(x_EVP - 0))]
p_SM0_at_tp = p_full_branches_normed_stack[:,turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0],full_branch_ids["evan-SM-0"]]
peak_p_SM0_at_tp = p_SM0_at_tp[np.argmin(np.abs(x_EVP - 0))]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-SM-0"]] = peak_p_IGW0_at_tp/peak_p_SM0_at_tp * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-SM-0"]]

# Match SM-1 to IGW-1 at turning point
p_IGW1_at_tp = p_full_branches_normed_stack[:,turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0],full_branch_ids["evan-IGW-1"]]
peak_p_IGW1_at_tp = p_IGW1_at_tp[np.argmin(np.abs(x_EVP - 1/4))]
p_SM1_at_tp = p_full_branches_normed_stack[:,turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0],full_branch_ids["evan-SM-1"]]
peak_p_SM1_at_tp = p_SM1_at_tp[np.argmin(np.abs(x_EVP - 1/4))]

for l in range(len(field_full_branches_normed_stack_list)):
    field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-SM-1"]] = peak_p_IGW1_at_tp/peak_p_SM1_at_tp * field_full_branches_normed_stack_list[l][...,full_branch_ids["evan-SM-1"]]

Plot eigenfunctions

In [None]:
plot_field = p_full_branches_normed_stack.real
plt_strd = 4

vmin = np.max(np.abs(plot_field))

fig,axs = plt.subplots(1,N_full_branches,figsize=(12,4))
for i in range(N_full_branches):
    axs[i].pcolormesh(x_EVP[::plt_strd],Z_EVP[::plt_strd],plot_field[::plt_strd,::plt_strd,i].T,shading='nearest',cmap='RdBu_r',vmin=(-vmin,vmin))
    axs[i].set_title(full_branch_names[i])
    axs[i].set_xlabel("$x$")
axs[0].set_ylabel("$Z$")
plt.suptitle(r"$\Re\{p_0(x,Z)\}$")

Check that $\Re\{p_{0,\text{IGW-0}} + i p_{0,\text{IGW-1}}\} \approx \Re\{\exp(i 2\pi x)\}$

In [None]:
plot_field = p_full_branches_normed_stack

IGWforcingsum = plot_field[:,np.argmin(np.abs(Z_EVP - params_IVP['Z0'])),0] + 1j*plot_field[:,np.argmin(np.abs(Z_EVP - params_IVP['Z0'])),1]
plt.plot(x_EVP,IGWforcingsum.real,color='tab:blue',label=r"$\Re\{p_{0,\text{IGW-0}} + i p_{0,\text{IGW-1}}\}$")
plt.plot(x_EVP,IGWforcingsum.imag,color='tab:blue',linestyle='--',label=r"$\Im\{p_{0,\text{IGW-0}} + i p_{0,\text{IGW-1}}\}$")

plt.plot(x_EVP,(np.exp(1j*2*np.pi*x_EVP)).real,color='tab:red',label=r"$\Re\{\exp(i 2\pi x)\}$")
plt.plot(x_EVP,(np.exp(1j*2*np.pi*x_EVP)).imag,color='tab:red',linestyle='--',label=r"$\Im\{\exp(i 2\pi x)\}$")
plt.legend()
plt.xlabel(r"$x$")

Check that $p_0$ is constant over $Z$ at $x=1/3$

In [None]:
x_crosssection = 1/3
plot_field = p_full_branches_normed_stack

fig,axs = plt.subplots(1,N_full_branches,figsize=(12,4))
for i in range(N_full_branches):
    axs[i].plot(plot_field[np.argmin(np.abs(x_EVP - x_crosssection)),:,i].real,Z_EVP,color=colors[i])
    axs[i].plot(plot_field[np.argmin(np.abs(x_EVP - x_crosssection)),:,i].imag,Z_EVP,color=colors[i],linestyle='--')
    axs[i].set_title(full_branch_names[i])
    axs[i].set_xlabel(f"$p_0(x={x_crosssection:.2f})$")
axs[0].set_ylabel("$Z$")
axs[0].plot([],[],color='gray',label=r'$\Re\{p_0\}$')
axs[0].plot([],[],color='gray',linestyle='--',label=r'$\Re\{p_0\}$')
axs[0].legend(ncols=2,bbox_to_anchor=(2.8,-0.2))

Check that $p_0$ eigenfunctions match across IGW $\rightarrow$ SM turning points.

In [None]:
plot_field = p_full_branches_normed_stack

fig,axs = plt.subplots(2,1)
for i in range(2):
    for mdnm in ['IGW','SM']:
        axs[i].plot(x_EVP, (plot_field[:,turning_pt_ids_stack[full_branch_ids[f"evan-{mdnm}-{i}"]][0],full_branch_ids[f"evan-{mdnm}-{i}"]]).real,color=colors[full_branch_ids[f"evan-{mdnm}-{i}"]],label=r"$\Re\{p_{0,\text{" + f"{mdnm}-{i}" +r"}}\}$")
        axs[i].plot(x_EVP, (plot_field[:,turning_pt_ids_stack[full_branch_ids[f"evan-{mdnm}-{i}"]][0],full_branch_ids[f"evan-{mdnm}-{i}"]]).imag,color=colors[full_branch_ids[f"evan-{mdnm}-{i}"]],linestyle='--',label=r"$\Im\{p_{0,\text{" + f"{mdnm}-{i}" +r"}}\}$")
    axs[i].grid(alpha=0.3)
    axs[i].set_xlim(0,params_IVP['Lx'])
    axs[i].legend(bbox_to_anchor=(1.05,0.8))
    axs[i].set_title(r"$Z = Z_{t,\text{IGW}-"+str(i)+r"}"+f"= {Z_IVP[turning_pt_ids_stack[full_branch_ids[f"evan-IGW-{i}"]][0]]:.4f}$")
    axs[i].set_xlabel(r"$x$")

plt.subplots_adjust(hspace=0.5)

Compare amplitudes of $u_0$ and $v_0$ above the second turning point (where conversion to AW modes occurs).

In [None]:
plot_branch_names = ['evan-SM-0', 'evan-SM-1']
plt_strd = 4

fig,axs = plt.subplots(2,2,figsize=(7.5,8))
for i,plot_branch_name in enumerate(plot_branch_names):
    plot_branch_idx = full_branch_ids[plot_branch_name]
    plot_u_field = np.abs(u_full_branches_normed_stack)[::plt_strd,::plt_strd,plot_branch_idx]
    plot_v_field = np.abs(v_full_branches_normed_stack)[::plt_strd,::plt_strd,plot_branch_idx]
    vmin = np.max(plot_u_field)
    pm=axs[i,0].pcolormesh(x_EVP[::plt_strd],Z_EVP[::plt_strd],plot_u_field.T,shading='nearest',cmap='Spectral_r',vmin=(0,vmin))
    plt.colorbar(pm,label=r"$\lvert u_0 \rvert$")
    pm=axs[i,1].pcolormesh(x_EVP[::plt_strd],Z_EVP[::plt_strd],plot_v_field.T,shading='nearest',cmap='Spectral_r',vmin=(0,vmin))
    plt.colorbar(pm,extend='max',label=r"$\lvert v_0 \rvert$")
    axs[i,0].set_title(plot_branch_name+r", $\lvert u_0 \rvert$")
    axs[i,1].set_title(plot_branch_name+r", $\lvert v_0 \rvert$")
    for j in range(2):
        axs[i,j].set_xlabel("$x$")
        axs[i,j].set_ylabel("$Z$")
        Z_t = turning_pt_stack[plot_branch_idx][1]
        axs[i,j].axhline(Z_t,color='white',linestyle='--',linewidth=1.75)
plt.subplots_adjust(wspace=0.4)
plt.subplots_adjust(hspace=0.3)

In [None]:
plot_field = np.abs(p_full_branches_normed_stack)
plot_field[plot_field > 1e-4] = 1

fig,axs = plt.subplots(1,N_full_branches,figsize=(12,4))
for i in range(N_full_branches):
    axs[i].pcolormesh(x_EVP[::plt_strd],Z_EVP[::plt_strd],plot_field[::plt_strd,::plt_strd,i].T,shading='nearest',cmap='viridis_r',vmin=(0,1))
    axs[i].set_title(full_branch_names[i])
    axs[i].set_xlabel("$x$")
axs[0].set_ylabel("$Z$")
plt.suptitle(r"Zeros of $\lvert p_0(x,Z)\rvert$")

### Compute amplitude functions, $A(Z) = \left[\int_0^1 \left(p_0 w_0 + \Gamma v_{A z} (b_{0 x} u_0 - b_{0 y} v_0)\right)\mathrm{d}x \right]^{-1/2}$

In [None]:
# Create arrays to store rescaled modes
p_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
u_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
v_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
w_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
bx_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
by_full_branches_rescaled_stack = np.zeros((Nx,NZ,N_full_branches),dtype=np.complex128)
field_full_branches_rescaled_stack_list = [p_full_branches_rescaled_stack,u_full_branches_rescaled_stack,v_full_branches_rescaled_stack,
                                           w_full_branches_rescaled_stack,bx_full_branches_rescaled_stack,by_full_branches_rescaled_stack]
# Parameters
ky = params_IVP['ky']
Lx = params_IVP['Lx']

# Create Dedalus domain and fields
dealias = 1 #3/2
coords = d3.CartesianCoordinates('x')
dist = d3.Distributor(coords, dtype=np.complex128)
xbasis = d3.ComplexFourier(coords['x'], size=Nx_EVP, bounds=(0, Lx), dealias=dealias)
xgrid = dist.local_grid(xbasis)

# Create fields
p_n = dist.Field(name='p_n', bases=xbasis)
u_n = dist.Field(name='u_n', bases=xbasis)
v_n = dist.Field(name='v_n', bases=xbasis)
w_n = dist.Field(name='w_n', bases=xbasis)
bx_n = dist.Field(name='bx_n', bases=xbasis)
by_n = dist.Field(name='by_n', bases=xbasis)
vAz_n = dist.Field(name='vAz_n', bases=xbasis)

# Operators
dx = lambda a: d3.Differentiate(a, coords['x'])
dy = lambda a: 1j*ky*a
integ_x = lambda a: d3.Integrate(a, ('x'))

# Make array to store amplitude and group velocity
A_full_branches_stack = np.zeros(np.shape(kz_full_branches_stack),dtype=np.complex128)
cg_full_branches_stack = np.zeros(np.shape(kz_full_branches_stack),dtype=np.complex128)

for i in range(N_full_branches):
    # Make intermediate vector for storing A
    A_vec = np.zeros(NZ,dtype=np.complex128)
    for n in range(NZ):
        # Get modes and eigenvalues
        Z_n = Z_EVP[n]
        kz_n = kz_full_branches_stack[n,i]
        vAz_n['g'] = np.cos(kb*xgrid)/np.exp(kb*Z_n)
        p_n['g'] = p_full_branches_normed_stack[:,n,i]
        u_n['g'] = u_full_branches_normed_stack[:,n,i]
        v_n['g'] = v_full_branches_normed_stack[:,n,i]
        w_n['g'] = w_full_branches_normed_stack[:,n,i]
        bx_n['g'] = bx_full_branches_normed_stack[:,n,i]
        by_n['g'] = by_full_branches_normed_stack[:,n,i]

        # Compute intermediate variables
        rho_n = -(1j*kz_n*p_n)
        uStar_n = np.conjugate(u_n)
        vStar_n = np.conjugate(v_n)
        rhoStar_n = np.conjugate(rho_n)
        pStar_n = np.conjugate(p_n)
        bxStar_n = np.conjugate(bx_n)
        byStar_n = np.conjugate(by_n)

        E_n = (u_n*uStar_n + v_n*vStar_n + rho_n*rhoStar_n + bx_n*bxStar_n + by_n*byStar_n)/2
        F_n = -(Gamma*(bxStar_n*u_n + byStar_n*v_n)*vAz_n) + pStar_n*w_n
        integ_F_n = integ_x(F_n)
        integ_E_n = integ_x(E_n)

        ScriptF_n = p_n*w_n + Gamma*vAz_n*(bx_n*u_n - by_n*v_n)
        integ_ScriptF_n = integ_x(ScriptF_n)

        # Group velocity
        real_integ_F_n = (integ_F_n + np.conjugate(integ_F_n))/2
        cg_n = real_integ_F_n/integ_E_n
        cg = cg_n.evaluate()
        cg_full_branches_stack[n,i] = cg['g'][0]

        # Amplitude
        A_pos = 1/np.sqrt(integ_ScriptF_n.evaluate()['g'][0])
        A_neg = -A_pos
        A_vec[n]= A_pos

        if n > 0:
            if np.abs(A_vec[n-1] - A_neg) < np.abs(A_vec[n-1] - A_pos):
                A_vec[n] = A_neg

        # Store rescaled eigenmodes
        for l,field_n in enumerate([p_n,u_n,v_n,w_n,bx_n,by_n]):
            field_n.change_scales(Nx/Nx_EVP)
            field_full_branches_rescaled_stack_list[l][:,n,i] = field_n['g']
            field_n.change_scales(1)

    # Store amplitude for each branch
    A_full_branches_stack[:,i] = A_vec

# Normalize amplitude of IGW-0 and IGW-1 at forcing height
for i in [full_branch_ids["evan-IGW-0"],full_branch_ids["evan-IGW-1"]]:
    A_at_forcing_level = A_full_branches_stack[np.argmin(np.abs(Z_EVP - params_IVP['Z0'])),i]
    A_full_branches_stack[...,i] = A_full_branches_stack[...,i]/A_at_forcing_level

# Normalize amplitude of SM-0 wave to amplitude of IGW-0 wave at turning point
A_IGW0_at_tp = A_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0],full_branch_ids["evan-IGW-0"]]
A_SM0_at_tp = A_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0],full_branch_ids["evan-SM-0"]]
A_full_branches_stack[...,full_branch_ids["evan-SM-0"]] = A_full_branches_stack[...,full_branch_ids["evan-SM-0"]]/A_SM0_at_tp * A_IGW0_at_tp

# Normalize amplitude of SM-1 wave to amplitude of IGW-1 wave at turning point
A_IGW1_at_tp = A_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0],full_branch_ids["evan-IGW-1"]]
A_SM1_at_tp = A_full_branches_stack[turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0],full_branch_ids["evan-SM-1"]]
A_full_branches_stack[...,full_branch_ids["evan-SM-1"]] = A_full_branches_stack[...,full_branch_ids["evan-SM-1"]]/A_SM1_at_tp * A_IGW1_at_tp


In [None]:
fig,axs = plt.subplots(1,N_full_branches+1,figsize=(15,5))

for i in range(N_full_branches):
    ax = axs[i]
    ax.plot(A_full_branches_stack[:,i].real,Z_IVP,color=colors[i],label="real")
    ax.plot(A_full_branches_stack[:,i].imag,Z_IVP,color=colors[i],linestyle='--',label="imag.")
    ax.set_title(full_branch_names[i])
    ax.set_xlabel("$A(Z)$")
    ax.legend(loc="upper right")
    ax.grid()
    axs[-1].plot(A_full_branches_stack[:,i].real,Z_IVP,color=colors[i],label="real")
    axs[-1].plot(A_full_branches_stack[:,i].imag,Z_IVP,color=colors[i],linestyle='--',label="imag.")
axs[0].set_ylabel("$Z$")
axs[-1].set_xlabel("$A(Z)$")
axs[-1].grid(alpha=0.5)

In [None]:
fig,axs = plt.subplots(1,N_full_branches,figsize=(15,5))

for i in range(N_full_branches):
    ax = axs[i]
    ax.plot(cg_full_branches_stack[:,i].real,Z_IVP,color=colors[i])
    ax.set_title(full_branch_names[i])
    ax.set_xlabel("$c_{g z}$")
    ax.grid()
axs[0].set_ylabel("$Z$")

### Construct WKB components

In [None]:
# Construct WKB components along full branch for all fields
field_wkb_comps_full_branches_stack_list = []

for field_full_branches_rescaled_stack in field_full_branches_rescaled_stack_list:
    field_wkb_comps_full_branches_stack = A_full_branches_stack[np.newaxis,:] * field_full_branches_rescaled_stack * np.exp(1j*theta_full_branches_stack[np.newaxis,:]/Fr)
    field_wkb_comps_full_branches_stack_list.append(field_wkb_comps_full_branches_stack)

In [None]:
# Separate out evanescent waves and keep only the decaying modes
mode_ids = {"IGW-0":0,"IGW-1":1,"evan-0":2,"evan-1":3,"SM-0":4,"AW-0":5,"SM-1":6,"SM-AW-1":7,"AW-1":8}
N_modes = len(mode_ids.keys())

field_wkb_comps_stack_list = []
for field_wkb_comps_full_branches_stack in field_wkb_comps_full_branches_stack_list:
    
    field_wkb_igw0 = np.copy(field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-IGW-0"]])
    field_wkb_igw0[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0]] = 0
    field_wkb_evan0 = field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-IGW-0"]] - field_wkb_igw0

    field_wkb_igw1 = np.copy(field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-IGW-1"]])
    field_wkb_igw1[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0]] = 0
    field_wkb_evan1 = field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-IGW-1"]] - field_wkb_igw1

    field_wkb_sm0_aw0 = np.copy(field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-SM-0"]])
    field_wkb_sm0_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0]] = 0
    field_wkb_aw0 = np.copy(field_wkb_sm0_aw0)
    field_wkb_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][1]] = 0
    field_wkb_sm0 = field_wkb_sm0_aw0 - field_wkb_aw0

    field_wkb_sm1_smaw1_aw1 = np.copy(field_wkb_comps_full_branches_stack[...,full_branch_ids["evan-SM-1"]])
    field_wkb_sm1_smaw1_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0]] = 0
    field_wkb_smaw1_aw1 = np.copy(field_wkb_sm1_smaw1_aw1)
    field_wkb_smaw1_aw1[...,:alfven_intsc_idx] = 0
    field_wkb_sm1 = field_wkb_sm1_smaw1_aw1 - field_wkb_smaw1_aw1
    field_wkb_aw1 = np.copy(field_wkb_smaw1_aw1)
    field_wkb_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][1]] = 0
    field_wkb_smaw1 = field_wkb_smaw1_aw1 - field_wkb_aw1

    field_wkb_comp_stack = np.zeros((Nx,NZ,N_modes),dtype=np.complex128)
    field_wkb_comp_stack[...,mode_ids["IGW-0"]] = field_wkb_igw0
    field_wkb_comp_stack[...,mode_ids["IGW-1"]] = field_wkb_igw1
    field_wkb_comp_stack[...,mode_ids["evan-0"]] = field_wkb_evan0
    field_wkb_comp_stack[...,mode_ids["evan-1"]] = field_wkb_evan1
    field_wkb_comp_stack[...,mode_ids["SM-0"]] = field_wkb_sm0
    field_wkb_comp_stack[...,mode_ids["AW-0"]] = field_wkb_aw0
    field_wkb_comp_stack[...,mode_ids["SM-1"]] = field_wkb_sm1
    field_wkb_comp_stack[...,mode_ids["SM-AW-1"]] = field_wkb_smaw1
    field_wkb_comp_stack[...,mode_ids["AW-1"]] = field_wkb_aw1

    field_wkb_comps_stack_list.append(field_wkb_comp_stack)

p_wkb_comp_stack,u_wkb_comp_stack,v_wkb_comp_stack,w_wkb_comp_stack,bx_wkb_comp_stack,by_wkb_comp_stack = field_wkb_comps_stack_list

In [None]:
kz_igw0 = np.copy(kz_full_branches_stack[...,full_branch_ids["evan-IGW-0"]])
kz_igw0[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0]] = 0
kz_evan0 = kz_full_branches_stack[...,full_branch_ids["evan-IGW-0"]] - kz_igw0

kz_igw1 = np.copy(kz_full_branches_stack[...,full_branch_ids["evan-IGW-1"]])
kz_igw1[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0]] = 0
kz_evan1 = kz_full_branches_stack[...,full_branch_ids["evan-IGW-1"]] - kz_igw1

kz_sm0_aw0 = np.copy(kz_full_branches_stack[...,full_branch_ids["evan-SM-0"]])
kz_sm0_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0]] = 0
kz_aw0 = np.copy(kz_sm0_aw0)
kz_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][1]] = 0
kz_sm0 = kz_sm0_aw0 - kz_aw0

kz_sm1_smaw1_aw1 = np.copy(kz_full_branches_stack[...,full_branch_ids["evan-SM-1"]])
kz_sm1_smaw1_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0]] = 0
kz_smaw1_aw1 = np.copy(kz_sm1_smaw1_aw1)
kz_smaw1_aw1[...,:alfven_intsc_idx] = 0
kz_sm1 = kz_sm1_smaw1_aw1 - kz_smaw1_aw1
kz_aw1 = np.copy(kz_smaw1_aw1)
kz_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][1]] = 0
kz_smaw1 = kz_smaw1_aw1 - kz_aw1

for kz_mode in [kz_igw0,kz_igw1,kz_evan0,kz_evan1,kz_sm0,kz_aw0,kz_sm1,kz_smaw1,kz_aw1]:
    kz_mode[kz_mode==0] = np.nan + 1j*np.nan

kz_stack = np.zeros((NZ,N_modes),dtype=np.complex128)
kz_stack[...,mode_ids["IGW-0"]] = kz_igw0
kz_stack[...,mode_ids["IGW-1"]] = kz_igw1
kz_stack[...,mode_ids["evan-0"]] = kz_evan0
kz_stack[...,mode_ids["evan-1"]] = kz_evan1
kz_stack[...,mode_ids["SM-0"]] = kz_sm0
kz_stack[...,mode_ids["AW-0"]] = kz_aw0
kz_stack[...,mode_ids["SM-1"]] = kz_sm1
kz_stack[...,mode_ids["SM-AW-1"]] = kz_smaw1
kz_stack[...,mode_ids["AW-1"]] = kz_aw1

In [None]:
cg_igw0 = np.copy(cg_full_branches_stack[...,full_branch_ids["evan-IGW-0"]])
cg_igw0[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-0"]][0]] = 0
cg_evan0 = cg_full_branches_stack[...,full_branch_ids["evan-IGW-0"]] - cg_igw0

cg_igw1 = np.copy(cg_full_branches_stack[...,full_branch_ids["evan-IGW-1"]])
cg_igw1[...,:turning_pt_ids_stack[full_branch_ids["evan-IGW-1"]][0]] = 0
cg_evan1 = cg_full_branches_stack[...,full_branch_ids["evan-IGW-1"]] - cg_igw1

cg_sm0_aw0 = np.copy(cg_full_branches_stack[...,full_branch_ids["evan-SM-0"]])
cg_sm0_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0]] = 0
cg_aw0 = np.copy(cg_sm0_aw0)
cg_aw0[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][1]] = 0
cg_sm0 = cg_sm0_aw0 - cg_aw0

cg_sm1_smaw1_aw1 = np.copy(cg_full_branches_stack[...,full_branch_ids["evan-SM-1"]])
cg_sm1_smaw1_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][0]] = 0
cg_smaw1_aw1 = np.copy(cg_sm1_smaw1_aw1)
cg_smaw1_aw1[...,:alfven_intsc_idx] = 0
cg_sm1 = cg_sm1_smaw1_aw1 - cg_smaw1_aw1
cg_aw1 = np.copy(cg_smaw1_aw1)
cg_aw1[...,:turning_pt_ids_stack[full_branch_ids["evan-SM-1"]][1]] = 0
cg_smaw1 = cg_smaw1_aw1 - cg_aw1

for cg_mode in [cg_igw0,cg_igw1,cg_evan0,cg_evan1,cg_sm0,cg_aw0,cg_sm1,cg_smaw1,cg_aw1]:
    cg_mode[cg_mode==0] = np.nan + 1j*np.nan

cg_stack = np.zeros((NZ,N_modes),dtype=np.complex128)
cg_stack[...,mode_ids["IGW-0"]] = cg_igw0
cg_stack[...,mode_ids["IGW-1"]] = cg_igw1
cg_stack[...,mode_ids["evan-0"]] = cg_evan0
cg_stack[...,mode_ids["evan-1"]] = cg_evan1
cg_stack[...,mode_ids["SM-0"]] = cg_sm0
cg_stack[...,mode_ids["AW-0"]] = cg_aw0
cg_stack[...,mode_ids["SM-1"]] = cg_sm1
cg_stack[...,mode_ids["SM-AW-1"]] = cg_smaw1
cg_stack[...,mode_ids["AW-1"]] = cg_aw1

In [None]:
color_dict = {'IGW-1':"#ed6f00",'IGW-0':"#dbb700",
              'SM-0':"#299100",'AW-0':"#299100",
              'SM-1':"#9839ad",'SM-AW-1':"#9839ad",'AW-1':"#9839ad",
              'evan-1':"#ed6f00",'evan-0':"#dbb700",'none':"#00000000"}
linestyle_dict = {'IGW-1':"-",'IGW-0':"-",
                  'SM-0':"-",'SM-0':"-",'AW-0':"dotted",
                  'SM-1':"-",'SM-AW-1':"--",'AW-1':"dotted",
                  'evan-1':"dotted",'evan-0':"dotted",'none':'-'}

#"#038f78"

In [None]:
plt_strd = 4
plot_field = p_wkb_comp_stack.real
vmin = np.quantile(np.abs(plot_field),q=0.9995)

fig,axs = plt.subplot_mosaic([['kz','.','cg','.'] + list(mode_ids.keys())], width_ratios=[2,0.25,2,0.25] + N_modes*[1],figsize=(1.75*(3 + 0.25 + N_modes),5))
for i,mode_nm in enumerate(mode_ids.keys()):
    ax = axs[mode_nm]
    ax.pcolormesh(x_IVP[::plt_strd],Z_IVP[::plt_strd],plot_field[::plt_strd,::plt_strd,mode_ids[mode_nm]].T,cmap="RdBu_r",shading='nearest',vmin=(-vmin,vmin))
    ax.set_title(mode_nm,color=color_dict[mode_nm])
    ax.set_ylim(0,LZ)
    ax.set_xlim(0,1)
    ax.set_xlabel("$x$",fontsize=14)
    ax.set_aspect(4.5 * 1/LZ)
    if i == 0:
        ax.set_ylabel("$Z$",fontsize=14)
axs['cg'].axvline(0,color='k',linewidth=0.5,zorder=0)
axs['cg'].grid(alpha=0.3)
for nm in mode_ids.keys():
    axs['kz'].plot(kz_stack[...,mode_ids[nm]].real,Z_IVP,color=color_dict[nm],linestyle=linestyle_dict[nm],linewidth=2,label=nm)
    axs['cg'].plot(cg_stack[...,mode_ids[nm]].real,Z_IVP,color=color_dict[nm],linestyle=linestyle_dict[nm],linewidth=2,label=nm,zorder=1)
xlims = axs['kz'].get_xlim()
axs['kz'].fill_betweenx(Z_EVP,alfven_bdry,100*np.ones(NZ),color='lightgray',zorder=0)
axs['kz'].set_xlim(xlims)
for axkey in ['kz','cg']:
    axs[axkey].set_ylim(0,LZ)
    axs[axkey].set_ylabel("$Z$",fontsize=14)
axs['kz'].legend(loc="lower right")
axs['kz'].set_xlabel(r"$\Re\{k_z\}$",fontsize=14)
axs['cg'].set_xlabel(r"$c_g$",fontsize=14)

In [None]:
for mode_nm in ["IGW-0","IGW-1"]:
    plt.plot(x_IVP[::plt_strd],p_wkb_comp_stack[::plt_strd,forcing_Z_idx,mode_ids[mode_nm]].real,color=color_dict[mode_nm],label=mode_nm)
    plt.plot(x_IVP[::plt_strd],p_wkb_comp_stack[::plt_strd,forcing_Z_idx,mode_ids[mode_nm]].imag,color=color_dict[mode_nm],linestyle='--')
plt.title(f"WKB components at forcing height, $Z$ = {Z_IVP[forcing_Z_idx]:.3f}")
plt.ylabel(r"$A(Z)p_0(x,Z)\exp(i \Theta(Z))$",fontsize=14)
plt.xlabel(r"$x$",fontsize=14)
plt.plot([],[],color='gray',label="real")
plt.plot([],[],color='gray',linestyle='--',label="imag")
plt.legend()
plt.grid(alpha=0.5)
plt.show()

### Fit WKB solution to IVP

In [None]:
def func_2d(C,q_wkb_comp_stack,sin_amp=1,cos_amp=1):
    evan0_phase = evan1_phase = 0
    sm0_phase = sm1_phase = -np.pi/2
    # Get parameters
    overall_amp, overall_phase = (np.abs(C), np.angle(C))

    # Get components
    igw0_comp = q_wkb_comp_stack[...,mode_ids['IGW-0']]
    igw1_comp = q_wkb_comp_stack[...,mode_ids['IGW-1']]
    evan0_comp = q_wkb_comp_stack[...,mode_ids['evan-0']]
    evan1_comp = q_wkb_comp_stack[...,mode_ids['evan-1']]
    sm0_comp = q_wkb_comp_stack[...,mode_ids['SM-0']]
    aw0_comp = q_wkb_comp_stack[...,mode_ids['AW-0']]
    sm1_comp = q_wkb_comp_stack[...,mode_ids['SM-1']]
    smaw1_comp = q_wkb_comp_stack[...,mode_ids['SM-AW-1']]
    aw1_comp = q_wkb_comp_stack[...,mode_ids['AW-1']]

    # cosines
    q_cos = igw0_comp
    q_cos = q_cos + evan0_comp*np.exp(1j*evan0_phase)
    q_cos = q_cos + sm0_comp*np.exp(1j*sm0_phase)
    q_cos = q_cos + aw0_comp*np.exp(1j*sm0_phase)

    # sines
    q_sin = igw1_comp
    q_sin = q_sin + evan1_comp*np.exp(1j*evan1_phase)
    q_sin = q_sin + sm1_comp*np.exp(1j*sm1_phase)
    q_sin = q_sin + smaw1_comp*np.exp(1j*sm1_phase)
    q_sin = q_sin + aw1_comp*np.exp(1j*sm1_phase)

    # Combine parities
    q = cos_amp*q_cos + 1j*sin_amp*q_sin

    # Overall amplitude and phase
    q = overall_amp * q * np.exp(1j*overall_phase)

    return q

In [None]:
# Set fitting latitude and height
x_fit = 1/3 #0.45
Z_fit = 0.175 #0.1875

# Choose field to use for fitting
fit_field_IVP = u_IVP
fit_field_wkb = u_wkb_comp_stack

x_fit_idx = np.argmin(np.abs(x_IVP - x_fit))
Z_fit_idx = np.argmin(np.abs(Z_IVP - Z_fit))

fit_field_IVP_pt= fit_field_IVP[x_fit_idx,Z_fit_idx]

complex_amp = fit_field_IVP_pt/(func_2d(1,fit_field_wkb,sin_amp=1,cos_amp=1)[x_fit_idx,Z_fit_idx])

In [None]:
# Fit WKB soln to IVP
p_wkb_soln = func_2d(complex_amp,p_wkb_comp_stack)
u_wkb_soln = func_2d(complex_amp,u_wkb_comp_stack)
v_wkb_soln = func_2d(complex_amp,v_wkb_comp_stack)
w_wkb_soln = func_2d(complex_amp,w_wkb_comp_stack)
bx_wkb_soln = func_2d(complex_amp,bx_wkb_comp_stack)
by_wkb_soln = func_2d(complex_amp,by_wkb_comp_stack)

In [None]:
def plot_comparison(field_IVP, field_wkb_soln, field_label, vmin_factor=1, x_prof=0.8, plt_strd=1):
    fig, axs = plt.subplots(1,5,figsize = (5*(3*10/16),8*10/16))
    title_fontsize = 12
    x_ind_IVP_color = '#0073ff'
    x_ind_IVP_extrap_color = '#d909b6'

    resid_field = field_IVP - field_wkb_soln

    vmin = (-vmin_factor*np.max(np.abs(field_IVP.T.real)),vmin_factor*np.max(np.abs(field_IVP.T.real)))

    axs[0].set_xlabel(field_label, fontsize=title_fontsize)
    axs[0].plot(field_IVP[x_fit_idx,:].real,Z_IVP,linestyle='--',color='navy',lw=2,label='IVP')
    axs[0].plot(field_wkb_soln[x_fit_idx,:].real,Z_IVP,linestyle='-',color='darkred',lw=1,label='WKB')
    axs[0].set_ylim(0,LZ)
    axs[0].set_title(f'$x/L$ = {x_fit:.2f}', fontsize=title_fontsize)
    axs[0].legend(loc="lower right",framealpha=0.9,fontsize=title_fontsize)

    x_extrap_idx = np.argmin(np.abs(x_IVP - x_prof))
    axs[1].plot(field_IVP[x_extrap_idx,:].real,Z_IVP,linestyle='--',color='navy',lw=2,label='IVP')
    axs[1].plot(field_wkb_soln[x_extrap_idx,:].real,Z_IVP,linestyle='-',color='darkred',lw=1,label='WKB')
    axs[1].set_ylim(0,LZ)
    axs[1].set_title(f'$x/L$ = {x_IVP[x_extrap_idx]:.2f}', fontsize=title_fontsize)
    axs[1].set_xlabel(field_label, fontsize=title_fontsize)

    axs[2].pcolormesh(x_IVP[::plt_strd],Z_IVP[::plt_strd],field_IVP[::plt_strd,::plt_strd].T.real,cmap='RdBu_r',vmin=vmin)
    axs[2].set_title(f'{field_label} (IVP)', fontsize=title_fontsize)
    axs[2].set_xlabel('$x/L$', fontsize=title_fontsize)

    axs[3].pcolormesh(x_IVP[::plt_strd],Z_IVP[::plt_strd],field_wkb_soln[::plt_strd,::plt_strd].T.real,cmap='RdBu_r',vmin=vmin)
    axs[3].set_title(f'{field_label} (WKB)', fontsize=title_fontsize)
    axs[3].set_xlabel('$x/L$', fontsize=title_fontsize)

    axs[4].pcolormesh(x_IVP[::plt_strd],Z_IVP[::plt_strd],resid_field[::plt_strd,::plt_strd].T.real,cmap='RdBu_r',vmin=vmin)
    axs[4].set_title('difference\n(IVP - WKB)', fontsize=title_fontsize)
    axs[4].set_xlabel('$x/L$', fontsize=title_fontsize)

    panel_label = ['($a$)','($b$)','($c$)','($d$)','($e$)']
    for k,ax in enumerate(axs):
        if k>0:
            ax.set_yticks([])
        else:
            ax.set_ylabel('$z/L$', fontsize=title_fontsize)

        # Panel labels
        ax.text(-0.0, 1.1, panel_label[k], transform=ax.transAxes, fontsize=14, va='top', ha='right')

        # Turning points and Alfven boundary intersection
        for mode_nm in ["IGW-0","IGW-1"]:
            ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{mode_nm}"]][0],color=color_dict[mode_nm],linestyle=linestyle_dict[mode_nm])
        ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-0"]][1],color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"])
        ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-1"]][1],color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"])
        ax.axhline(alfven_intsc_Z,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"])

        # Forcing height
        ax.axhline(params_IVP['Z0'],linestyle='dotted',color='k',lw=2)
        xlims = ax.get_xlim()
        ylims = ax.get_ylim()

        # Damping layers
        ax.fill_between([xlims[0],xlims[1]],[params_IVP['s'],params_IVP['s']], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5)
        ax.fill_between([xlims[0],xlims[1]],[LZ-params_IVP['s'],LZ-params_IVP['s']],[LZ,LZ], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5)
        
        # Z-profile locations
        if k in [2,3]:
            ax.axvline(x_IVP[x_fit_idx],color=x_ind_IVP_color,linewidth=1.25,linestyle='--')
            ax.axvline(x_IVP[x_extrap_idx],color=x_ind_IVP_extrap_color,linewidth=1.25,linestyle='--')
        ax.set_xlim(xlims)
        ax.set_ylim(ylims)

    axs[0].tick_params(colors=x_ind_IVP_color)
    [spine.set_color(x_ind_IVP_color) for spine in axs[0].spines.__dict__['_dict'].values()]

    axs[1].tick_params(colors=x_ind_IVP_extrap_color)
    [spine.set_color(x_ind_IVP_extrap_color) for spine in axs[1].spines.__dict__['_dict'].values()]

    ax.set_xlabel('$x/L$', fontsize=title_fontsize)
    plt.subplots_adjust(top=1,wspace=0.2)

In [None]:
plot_comparison(p_IVP, p_wkb_soln, "$p$", plt_strd=4, vmin_factor=1.2, x_prof=0.5)

In [None]:
plot_comparison(v_IVP, v_wkb_soln, "$v$", plt_strd=4, vmin_factor=1.2, x_prof=0.05)

In [None]:
plot_comparison(u_IVP, u_wkb_soln, "$u$", plt_strd=4, vmin_factor=1.2, x_prof=0.95)

In [None]:
# Decompose by parity

def sin_cos_decomp(c,Lx,LZ):
    Nx,NZ = c.shape
    coords_real = d3.CartesianCoordinates('x', 'Z')
    dist_real = d3.Distributor(coords_real, dtype=np.float64)
    xbasis_real = d3.RealFourier(coords_real['x'], size=Nx, bounds=(0, Lx), dealias=3/2)
    Zbasis_real = d3.RealFourier(coords_real['Z'], size=NZ, bounds=(0, LZ), dealias=3/2)
    # xgrid_real, Zgrid_real = dist_real.local_grids(xbasis_real, Zbasis_real)

    creal_cos = dist_real.Field(name='creal_cos', bases=(xbasis_real, Zbasis_real))
    creal_sin = dist_real.Field(name='creal_sin', bases=(xbasis_real, Zbasis_real))
    cimag_cos = dist_real.Field(name='cimag_cos', bases=(xbasis_real, Zbasis_real))
    cimag_sin = dist_real.Field(name='cimag_sin', bases=(xbasis_real, Zbasis_real))

    creal_cos['g'] = c.real
    creal_sin['g'] = c.real
    cimag_cos['g'] = c.imag
    cimag_sin['g'] = c.imag

    creal_cos['c'][1::2,:] = 0
    creal_sin['c'][::2,:] = 0
    cimag_cos['c'][1::2,:] = 0
    cimag_sin['c'][::2,:] = 0

    c_cos = creal_cos['g'] + 1j*cimag_cos['g']
    c_sin = creal_sin['g'] + 1j*cimag_sin['g']

    return c_sin, c_cos

p_sin, p_cos = sin_cos_decomp(p_IVP,Lx,LZ)
u_sin, u_cos = sin_cos_decomp(u_IVP,Lx,LZ) 
v_sin, v_cos = sin_cos_decomp(v_IVP,Lx,LZ)
w_sin, w_cos = sin_cos_decomp(w_IVP,Lx,LZ)
bx_sin, bx_cos = sin_cos_decomp(Bx_IVP,Lx,LZ)
by_sin, by_cos = sin_cos_decomp(By_IVP,Lx,LZ)

p_wkb_soln_sin,p_wkb_soln_cos = [func_2d(complex_amp,p_wkb_comp_stack,sin_amp=1,cos_amp=0),func_2d(complex_amp,p_wkb_comp_stack,sin_amp=0,cos_amp=1)]
u_wkb_soln_sin,u_wkb_soln_cos = [func_2d(complex_amp,u_wkb_comp_stack,sin_amp=0,cos_amp=1),func_2d(complex_amp,u_wkb_comp_stack,sin_amp=1,cos_amp=0)] # opposite parity for u
v_wkb_soln_sin,v_wkb_soln_cos = [func_2d(complex_amp,v_wkb_comp_stack,sin_amp=1,cos_amp=0),func_2d(complex_amp,v_wkb_comp_stack,sin_amp=0,cos_amp=1)]
w_wkb_soln_sin,w_wkb_soln_cos = [func_2d(complex_amp,w_wkb_comp_stack,sin_amp=1,cos_amp=0),func_2d(complex_amp,w_wkb_comp_stack,sin_amp=0,cos_amp=1)]
bx_wkb_soln_sin,bx_wkb_soln_cos = [func_2d(complex_amp,bx_wkb_comp_stack,sin_amp=0,cos_amp=1),func_2d(complex_amp,bx_wkb_comp_stack,sin_amp=1,cos_amp=0)] # opposite parity for bx
by_wkb_soln_sin,by_wkb_soln_cos = [func_2d(complex_amp,by_wkb_comp_stack,sin_amp=1,cos_amp=0),func_2d(complex_amp,by_wkb_comp_stack,sin_amp=0,cos_amp=1)]

sin_cos_dict = {}
sin_cos_dict["p"] = {"ivp":p_IVP, "ivp_sin":p_sin, "ivp_cos":p_cos, "wkb":p_wkb_soln, "wkb_sin":p_wkb_soln_sin, "wkb_cos":p_wkb_soln_cos}
sin_cos_dict["u"] = {"ivp":u_IVP, "ivp_sin":u_sin, "ivp_cos":u_cos, "wkb":u_wkb_soln, "wkb_sin":u_wkb_soln_sin, "wkb_cos":u_wkb_soln_cos}
sin_cos_dict["v"] = {"ivp":v_IVP, "ivp_sin":v_sin, "ivp_cos":v_cos, "wkb":v_wkb_soln, "wkb_sin":v_wkb_soln_sin, "wkb_cos":v_wkb_soln_cos}
sin_cos_dict["w"] = {"ivp":w_IVP, "ivp_sin":w_sin, "ivp_cos":w_cos, "wkb":w_wkb_soln, "wkb_sin":w_wkb_soln_sin, "wkb_cos":w_wkb_soln_cos}
sin_cos_dict["bx"] = {"ivp":Bx_IVP, "ivp_sin":bx_sin, "ivp_cos":bx_cos, "wkb":bx_wkb_soln, "wkb_sin":bx_wkb_soln_sin, "wkb_cos":bx_wkb_soln_cos}
sin_cos_dict["by"] = {"ivp":By_IVP, "ivp_sin":by_sin, "ivp_cos":by_cos, "wkb":by_wkb_soln, "wkb_sin":by_wkb_soln_sin, "wkb_cos":by_wkb_soln_cos}

In [None]:
om = 1

def crit_lat_ideal(kz,Z):
    # Ideal MHD critical latitudes (i.e. values of x where om**2 - Gamma**2 * vAz**2 * kz**2 == 0)
    if np.isnan(kz):
        x_list = np.nan*np.ones(4)
    else:
        kz = kz.real
        arg = np.exp(kb*Z) * om/(kz * Gamma)
        x_list = np.array([-np.arccos(-arg)/kb, np.arccos(-arg)/kb, -np.arccos(arg)/kb, np.arccos(arg)/kb])
        for k,x in enumerate(x_list):
            if x < 0:
                x_list[k] = x + 2*np.pi/kb
            elif x > Lx:
                x_list[k] = x - 2*np.pi/kb
        x_list = x_list[np.argsort(x_list)]
    return x_list[np.argsort(x_list)]

smaw1_crit_lats = np.zeros((NZ,4))

for i in range(NZ):
    Z_i = Z_EVP[i]
    kz_i = kz_stack[i,mode_ids["SM-AW-1"]]
    smaw1_crit_lats[i,:] = crit_lat_ideal(kz_i,Z_i)

In [None]:
label_dict = {"p":"p", "u":"u", "v":"v", "w":"w", "bx":"b_x", "by":"b_y"}

def plot_comp_sincos(sin_cos_dict, field, plt_strd=1, savefig=False, shading='nearest', rasterized=False, dpi=300, ext='pdf'):
    gridspec = dict(hspace=0.0, width_ratios=[1, 1, 0.4, 1, 1, 0.4, 1, 1])
    fig, axs = plt.subplots(1,8,figsize = (8,5),gridspec_kw = gridspec)
    axs[2].set_visible(False)
    axs[5].set_visible(False)
    title_fontsize = 12
    vmin = np.max(np.abs(sin_cos_dict[field]["ivp"].real))

    q_IVP = sin_cos_dict[field]["ivp"]
    q_cos = sin_cos_dict[field]["ivp_cos"]
    q_sin = sin_cos_dict[field]["ivp_sin"]
    q_wkb_soln = sin_cos_dict[field]["wkb"]
    q_wkb_soln_cos = sin_cos_dict[field]["wkb_cos"]
    q_wkb_soln_sin = sin_cos_dict[field]["wkb_sin"]

    if field in ["u", "bx"]:
        switch_axes = 1
    else:
        switch_axes = 0

    axs[0].imshow(q_IVP.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[1].imshow(q_wkb_soln.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[0].set_title(f"${label_dict[field]}_{{\\text{{IVP}}}}$")
    axs[1].set_title(f"${label_dict[field]}_{{\\text{{WKB}}}}$")

    axs[3+3*switch_axes].imshow(q_cos.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[4+3*switch_axes].imshow(q_wkb_soln_cos.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[3+3*switch_axes].set_title(f"${label_dict[field]}_{{\\text{{IVP}}}}$ (cos)")
    axs[4+3*switch_axes].set_title(f"${label_dict[field]}_{{\\text{{WKB}}}}$ (cos)")

    axs[6-3*switch_axes].imshow(q_sin.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[7-3*switch_axes].imshow(q_wkb_soln_sin.T.real[::plt_strd,::plt_strd],cmap='RdBu_r',vmin=(-vmin,vmin),extent=[0,np.max(x_IVP),0,np.max(Z_IVP)],origin="lower")
    axs[6-3*switch_axes].set_title(f"${label_dict[field]}_{{\\text{{IVP}}}}$ (sin)")
    axs[7-3*switch_axes].set_title(f"${label_dict[field]}_{{\\text{{WKB}}}}$ (sin)")

    # Plot critical latitudes
    if field == "v":
        for xc in [0,0.5,1]:
            Zc = Z_IVP[turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][0]:turning_pt_ids_stack[full_branch_ids["evan-SM-0"]][1]+1]
            axs[3+3*switch_axes].plot(xc+0*Zc,Zc,color='k',linestyle='--',lw=0.75)
            axs[4+3*switch_axes].plot(xc+0*Zc,Zc,color='k',linestyle='--',lw=0.75)
        axs[6-3*switch_axes].plot(smaw1_crit_lats,Z_IVP,color='k',lw=0.75)
        axs[7-3*switch_axes].plot(smaw1_crit_lats,Z_IVP,color='k',lw=0.75)

    panel_label = ['($a$)','($b$)','.','($c$)','($d$)','.','($e$)','($f$)']
    for k in [0,1,3,4,6,7]:
        ax = axs[k]
        if k in [1,4,7]:
            ax.set_yticks([])
        else:
            ylab = ax.set_ylabel('$z/L$', labelpad=-5)
        
        # Panel labels
        ax.text(0.1, 1.11, panel_label[k], transform=ax.transAxes, fontsize=14, va='top', ha='right')

        # Forcing height
        ax.axhline(params_IVP['Z0'],linestyle='dotted',color='k',lw=1.5)

        # Turning points and Alfven boundary intersection
        lw = 2
        if k in [0,1,3,4]:
            ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-0"}"]][0],color=color_dict["IGW-0"],linestyle=linestyle_dict["IGW-0"])
            ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-0"]][1],color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"],lw=lw)
        if k in [0,1,6,7]:
            ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-1"}"]][0],color=color_dict["IGW-1"],linestyle=linestyle_dict["IGW-1"])
            ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-1"]][1],color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"],lw=lw)
            ax.axhline(alfven_intsc_Z,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"])

        xlims = ax.get_xlim()
        ylims = ax.get_ylim()
        ax.fill_between([0,1],[params_IVP['s'],params_IVP['s']], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5, zorder=10)
        ax.fill_between([0,1],[LZ-params_IVP['s'],LZ-params_IVP['s']],[LZ,LZ], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5, zorder=10)
        ax.set_ylim(0,LZ)
        ax.set_xlim(0,1)
        ax.set_xticks([0,0.5,1],['0','0.5','1'])
        ax.set_xlabel("$x/L$")
        ax.set_aspect(4.5 * 1/LZ)

    # Labels
    ax.axhline(np.nan,color=color_dict["IGW-0"],linestyle=linestyle_dict["IGW-0"],label="IGW-0 $\\rightarrow$ SM-0")
    ax.axhline(np.nan,color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"],lw=lw,label="SM-0 $\\rightarrow$ AW-0")
    ax.axhline(np.nan,color=color_dict["IGW-1"],linestyle=linestyle_dict["IGW-1"],label="IGW-1 $\\rightarrow$ SM-1")
    ax.axhline(np.nan,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"],label="SM-1 $\\rightarrow$ SM-AW-1")
    ax.axhline(np.nan,color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"],lw=lw,label="SM-AW-1 $\\rightarrow$ AW-1")
    ax.axhline(np.nan,color='k',linestyle='dotted',lw=1.5,label="driving layer")
    if field == "v":
        ax.axhline(np.nan,color='k',linestyle='--',lw=0.75,label="crit. lat. for  $k_z = k_{A B}$")
        ax.axhline(np.nan,color='k',linestyle='-',lw=0.75,label="crit. lat. for SM-AW-1")
        ncols=4
        columnspacing=1.5
        handlelength=1.2
    else:
        ncols=3
        columnspacing=5
        handlelength=2
    ax.legend(loc="center",bbox_to_anchor = (-3,-0.225), ncols=ncols, framealpha=0, columnspacing=columnspacing, borderpad=0.75, handlelength=handlelength)
    

    if savefig:
        save_path = wkb_fig_path.joinpath(f"{sim_name}_{field}_wkb_{shading}_dpi{dpi}.{ext}")
        plt.savefig(save_path,dpi=dpi,bbox_inches='tight')

In [None]:
plot_comp_sincos(sin_cos_dict, "p", plt_strd=1, savefig=True, shading='nearest', rasterized=True, dpi=200)

In [None]:
plot_comp_sincos(sin_cos_dict, "u", plt_strd=1, savefig=True, shading='nearest', rasterized=True, dpi=200)

In [None]:
plot_comp_sincos(sin_cos_dict, "v", plt_strd=1, savefig=True, shading='nearest', rasterized=True, dpi=200)

# Additional plots

In [None]:
# Load axisymmetric evp
axisym_classified_evp_name = "evp_LZ=2p500e-01_ky=0p00e+00pi_Gamma1p00e-01_eta0p0e+00_96_512_evp_LZ=2p500e-01_ky=0p00e+00pi_Gamma1p00e-01_eta0p0e+00_128_512_classified.hdf5"
save_name = str(data_path.joinpath(axisym_classified_evp_name))

axisym_params = {}
axisym_EVP_dict = {}
with h5py.File(save_name,'r') as f:
    for key in f['params'].keys():
        axisym_params[key] = f['params'][key][()]
    for key in f['l=1'].keys():
        axisym_EVP_dict[key] = {}
        for subkey in f['l=1'][key].keys():
            axisym_EVP_dict[key][subkey] = f['l=1'][key][subkey][()]

NZ_axisym_EVP = axisym_params['NZ']
LZ_axisym_EVP = axisym_params['LZ']
Z_axisym_EVP = np.linspace(0,LZ_axisym_EVP,NZ_axisym_EVP)
axisym_EVP_dict['evan-1'] = axisym_EVP_dict['IGW-evan-1']
axisym_EVP_dict.pop('IGW-evan-1',None)
axisym_EVP_dict.pop('SM-evan-1',None)

axisym_EVP_mode_names = axisym_EVP_dict.keys()
axisym_mode_ids = {nm:idx for idx,nm in enumerate(axisym_EVP_mode_names)}
N_axisym_modes = len(axisym_EVP_mode_names)
kz_stack_axisym = (np.nan + 1j*np.nan)*np.ones((NZ_axisym_EVP,N_axisym_modes), dtype=np.complex128)

for nm in axisym_EVP_mode_names:
    Z_each = axisym_EVP_dict[nm]['Z']
    start_idx = np.argmin(np.abs(Z_axisym_EVP - Z_each[0]))
    end_idx = np.argmin(np.abs(Z_axisym_EVP - Z_each[-1]))
    kz_stack_axisym[start_idx:end_idx+1,axisym_mode_ids[nm]] = axisym_EVP_dict[nm]['kz']


In [None]:
def get_kz_from_Z(mode_nm,Z):
    kzarr = kz_stack[...,mode_ids[mode_nm]]
    Z_idx = np.argmin(np.abs(Z_IVP-Z))
    return kzarr[Z_idx]

def get_kz_from_Z_axisym(mode_nm,Z):
    kzarr = kz_stack_axisym[...,axisym_mode_ids[mode_nm]]
    Z_idx = np.argmin(np.abs(Z_axisym_EVP-Z))
    return kzarr[Z_idx]

In [None]:
fig,axs = plt.subplot_mosaic([['axisym','.','real','imag']],figsize=(8,3),width_ratios=[2,0.5,2,1])
lw = 1.5
for nm in axisym_mode_ids.keys():
    if 'evan' in nm:
        label = nm.replace('evan','evan.')
    else:
        label=nm
    axs['axisym'].plot(1/Fr*kz_stack_axisym[...,axisym_mode_ids[nm]].real,Z_axisym_EVP,color=color_dict[nm],linestyle=linestyle_dict[nm],linewidth=lw,label=label)

for nm in ['IGW-0', 'evan-0', 'IGW-1', 'evan-1', 'SM-0', 'AW-0', 'SM-1', 'SM-AW-1', 'AW-1']:
    if 'evan' in nm:
        label = nm.replace('evan','evan.')
    else:
        label=nm
    if nm == 'SM-AW-1':
        linestyle = (3, (2.5, 1.5))
    else:
        linestyle = linestyle_dict[nm]
    axs['real'].plot(1/Fr*kz_stack[...,mode_ids[nm]].real,Z_IVP,color=color_dict[nm],linestyle=linestyle,linewidth=lw)
    axs['imag'].plot(1/Fr*kz_stack[...,mode_ids[nm]].imag,Z_IVP,color=color_dict[nm],linestyle=linestyle,linewidth=lw)

# Plot Alfven continuum
axs['axisym'].fill_betweenx(Z_EVP,1/Fr*alfven_bdry,1/Fr*100*np.ones(NZ),color='lightgray',zorder=0)
axs['real'].fill_betweenx(Z_EVP,1/Fr*alfven_bdry,1/Fr*100*np.ones(NZ),color='lightgray',zorder=0)

# Plot pure IGW wave kz
kx = 2*np.pi
kz_pure_IGW = np.sqrt(kx**2+ky**2)
axs['real'].axvline(1/Fr*kz_pure_IGW,ymin=0.,ymax=1,color='dimgray',linewidth=1,linestyle='dashdot',zorder=0,alpha=0.5)
kz_pure_IGW_axisym = np.sqrt(kx**2)
axs['axisym'].axvline(1/Fr*kz_pure_IGW_axisym,ymin=0.,ymax=1,color='dimgray',linewidth=1,linestyle='dashdot',zorder=0,alpha=0.4)

# Plot SM-1 cutoff height for ky=0
ky0_cutoff_Z = np.log(6*np.pi*Gamma)/kb
for axl in ['axisym','real']:
    axs[axl].axhline(ky0_cutoff_Z,xmin=0.7,color='dimgray',linewidth=1,linestyle='--')
    axs[axl].set_xlabel(r"$\Re\{k_z\}/L^{-1}$")
    axs[axl].set_ylabel("$z/L$")

# Arrows
arrowstyle="-|>, head_width=0.2, head_length=0.65"
arrow_Z_dict = {"IGW-0": [0.11,0.08], "IGW-1":[0.11,0.08],"SM-1":[0.07,0.08],"SM-AW-1":[0.07,0.08],"SM-0":[0.1,0.12]}
for mode_nm in arrow_Z_dict.keys():
    Zstart,Zend = arrow_Z_dict[mode_nm]
    if mode_nm in axisym_mode_ids.keys():
        kzstart = get_kz_from_Z_axisym(mode_nm,Zstart).real
        kzend = get_kz_from_Z_axisym(mode_nm,Zend).real
        axs['axisym'].annotate("", xytext=(1/Fr*kzstart, Zstart), xy=(1/Fr*kzend, Zend),
                    arrowprops=dict(arrowstyle=arrowstyle,color=color_dict[mode_nm]))
    if mode_nm in mode_ids.keys():
        kzstart = get_kz_from_Z(mode_nm,Zstart).real
        kzend = get_kz_from_Z(mode_nm,Zend).real
        axs['real'].annotate("", xytext=(1/Fr*kzstart, Zstart), xy=(1/Fr*kzend, Zend),
                    arrowprops=dict(arrowstyle=arrowstyle,color=color_dict[mode_nm]))

# Legend
handles_1 = []
handles_2 = []
for nm in ['IGW-0', 'SM-1', 'evan-0', 'SM-AW-1', 'IGW-1', 'AW-1', 'evan-1', 'SM-0', 'AW-0']:
    if 'evan' in nm:
        label = nm.replace('evan','evan.')
    else:
        label=nm
    if nm == 'SM-AW-1':
        linestyle = (3, (2.5, 1.5))
    else:
        linestyle = linestyle_dict[nm]
    handle = axs['real'].axvline(np.nan,color=color_dict[nm],linestyle=linestyle,linewidth=lw,label=label)
    if nm in ['IGW-0', 'evan-0', 'IGW-1', 'SM-1', 'SM-AW-1', 'AW-1']:
        handles_1.append(handle)
    else:
        handles_2.append(handle)

handles_3 = []
handle = axs['real'].axvline(np.nan,ymin=0.,ymax=1,color='dimgray',linewidth=1,linestyle='dashdot',zorder=0,alpha=0.5,label="pure IGW")
handles_3.append(handle)
handle = axs['real'].axhline(np.nan,color='dimgray',linewidth=1,linestyle='--',label="SM cutoff height ($k_y = 0$)")
handles_3.append(handle)

# Labels, limits
legend_1 = axs['real'].legend(handles=handles_1,loc="center",bbox_to_anchor=(-0.75+0.05,-0.35),ncols=3,framealpha=0,columnspacing=2.,labelspacing=0.75)
axs['real'].add_artist(legend_1)
legend_2 = axs['real'].legend(handles=handles_2,handletextpad=1,loc="center",bbox_to_anchor=(0.85+0.05,-0.30),ncols=3,framealpha=0,columnspacing=2.5)
axs['real'].add_artist(legend_2)
legend_2 = axs['real'].legend(handles=handles_3,handletextpad=1,loc="center",bbox_to_anchor=(0.735+0.115+0.05,-0.4025),ncols=3,framealpha=0,columnspacing=1.4)

axs['imag'].set_xlabel(r"$\Im\{k_z\}/L^{-1}$")
axs['imag'].set_yticklabels([])
axs['axisym'].set_xlim(0.8*1/Fr*kz_pure_IGW_axisym,1/Fr*35)
axs['real'].set_xlim(0.8*1/Fr*kz_pure_IGW_axisym,1/Fr*35)
axs['imag'].set_xlim(-1/Fr*9.4,1/Fr*9.4)

# Titles
axs['axisym'].set_title("$k_y = 0$")
axs['real'].set_title(f"$k_y = {ky/np.pi:.0f}\\pi/L$")
axs['imag'].set_title(f"$k_y = {ky/np.pi:.0f}\\pi/L$")

panel_label = ['($a$)','($b$)','($c$)']
        
for k,axl in enumerate(['axisym','real','imag']):
    axs[axl].set_ylim(0,LZ)
    axs[axl].text(0.1, 1.11, panel_label[k], transform=axs[axl].transAxes, fontsize=14, va='top', ha='right')

# Turn off latex to use custom font for annotation, then turn it back on
plotting_setup.usetex(False)
for axl in ['axisym','real']:
    axs[axl].text(650,0.025,"ALFVÃ‰N WAVES",font=sans_name,rotation=20,fontsize=10,color='dimgray')
plotting_setup.usetex(True)

plt.subplots_adjust(wspace=0.1)
save_path = wkb_fig_path.joinpath(f"kz_ky={ky/np.pi:.1f}pi_Gamma={Gamma:.2f}".replace('.','p')+".pdf")
plt.savefig(save_path,bbox_inches='tight')

In [None]:
# plot comparision at several latitudes
def plot_profiles(sin_cos_dict, field_list, x_positions, savefig=False):

    title_fontsize = 11

    N_profs = len(x_positions)

    width_ratios = []
    for i in range(len(field_list)):
        if i == 0:
            width_ratios = width_ratios + N_profs*[1]
        else:
            width_ratios = width_ratios + [0.4] + N_profs*[1]
    
    N_subplots = len(width_ratios)
    gridspec = dict(hspace=0.0, width_ratios=width_ratios)

    fig, axs = plt.subplots(1,N_subplots,figsize=(6.5/6.8 * np.sum(width_ratios),5.75),gridspec_kw=gridspec)
    panel_label = [f"(${chr(i)}$)" for i in range(ord('a'),ord('z')+1)]
    shift = 0
    count = 0
    for j,field in enumerate(field_list):
        if j > 0:
            axs[shift-1].set_visible(False)
        q_IVP = sin_cos_dict[field]["ivp"]
        q_wkb_soln = sin_cos_dict[field]["wkb"]

        for i,x_pos in enumerate(x_positions):
            x_pos_idx = np.argmin(np.abs(x_IVP - x_pos))
            ax = axs[shift+i]
            ax.text(0.1, 1.12, panel_label[count], transform=ax.transAxes, fontsize=14, va='top', ha='right')
            ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            if i == 0:
                ax.set_ylabel('$z/L$', fontsize=title_fontsize, labelpad=-5)
            else:
                ax.set_yticks([])

            # Plot profiles
            ax.plot(q_IVP[x_pos_idx,:].real,Z_IVP,linestyle='--',color='navy',lw=1.5,label='IVP simulation')
            ax.plot(q_wkb_soln[x_pos_idx,:].real,Z_IVP,linestyle='-',color='darkred',lw=0.75,label='WKB theory')

            # Forcing height
            lw = 1.5
            ax.axhline(params_IVP['Z0'],linestyle='dotted',color='k',lw=lw)

            # Turning points and Alfven boundary intersection
            if ((((x_pos==0)|(x_pos==0.5)|(x_pos==1))&(field not in ['u','bx']))) | (((x_pos==0.25)|(x_pos==0.75))&(field in ['u','bx'])):
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-0"}"]][0],color=color_dict["IGW-0"],linestyle=linestyle_dict["IGW-0"],lw=lw)
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-0"]][1],color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"],lw=lw)
            elif ((((x_pos==0)|(x_pos==0.5)|(x_pos==1))&(field in ['u','bx']))) | (((x_pos==0.25)|(x_pos==0.75))&(field not in ['u','bx'])):
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-1"}"]][0],color=color_dict["IGW-1"],linestyle=linestyle_dict["IGW-1"],lw=lw)
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-1"]][1],color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"],lw=lw)
                ax.axhline(alfven_intsc_Z,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"],lw=lw)
            else:
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-0"}"]][0],color=color_dict["IGW-0"],linestyle=linestyle_dict["IGW-0"],lw=lw)
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-0"]][1],color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"],lw=lw)
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-{"IGW-1"}"]][0],color=color_dict["IGW-1"],linestyle=linestyle_dict["IGW-1"],lw=lw)
                ax.axhline(turning_pt_stack[full_branch_ids[f"evan-SM-1"]][1],color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"],lw=lw)
                ax.axhline(alfven_intsc_Z,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"],lw=lw)

            ax.set_title(f'${label_dict[field]}$', fontsize=title_fontsize, pad=3.5)
            ax.set_xlabel(f'$x = {x_IVP[x_pos_idx]:.2f} L$', fontsize=title_fontsize)
            ax.tick_params(axis='x', which='major', pad=0)
            max_ivp = np.max(np.abs(q_IVP[x_pos_idx,:].real))
            xlims = (-1.2*max_ivp,1.2*max_ivp)
            ax.fill_between([xlims[1],xlims[0]],[params_IVP['s'],params_IVP['s']], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5, zorder=10)
            ax.fill_between([xlims[1],xlims[0]],[LZ-params_IVP['s'],LZ-params_IVP['s']],[LZ,LZ], facecolor="none", hatch="//////", edgecolor="k", linewidth=0.5, zorder=10)
            ax.set_ylim(0,LZ)
            ax.set_xlim(xlims)
            ax.set_aspect(5.55 * (xlims[1]-xlims[0])/(params_IVP['LZ']))

            count += 1
        
        shift+= N_profs + 1

    # Labels
    ax.axhline(np.nan,color=color_dict["IGW-0"],linestyle=linestyle_dict["IGW-0"],lw=lw,label="IGW-0 $\\rightarrow$ SM-0")
    ax.axhline(np.nan,color=color_dict["AW-0"],linestyle=linestyle_dict["AW-0"],lw=lw,label="SM-0 $\\rightarrow$ AW-0")
    ax.axhline(np.nan,color=color_dict["IGW-1"],linestyle=linestyle_dict["IGW-1"],lw=lw,label="IGW-1 $\\rightarrow$ SM-1")
    ax.axhline(np.nan,color=color_dict["SM-AW-1"],linestyle=linestyle_dict["SM-AW-1"],lw=lw,label="SM-1 $\\rightarrow$ SM-AW-1")
    ax.axhline(np.nan,color=color_dict["AW-1"],linestyle=linestyle_dict["AW-1"],lw=lw,label="SM-AW-1 $\\rightarrow$ AW-1")
    ax.axhline(np.nan,linestyle='dotted',color='k',lw=lw,label='driving layer')
    ax.legend(loc="center",bbox_to_anchor = (-(np.sum(width_ratios)-1/2)/2,-0.175), ncols=4, framealpha=0)

    if savefig:
        save_name = sim_name + "_"
        save_name = save_name + "_".join(field_list) 
        save_name = save_name+"_"+ "_".join([f"{xp:.2f}".replace(".","p") for xp in x_positions])
        save_name = save_name+".pdf"
        save_path = wkb_fig_path.joinpath(save_name)
        plt.savefig(save_path,bbox_inches='tight')


In [None]:
plot_profiles(sin_cos_dict, field_list=["u"], x_positions=np.arange(0,1.05,0.05), savefig=True)

In [None]:
plot_profiles(sin_cos_dict, field_list=["v"], x_positions=np.arange(0,1.05,0.05), savefig=True)

In [None]:
plot_profiles(sin_cos_dict, field_list=["u","v"], x_positions=[x_fit,0.5,0.75,0.95], savefig=True)

# Integrate group velocities

In [None]:
Z0 = params_IVP['Z0']
t_arr = timeseries_dict['t']
kz_char_stack = np.nan*np.ones((2,len(t_arr)),dtype=np.complex128)

Z_char_stack = np.nan*np.ones((2,len(t_arr)))
Z_char_stack[:,0] = Z0

mode_char_list = [['IGW-0','SM-0'],['IGW-1','SM-1','SM-AW-1']]
mode_char_stack = np.empty((2,len(t_arr)),dtype='<U10')
mode_char_stack[...] = 'none'
mode_char_stack[0,0] = 'IGW-0'
mode_char_stack[1,0] = 'IGW-1'
t_transition = [[],[]]
Z_transition = [
    [turning_pt_stack[full_branch_ids['evan-IGW-0']][0], turning_pt_stack[full_branch_ids['evan-SM-0']][1]],
    [turning_pt_stack[full_branch_ids['evan-IGW-1']][0], alfven_intsc_Z, turning_pt_stack[full_branch_ids['evan-SM-1']][1]]
    ]

for full_branch_idx in range(2):
    kz_profile = kz_full_branches_stack[...,full_branch_ids[f"evan-IGW-{full_branch_idx}"]]
    kz_char_stack[full_branch_idx,0] = np.interp(Z0,Z_IVP,kz_profile)

    mode_list = mode_char_list[full_branch_idx]
    mode_list_idx = 0
    n = 0
    while (n < len(t_arr)-1)&(mode_list_idx<len(mode_list)):
        Z_char_n = Z_char_stack[full_branch_idx,n]

        if "IGW" in mode_list[mode_list_idx]:
            full_branch_key = f"evan-IGW-{full_branch_idx}"
        else:
            full_branch_key = f"evan-SM-{full_branch_idx}"

        cg_profile = cg_full_branches_stack[...,full_branch_ids[full_branch_key]].real
        cg_n = np.interp(Z_char_n,Z_IVP,cg_profile).real
        current_direction = np.sign(cg_n)

        Z_char_np1 = Z_char_n + Fr * cg_n * (t_arr[n+1] - t_arr[n])

        if current_direction*(Z_char_np1 - Z_transition[full_branch_idx][mode_list_idx]) >= 0:
            mode_list_idx += 1
            t_transition[full_branch_idx].append(t_arr[n+1])
        else:
            Z_char_stack[full_branch_idx,n+1] = Z_char_np1
            mode_char_stack[full_branch_idx,n+1] = mode_list[mode_list_idx]

            kz_profile = kz_full_branches_stack[...,full_branch_ids[full_branch_key]]
            kz_char_stack[full_branch_idx,n+1] = np.interp(Z_char_np1,Z_IVP,kz_profile)

            n+=1

import pickle

wkb_save_path = pathlib.Path.joinpath(file_dir.parent,f"{sim_name}-wkb.pickle")
wkb_dict = {'mode_ids':mode_ids,'Z':Z_IVP,'t':t_arr,'kz':kz_stack,
            'kz_char':kz_char_stack,'Z_char':Z_char_stack,
            'Z_transition':Z_transition,'t_transition':t_transition,
            'mode_char':mode_char_stack,'color':color_dict,'linestyle':linestyle_dict}
with open(wkb_save_path, 'wb') as handle:
            pickle.dump(wkb_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
fig, axs = plt.subplots(1,2)

axs[0].scatter(kz_char_stack[0].real,Z_char_stack[0],c=[color_dict[modenm] for modenm in mode_char_stack[0]],s=2)
axs[0].scatter(kz_char_stack[1].real,Z_char_stack[1],c=[color_dict[modenm] for modenm in mode_char_stack[1]],s=2)
axs[0].set_xlabel(r"$\Re\{k_z\}$")

axs[1].scatter(t_arr,Z_char_stack[0],c=[color_dict[modenm] for modenm in mode_char_stack[0]],s=2)
axs[1].scatter(t_arr,Z_char_stack[1],c=[color_dict[modenm] for modenm in mode_char_stack[1]],s=2)
axs[1].set_xlabel(r"$t \omega$")

for ax in axs:
    ax.set_ylabel(r"$z/L$")
    ax.set_ylim(0,LZ)
    ax.axhline(Z0,color='k',linestyle='dotted')
plt.subplots_adjust(wspace=0.3)