In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 5_close_look_merger_sample.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Investigate a subset of the major merger sample that will be used to investigate interesting trends.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## galpy
from galpy import orbit
from galpy import potential

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import kinematics as pkin
from tng_dfs import plot as pplot
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
# fig_dir = os.path.join(fig_dir_base, 'notebooks')
os.makedirs(local_fig_dir,exist_ok=True)
# os.makedirs(fig_dir,exist_ok=True)
# show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Get some data

In [None]:
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)
merger_data = np.load(os.path.join(analysis_dir,'merger_data.npy'))
print(merger_data.dtype.names)

### Determine the GS/E case

In [None]:
fig = plt.figure(figsize=(16,4))
ax1,ax2,ax3,ax4 = fig.subplots(nrows=1, ncols=4)

star_mass = merger_data['star_mass']
merger_snapnum = merger_data['merger_snapnum']
star_mass_ratio = merger_data['star_mass_ratio']
dm_mass_ratio = merger_data['dm_mass_ratio']

mask = merger_data['anisotropy'] > 0.8

cmap = plt.cm.get_cmap('viridis')
ax1.scatter(star_mass[mask], 
                 putil.snapshot_to_redshift(merger_snapnum[mask]), 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax1.set_xscale('log')
ax1.set_xlabel(r'$M_{\star}$')
ax1.set_ylabel(r'$z$')
texts1 = []
for i in range(len(merger_data[mask])):
    txt = ax1.text( star_mass[mask][i],
              putil.snapshot_to_redshift(merger_snapnum[mask][i]),
              str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
              fontsize=8 
    )
    texts1.append(txt)
ax1.text(0.05, 0.05, r'$\beta > 0.8$', transform=ax1.transAxes)

ax2.scatter(star_mass[mask], 
                 1/star_mass_ratio[mask], 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax2.set_xscale('log')
ax2.set_xlabel(r'$M_{\star}$')
ax2.set_ylabel(r'$M_{\star,p}/M_{\star,s}$')
for i in range(len(merger_data[mask])):
    ax2.annotate(str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                (star_mass[mask][i],1/star_mass_ratio[mask][i]),
                fontsize=8 
    )

ax3.scatter(star_mass[mask], 
                 merger_data['anisotropy'][mask], 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax3.set_xscale('log')
ax3.set_xlabel(r'$M_{\star}$')
ax3.set_ylabel(r'$\beta$')
texts = []
for i in range(len(merger_data[mask])):
    txt = ax3.text(star_mass[mask][i],
                   merger_data['anisotropy'][mask][i],
                   str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                   fontsize=8 
    )
    texts.append(txt)

ax4.scatter(star_mass[mask], 
                 1/dm_mass_ratio[mask], 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax4.set_xscale('log')
ax4.set_xlabel(r'$M_{\star}$')
ax4.set_ylabel(r'$M_{\rm DM,p}/M_{\rm DM,s}$')
for i in range(len(merger_data[mask])):
    ax4.annotate(str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                (star_mass[mask][i],1/dm_mass_ratio[mask][i]),
                fontsize=8 
    )

fig.tight_layout()
fig.show()

Good candidates for GS/E analogs are:
- 552414-1
- 522530-3

So we choose 552414, merger #1. It has a mass of about $2\times10^{9}$ , merged at about $z=0.75$, with a mass ratio of 12:1

### Get retrograde case

In [None]:
fig = plt.figure(figsize=(15,4))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)

mask = merger_data['krot'] < -0.5

cmap = plt.cm.get_cmap('viridis')
ax1.scatter(star_mass[mask], 
                 putil.snapshot_to_redshift(merger_snapnum[mask]), 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax1.set_xscale('log')
ax1.set_xlabel(r'$M_{\star}$')
ax1.set_ylabel(r'$z$')
for i in range(len(merger_data[mask])):
    ax1.annotate(str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                (star_mass[mask][i],putil.snapshot_to_redshift(merger_snapnum[mask][i])),
                fontsize=8 
    )


ax2.scatter(star_mass[mask], 
                 1/star_mass_ratio[mask], 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax2.set_xscale('log')
ax2.set_xlabel(r'$M_{\star}$')
ax2.set_ylabel(r'$M_{\star,p}/M_{\star,s}$')
for i in range(len(merger_data[mask])):
    ax2.annotate(str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                (star_mass[mask][i],1/star_mass_ratio[mask][i]),
                fontsize=8 
    )

ax3.scatter(star_mass[mask], 
                 merger_data['krot'][mask], 
                 s=10, c='Black', alpha=1.0)
#cbar = fig.colorbar(pts)
#cbar.set_label(r'$M_{\star,p}/M_{\star,s}$')
ax3.set_xscale('log')
ax3.set_xlabel(r'$M_{\star}$')
ax3.set_ylabel(r'$k$')
for i in range(len(merger_data[mask])):
    ax3.annotate(str(merger_data['z0_sid'][mask][i])+'-'+str(merger_data['merger_number'][mask][i]),
                (star_mass[mask][i],merger_data['krot'][mask][i]),
                fontsize=8 
    )


fig.tight_layout()
fig.show()

Good candidates for mergers are:
- 531910-8
- 518682-4

Looks like the best merger is 531910-8. It has intermediate mass, $3\times10^{8}$, and a mass ratio of 4:1

### Get properties for both

In [None]:
gse_mask = (merger_data['z0_sid'] == 522530) &\
    (merger_data['merger_number'] == 3) 
gse_indx = np.where(gse_mask)[0][0]

print('GS/E analog')
print(merger_data.dtype.names)
print(merger_data[gse_indx])
print('stellar mass', star_mass[gse_indx])
print('stellar mass ratio', 1/star_mass_ratio[gse_indx])
print('accretion redshift', putil.snapshot_to_redshift(merger_snapnum[gse_indx]))

seq_mask = (merger_data['z0_sid'] == 518682) &\
    (merger_data['merger_number'] == 4)
seq_indx = np.where(seq_mask)[0][0]
print('\nSequoia analog')
print(merger_data.dtype.names)
print(merger_data[seq_indx])
print('stellar mass', star_mass[seq_indx])
print('stellar mass ratio', 1/star_mass_ratio[seq_indx])
print('accretion redshift', putil.snapshot_to_redshift(merger_snapnum[seq_indx]))

### Get orbit samples for the close-look cases

In [None]:
# z0_sids = np.array( [552414,531910] )
# merger_numbers = np.array( [1,8] )
z0_sids = np.array( [522530,518682] )
merger_numbers = np.array( [3,4] )
n_cases = len(z0_sids)
assert n_cases == len(merger_numbers)


# Constant beta and N-body
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function')
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Get the orbits
sample_data_cb = np.load(os.path.join(analysis_dir,'sample_data_cb.npy'),
    allow_pickle=True)
sample_data_om = np.load(os.path.join(analysis_dir,'sample_data_om.npy'),
    allow_pickle=True)
sample_data_om2 = np.load(os.path.join(analysis_dir,'sample_data_om2.npy'),
    allow_pickle=True)

# Interpolated potential version
interpot_version = 'all_star_dm_enclosed_mass'

interpots = []
ms = []
pe = []
# pe_cb = []
# pe_om = []
o = []
o_cb = []
o_om = []
o_om2 = []

for k in range(n_cases):
    for i in range(n_mw):

        # Get the primary
        primary = tree_primaries[i]
        z0_sid = primary.subfind_id[0]

        if z0_sid != z0_sids[k]: continue
        # z0_sid_mask = z0_sids == z0_sid

        major_mergers = primary.tree_major_mergers
        n_major = primary.n_major_mergers
        n_snap = len(primary.snapnum)
        primary_filename = primary.get_cutout_filename(mw_analog_dir,
            snapnum=primary.snapnum[0])
        co = pcutout.TNGCutout(primary_filename)
        co.center_and_rectify()
        pid = co.get_property('stars','ParticleIDs')

        # Load the interpolator for the sphericalized potential
        interpolator_filename = os.path.join(dens_fitting_dir,
            'spherical_interpolated_potential/',interpot_version,
            str(z0_sid),'interp_potential.pkl')
        with open(interpolator_filename,'rb') as handle:
            interpot = pickle.load(handle)

        for j in range(n_major):
            if (j+1) != merger_numbers[k]: continue
            if verbose:
                print('z=0 SID: ', z0_sid, ', merger number: ', j+1)
            
            # Get the major merger
            major_merger = primary.tree_major_mergers[j]
            major_acc_sid = major_merger.subfind_id[0]
            major_mlpid = major_merger.secondary_mlpid
            upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
            indx = np.where(np.isin(pid,upid))[0]
            orbs = co.get_orbs('stars')[indx]
            rs = orbs.r().to(apu.kpc).value
            n_star = len(orbs)
            star_mass = co.get_masses('stars')[indx].to_value(apu.Msun)
            pot_energy = co.get_potential_energy('stars')[indx].to_value(apu.km**2/apu.s**2)
            vels = co.get_velocities('stars')[indx].to_value(apu.km/apu.s)
            vmag = np.linalg.norm(vels,axis=1)
            energy = pot_energy + 0.5*vmag**2
            Lz = orbs.Lz().to_value(apu.kpc*apu.km/apu.s)

            o_mask = (sample_data_om['z0_sid'] == z0_sid) &\
                     (sample_data_om['merger_number'] == j+1)
            sample_cb = sample_data_cb[o_mask]['sample'][0]
            sample_om = sample_data_om[o_mask]['sample'][0]
            sample_om2 = sample_data_om2[o_mask]['sample'][0]

            interpots.append( interpot )
            o.append( orbs )
            o_cb.append( sample_cb )
            o_om.append( sample_om )
            o_om2.append( sample_om2 )
            pe.append( pot_energy )
            ms.append( star_mass )
            

### Examine velocity dispersions of samples

In [None]:
columnwidth, textwidth = pplot.get_latex_columnwidth_textwidth_inches()

fig = plt.figure(figsize=(12,6))
axs = fig.subplots(nrows=2, ncols=4)

nbody_color = 'Black'
sample_colors = ['DodgerBlue','Red', 'DarkOrange']
panel_labels = ['GS/E analog', 'Sequoia analog']
colors = [nbody_color] + sample_colors
linestyles = ['solid','solid','solid','solid']
labels = ['N-body', 'CB', 'OM', 'OM2']
lws = [3,1,1,1]
n_bs = 10
sv_lims = [[1,3], [1.5,2.5]]
panel_label_fs = 9

for i in range(len(o)):
    # _axs = axs[i,:]
    # fig, _axs = plot_beta_vdisp(o[i], o_cb[i], n_bs=5, fig=fig, axs=_axs, color='DodgerBlue')
    # fig, _axs = plot_beta_vdisp(o[i], o_om[i], n_bs=5, fig=fig, axs=_axs, color='Red')

    # Binning for velocity dispersions and betas
    r_softening = putil.get_softening_length('stars', z=0, physical=True)
    rmin = np.max( [np.min(o[i].r().to_value(apu.kpc)), r_softening] )
    n_bin = np.min([500, len(o[i])//10]) # n per bin
    adaptive_binning_kwargs = {
        'n':n_bin,
        'rmin':rmin,
        'rmax':np.max( o[i].r().to_value(apu.kpc) ),
        'bin_mode':'exact numbers',
        'bin_equal_n':True,
        'end_mode':'ignore',
        'bin_cents_mode':'median',
    }
    bin_edges, bin_cents, bin_n = pkin.get_radius_binning(o[i], 
        **adaptive_binning_kwargs)

    # Compute velocity dispersions for N-body
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    nbody_beta, nbody_vr2, nbody_vp2, nbody_vt2 = \
        pkin.compute_betas_bootstrap(o[i], bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Compute velocity dispersions for the constant beta DF samples
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    sample_beta_cb, sample_vr2_cb, sample_vp2_cb, sample_vt2_cb = \
        pkin.compute_betas_bootstrap(o_cb[i], bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)
    
    # Compute velocity dispersions for the Osipkov-Merritt DF samples
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    sample_beta_om, sample_vr2_om, sample_vp2_om, sample_vt2_om = \
        pkin.compute_betas_bootstrap(o_om[i], bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Compute velocity dispersions for the Osipkov-Merritt linear combination DF samples
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    sample_beta_om2, sample_vr2_om2, sample_vp2_om2, sample_vt2_om2 = \
        pkin.compute_betas_bootstrap(o_om2[i], bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Plot beta
    betas = [nbody_beta, sample_beta_cb, sample_beta_om, sample_beta_om2]
    for j in range(4):
        axs[i,0].plot(np.log10(bin_cents), np.median(betas[j], axis=0), 
            color=colors[j], label=labels[j], linewidth=lws[j],
            linestyle=linestyles[j])
        # axs[i,0].fill_between(np.log10(bin_cents), 
        #     np.percentile(betas[j], 16, axis=0),
        #     np.percentile(betas[j], 84, axis=0), 
        #     color=colors[j], alpha=0.25)

    # Plot velocity dispersions
    sv2s = [[nbody_vr2,nbody_vp2,nbody_vt2],
            [sample_vr2_cb,sample_vp2_cb,sample_vt2_cb],
            [sample_vr2_om,sample_vp2_om,sample_vt2_om],
            [sample_vr2_om2,sample_vp2_om2,sample_vt2_om2]]
    v_suffixes = [r'r',r'\phi',r'\theta']
    for j in range(4): # Loop over Nbody, CB, OM, OM2
        for k in range(3): # Loop over vr, vp, vt
            axs[i,k+1].plot(np.log10(bin_cents), 
                np.log10(np.sqrt(np.median(sv2s[j][k], axis=0))), 
                color=colors[j], label=labels[j], linewidth=lws[j],
                linestyle=linestyles[j])
            # axs[i,k+1].fill_between(np.log10(bin_cents), 
            #     np.log10(np.sqrt(np.percentile(sv2s[j][k], 16, axis=0))),
            #     np.log10(np.sqrt(np.percentile(sv2s[j][k], 84, axis=0))), 
            #     color=colors[j], alpha=0.25)
            axs[i,k+1].set_ylabel(r'$\log( \sigma_'+v_suffixes[k]+r'/ \mathrm{[km/s]})$')

    # Labels
    # axs[i,0].set_xlabel(r'$r$ [kpc]')
    axs[i,0].set_xlabel(r'$\log(r/\mathrm{kpc})$')
    # axs[i,0].set_xscale('log')
    axs[i,0].set_ylabel(r'$\beta$')
    # axs[i,0].set_ylim(-0.3,1.1)
    # axs[0].legend()
    for k in range(3):
        # axs[i,k+1].set_xlabel(r'$r$ [kpc]')
        axs[i,k+1].set_xlabel(r'$\log(r/\mathrm{kpc})$')
        # axs[i,k+1].set_xscale('log')
        # axs[i,k+1].set_yscale('log')
        axs[i,k+1].set_ylim(sv_lims[i])
    
    # Label the main panels
    axs[i,1].text(0.99, 0.9, s=panel_labels[i], transform=axs[i,1].transAxes,
        fontsize=panel_label_fs, ha='right', va='center')

axs[0,0].legend(loc='best', fontsize=12)
fig.tight_layout()

fig.savefig(local_fig_dir+'beta_vdisp.pdf', bbox_inches='tight')

### Make a figure of energy and angular momentum including margins

In [None]:
columnwidth, textwidth = pplot.get_latex_columnwidth_textwidth_inches()

# Make the axes
fig = plt.figure(figsize=(columnwidth,5))
gs = fig.add_gridspec(9, 5)
axs = [fig.add_subplot(gs[1:4,0:4]),
       fig.add_subplot(gs[6:9,0:4])]
maxs = [[fig.add_subplot(gs[0,0:4]),
         fig.add_subplot(gs[1:4,4])],
        [fig.add_subplot(gs[5,0:4]),
         fig.add_subplot(gs[6:9,4])]
        ]

nbody_color = 'Black'
sample_colors = ['DodgerBlue','Red','DarkOrange']
panel_labels = ['GS/E analog', 'Sequoia analog']
colors = [nbody_color] + sample_colors
xlims = [(-3000,3000), (-5000,3000)]
ylims = [(-4,-2), (-1.25,0)]
alphas = [0.01, 0.2]
labels = ['N-body', 'CB', 'OM', 'OM2']
linestyles = ['solid','dashed','solid','dotted']
# bins_base = 20
# bins = [int(bins_base/np.sqrt(len(o[0])/len(o[i]))) for i in range(len(o))]
bins = [20, 15]
levels = [500, 80]
xaxis_label_fs = 10
yaxis_label_fs = 10
margin_axis_label_fs = 9
legend_fs = 6
tick_label_fs = 8
panel_label_fs = 7

# Do the Nbody data
for i in range(len(o)):
    # Compute the energies and angular momentum
    vmag = ((o[i].vR()**2 + o[i].vT()**2 + o[i].vz()**2)**0.5).to_value(apu.km/apu.s)
    e = pe[i] + 0.5*vmag**2
    Lz = o[i].Lz().to_value(apu.kpc*apu.km/apu.s)
    # print(len(e), len(Lz))

    # Scale the N-body energies by the potential energy of interpot at the 
    # stellar half-mass radius
    rs = o[i].r().to_value(apu.kpc)
    rhalf = pkin.half_mass_radius(rs, ms[i])
    rhalf_perc = 0.05
    rhalf_mask = np.abs(rs-rhalf) < rhalf_perc*rhalf
    rhalf_pe_star = np.median(pe[i][rhalf_mask])
    rhalf_pe_interpot = potential.evaluatePotentials(
        interpots[i], rhalf*apu.kpc, 0.).to_value(apu.km**2/apu.s**2)
    rhalf_pe_offset = rhalf_pe_star - rhalf_pe_interpot
    e -= rhalf_pe_offset

    # Scatter plot
    # axs[i,0].scatter(Lz, e/1e5, s=1, color=nbody_color, alpha=alphas[i])
    axs[i].set_xlabel(r'$L_{\mathrm{z}}$', fontsize=xaxis_label_fs)
    axs[i].set_ylabel(r'$E/10^{5}$ [km$^2$/s$^2$]', fontsize=yaxis_label_fs)
    axs[i].set_xlim(xlims[i])
    axs[i].set_ylim(ylims[i])
    axs[i].tick_params(labelsize=tick_label_fs)

    # Contours
    H, xedges, yedges = np.histogram2d(Lz, e/1e5, bins=bins[i], 
        range=[xlims[i],ylims[i]])
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    axs[i].imshow(H.T, extent=extent,
        origin='lower', aspect='auto', cmap='Greys')
    axs[i].contour(H.T, extent=extent, levels=[levels[i]], # 1.5*levels[i]],
        origin='lower', colors=colors[0], label=labels[0], 
        linewidths=2.0, linestyles=linestyles[0])
    
    # Label the main panels
    axs[i].text(0.99, 0.9, s=panel_labels[i], transform=axs[i].transAxes,
        fontsize=panel_label_fs, ha='right', va='center')

    # Do the margins
    N, edges = np.histogram(Lz, bins=bins[i], range=xlims[i], 
       density=True)
    cents = (edges[:-1] + edges[1:])/2
    maxs[i][0].plot(cents, N, color=nbody_color, linewidth=2., 
        linestyle=linestyles[0])
    maxs[i][0].set_xlim(xlims[i])
    maxs[i][0].tick_params(labelbottom=False, labelleft=False)
    maxs[i][0].set_ylabel('Density', fontsize=margin_axis_label_fs)

    N, edges = np.histogram(e/1e5, bins=bins[i], range=ylims[i],
       density=True)
    cents = (edges[:-1] + edges[1:])/2
    maxs[i][1].plot(N, cents, color=nbody_color, linewidth=2., 
        linestyle=linestyles[0])
    maxs[i][1].set_ylim(ylims[i])
    maxs[i][1].tick_params(labelbottom=False, labelleft=False)
    maxs[i][1].set_xlabel('Density', fontsize=margin_axis_label_fs)

# Now do the CB, OM, OM2
orbs = [o_cb, o_om, o_om2]
for i in range(len(orbs)):

    for j in range(len(orbs[i])):

       e = orbs[i][j].E(pot=interpots[j]).to_value(apu.km**2/apu.s**2)
       Lz = orbs[i][j].Lz().to_value(apu.kpc*apu.km/apu.s)
       # print(len(e), len(Lz))

       # Scatter plot
       # axs[i,0].scatter(Lz, e/1e5, s=1, color=nbody_color, alpha=alphas[i])
       axs[j].set_xlabel(r'$L_{\mathrm{z}}$ [kpc km s$^{-1}$]')
       # axs[j,i+1].set_ylabel(r'$E/10^{5}$ [km$^2$/s$^2$]')
       axs[j].set_xlim(xlims[j])
       axs[j].set_ylim(ylims[j])

       # Contours
       H, xedges, yedges = np.histogram2d(Lz, e/1e5, bins=bins[j], 
           range=[xlims[j],ylims[j]])
       extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
       # axs[j,i+1].imshow(H.T, extent=extent,
       #     origin='lower', aspect='auto', cmap='Greys')
       axs[j].contour(H.T, extent=extent, levels=[levels[j],], # 1.5*levels[j]],
           origin='lower', colors=colors[i+1], linewidths=2.0,
           linestyles=linestyles[i+1])
        
       # Do the margins
       N, edges = np.histogram(Lz, bins=bins[j], range=xlims[j], 
              density=True)
       cents = (edges[:-1] + edges[1:])/2
       maxs[j][0].plot(cents, N, color=colors[i+1], linewidth=1.,
        linestyle=linestyles[i+1])

       N, edges = np.histogram(e/1e5, bins=bins[j], range=ylims[j],
              density=True)
       cents = (edges[:-1] + edges[1:])/2
       maxs[j][1].plot(N, cents, color=colors[i+1], linewidth=1.,
        linestyle=linestyles[i+1])


for i in range(2):
    for j in range(3):
        axs[i].axvline(0, color='Gray', linestyle='dashed')
        # if j > 0:
        #     axs[i,j].tick_params(labelleft=False)

for i in range(4):
    axs[1].plot([], [], color=colors[i], label=labels[i],
        linestyle=linestyles[i])
axs[1].legend(loc='upper left', fontsize=legend_fs)
fig.tight_layout()

fig.savefig(local_fig_dir+'energy_Lz.pdf', dpi=300, bbox_inches='tight')