In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 6_higher_order_moments.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Compute the higher order moments of the data and the DF realizations
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle, logging, time

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu
from astropy import constants as apc

## Analysis
import scipy.stats
import scipy.interpolate

## galpy
from galpy import orbit
from galpy import potential

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import kinematics as pkin
from tng_dfs import plot as pplot
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/5_compare_distribution_functions/6_higher_order_moments/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Functions

In [None]:
def compute_vmoment(orbs, moment, bin_edges, n_bootstrap=100, 
    stdev_normalization=True, pearson_correction=True):
    '''compute_vmoment:
    
    Wrapper for compute_standard_vmoment() and compute_mean_std_vmoment()

    Args:
        orbs (galpy.orbit.Orbit): Orbits
        moment (int): Moment to compute
        bin_edges (np.ndarray): Bin edges
        n_bootstrap (int): Number of times to bootstrap the orbits to compute
            the moment for error estimation
        kwargs (dict): kwargs for compute_standard_vmoment() or 
            compute_mean_std_vmoment()
    '''
    if moment in [1,2,3,4]:
        return compute_standard_vmoment(orbs, moment, bin_edges, 
            n_bootstrap=n_bootstrap, stdev_normalization=stdev_normalization,
            pearson_correction=pearson_correction)
    elif moment in ['mean','std','meansq']:
        return compute_mean_std_vmoment(orbs, moment, bin_edges, 
            n_bootstrap=n_bootstrap)

def compute_standard_vmoment(orbs, n, bin_edges, n_bootstrap=100, 
    stdev_normalization=True, pearson_correction=True):
    '''compute_vmoment:
    
    Compute the nth standard moment of the spherical velocity distribution in 
    bins.
    
    Args:
        orbs (galpy.orbit.Orbit): Orbits
        n (int): Moment to compute
        bin_edges (np.ndarray): Bin edges
        n_bootstrap (int): Number of times to bootstrap the orbits to compute
            the moment for error estimation
        stdev_normalization (bool): If True, normalize the moment 
        pearson_correction (bool): If True, apply Pearson's correction to the 
            4th standardized moment such that it gives excess kurtosis.

    
    Returns:
        v_moment (np.ndarray): Velocity moments, shape (3, len(bin_edges)-1, n_bs)
    '''
    r = orbs.r().to_value(apu.kpc)
    vr = orbs.vr().to_value(apu.km/apu.s)
    vt = orbs.vtheta().to_value(apu.km/apu.s)
    vp = orbs.vT().to_value(apu.km/apu.s)

    vmom = np.zeros((3,n_bootstrap,len(bin_edges)-1))
    if n == 1:
        vmom[:,:,:] = 0 # Definitionaly
        return vmom
    if n == 2 and stdev_normalization:
        vmom[:,:,:] = 1 # Definitionaly
        return vmom
    n_orbs = len(orbs)

    for i in range(n_bootstrap):
        idx = np.random.randint(0,n_orbs,n_orbs)

        for j in range(len(bin_edges)-1):
            bidx = (r[idx] > bin_edges[j]) & (r[idx] < bin_edges[j+1])
            if np.sum(bidx) == 0:
                vmom[:,i,j] = np.nan
                continue
            for k,v in enumerate([vr,vp,vt]):
                _vmom = np.mean((v[idx][bidx]-np.mean(v[idx][bidx]))**n)
                if stdev_normalization:
                    _std = np.std(v[idx][bidx])
                    _vmom /= _std**n
                vmom[k,i,j] = _vmom
    
    if n == 4 and stdev_normalization and pearson_correction:
        vmom -= 3.

    return vmom

def compute_mean_std_vmoment(orbs, moment, bin_edges, n_bootstrap=100):
    '''compute_mean_std_vmoment:
    
    Compute the usual spherical velocity mean/std with the same interface as 
    compute_standard_vmoment()
    
    Args:
        orbs (galpy.orbit.Orbit): Orbits
        moment (str): 'mean', 'std', 'meansq'
        bin_edges (np.ndarray): Bin edges
        n_bootstrap (int): Number of times to bootstrap the orbits to compute
            the moment for error estimation
    
    Returns:
        v_moment (np.ndarray): Velocity means, shape (3, len(bin_edges)-1, n_bs)
    '''
    assert moment in ['mean','std','meansq']
    r = orbs.r().to_value(apu.kpc)
    vr = orbs.vr().to_value(apu.km/apu.s)
    vt = orbs.vtheta().to_value(apu.km/apu.s)
    vp = orbs.vT().to_value(apu.km/apu.s)
    vmom = np.zeros((3,n_bootstrap,len(bin_edges)-1))
    n_orbs = len(orbs)

    for i in range(n_bootstrap):
        idx = np.random.randint(0,n_orbs,n_orbs)

        for j in range(len(bin_edges)-1):
            bidx = (r[idx] > bin_edges[j]) & (r[idx] < bin_edges[j+1])
            if np.sum(bidx) == 0:
                vmom[:,i,j] = np.nan
                continue
            for k,v in enumerate([vr,vp,vt]):
                if moment == 'mean':
                    vmom[k,i,j] = np.mean(v[idx][bidx])
                elif moment == 'std':
                    vmom[k,i,j] = np.std(v[idx][bidx])
                elif moment == 'meansq':
                    vmom[k,i,j] = np.mean(v[idx][bidx]**2)

    return vmom

def compute_mass_error_weighted_deviation_vmoment(nbody_orbs, sample_orbs, 
    nbody_mass, n, n_bs=100, adaptive_binning_kwargs={}, 
    raise_inverse_power=False, stdev_normalization=True, pearson_correction=True):
    '''compute_mass_error_weighted_deviation_v4:
    
    Compute the mass- and uncertainty-weighted deviation of the N-body 
    velocity 4th order moments of vr, vp, and vt from the DF samples.

    For the binning scheme the default kwargs are:
    - n: min(500, number of N-body particles//10)
    - rmin: 0.
    - rmax: max(N-body particle radii)
    - bin_mode: 'exact numbers'
    - bin_equal_n: True
    - end_mode: 'ignore'
    - bin_cents_mode: 'median'

    Args:
        nbody_orbs (galpy.orbit.Orbit): N-body orbits
        sample_orbs (galpy.orbit.Orbit): DF samples
        nbody_mass (np.ndarray): N-body particle masses
        n (int): Moment to compute
        n_bs (int): Number of times to bootstrap the DF/N-body samples
            to compute the deviation statistic for error estimation
        adaptive_binning_kwargs (dict): kwargs for get_radius_binning(), will
            be populated with defaults listed above if not provided.
        moment_inverse_power (bool): If True, raise each moment to the 
            inverse power of n.
        standardized (bool): If True, compute the standardized moment, 
            i.e. divide by the standard deviation to the nth power
        pearson (bool): If True, apply Pearson's correction to the 4th 
            standardized moment such that it gives excess kurtosis.
    
    Returns:
        mwed_[beta,vr2,vp2,vt2] (np.ndarray): Mass-weighted error deviation
    '''
    # Binning for velocity dispersions and betas
    n_bin = np.min([500, len(nbody_orbs)//10]) # n per bin
    if 'n' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['n'] = n_bin
    if 'rmin' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['rmin'] = 0.
    if 'rmax' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['rmax'] = np.max( nbody_orbs.r().to_value(apu.kpc) )
    if 'bin_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_mode'] = 'exact numbers'
    if 'bin_equal_n' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_equal_n'] = True
    if 'end_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['end_mode'] = 'ignore'
    if 'bin_cents_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_cents_mode'] = 'median'

    bin_edges, bin_cents, _ = pkin.get_radius_binning(nbody_orbs, 
        **adaptive_binning_kwargs)

    # Compute velocity moments for the N-body data and DF samples
    nbody_vmom = compute_vmoment(nbody_orbs, n, bin_edges, n_bootstrap=n_bs,
        stdev_normalization=stdev_normalization, 
        pearson_correction=pearson_correction)
    sample_vmom = compute_vmoment(sample_orbs, n, bin_edges, n_bootstrap=n_bs,
        stdev_normalization=stdev_normalization, 
        pearson_correction=pearson_correction)

    if raise_inverse_power:
        nbody_vmom = nbody_vmom**(1/n)
        sample_vmom = sample_vmom**(1/n)

    # Compute the mass profile for the N-body data
    mass_profile = np.zeros(len(bin_cents))
    rs = nbody_orbs.r().to_value(apu.kpc)
    for i in range(len(bin_cents)):
        mass_profile[i] = np.sum(nbody_mass[(rs > bin_edges[i]) &\
                                            (rs < bin_edges[i+1])])

    # Compute the inter-sigma range for the N-body data, which will be the error
    nbody_err = np.zeros((3,len(bin_cents)))
    for i in range(3):
        nbody_err[i] = np.percentile(nbody_vmom[i], 84, axis=0) - \
                       np.percentile(nbody_vmom[i], 16, axis=0)

    # Compute the mass-error-weighted deviation between the N-body and DF 
    # sample trends
    mewd = np.zeros((3,n_bs))
    for i in range(3):
        mewd[i] = np.sum( np.abs(nbody_vmom[i] - sample_vmom[i])*\
                            mass_profile/nbody_err[i], axis=1 )/\
                    np.sum(mass_profile)

    return mewd

def plot_velocity_moments(orbs, sample, bin_edges, bin_cents, moms, n_bs=100,
                         plot_log=False, stdev_normalization=True, 
                         pearson_correction=True):
    if isinstance(stdev_normalization, bool):
        stdev_normalization = [stdev_normalization]*len(moms)
    if isinstance(pearson_correction, bool):
        pearson_correction = [pearson_correction]*len(moms)
    if isinstance(plot_log, bool):
        plot_log = [plot_log]*len(moms)

    vtext = [r'v_{r}', r'v_{\phi}', r'v_{\theta}']

    fig = plt.figure(figsize=(len(moms)*4,10))
    axs = fig.subplots(nrows=3, ncols=len(moms))

    for k,m in enumerate(moms):
        vm_nbody = compute_vmoment(orbs, m, bin_edges, n_bootstrap=n_bs,
            stdev_normalization=stdev_normalization[k],
            pearson_correction=pearson_correction[k])
        vm_sample = compute_vmoment(sample, m, bin_edges, n_bootstrap=n_bs,
            stdev_normalization=stdev_normalization[k],
            pearson_correction=pearson_correction[k])

        for l in range(3):
            vl, vm, vu = np.percentile(vm_nbody[l], [16,50,84], axis=0)
            axs[l,k].plot(bin_cents, vm, color='Black', alpha=1.0)
            axs[l,k].fill_between(bin_cents, vl, vu, color='Black', 
                alpha=0.3)

            vl, vm, vu = np.percentile(vm_sample[l], [16,50,84], axis=0)
            axs[l,k].plot(bin_cents, vm, color='Red', alpha=1.0)
            axs[l,k].fill_between(bin_cents, vl, vu, color='Red', 
                alpha=0.3)

            axs[l,k].set_xscale('log')
            if plot_log[k]: axs[l,k].set_yscale('log')


            if m == 'mean':
                axs[l,k].set_ylabel(r'$\overline{'+vtext[l]+r'}$')
            elif m == 'std':
                axs[l,k].set_ylabel(r'$\sigma_{'+vtext[l]+r'}$')
            elif m == 'meansq':
                axs[l,k].set_ylabel(r'$\overline{'+vtext[l]+r'^2}$')
            else:
                axs[l,k].set_ylabel(r'$\mu^{'+str(m)+r'}_{'+vtext[l]+r'}$')

        axs[2,k].set_xlabel(r'$r\,(\mathrm{kpc})$')
    
    return fig, axs
    

### Compute the mean, std, 3rd, 4th order moments and deltas for the N-body data and DF samples

In [None]:
verbose = True
make_moment_plot = False

# Pathing
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function')
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Moment calculation keywords
n_bs = 10
raise_inverse_power = False
# moms = ['mean','std',3,4]
moms = ['std',]
stdev_normalization = [False, False, True, True]
pearson_correction = [False, False, False, True]

# Get the sample orbits
sample_data_cb = np.load(os.path.join(analysis_dir,'sample_data_cb.npy'),
    allow_pickle=True)
sample_data_om = np.load(os.path.join(analysis_dir,'sample_data_om.npy'),
    allow_pickle=True)
sample_data_om2 = np.load(os.path.join(analysis_dir,'sample_data_om2.npy'),
    allow_pickle=True)

# Begin logging
os.makedirs('./log/',exist_ok=True)
log_filename = './log/6_higher_order_moments.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning higher order moment computation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

mewd_data = []
mewd_data_dtype_list = []
make_mewd_data_dtype_list = True

for i in range(n_mw):
    # if i != 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    for j in range(n_major):
        # if j > 0: continue

        if verbose: 
            msg = f'Calculating moment profiles for MW analog '+\
                  f'{i+1}/{n_mw}, merger {j+1}/{n_major}'
            logging.info(msg)
            print(msg, end='\r')

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        n_star = len(orbs)
        star_mass = co.get_masses('stars')[indx].to_value(apu.Msun)
        r = orbs.r().to_value(apu.kpc)
        r_softening = putil.get_softening_length('stars', z=0, physical=True)
        rmin = np.max([np.min(r), r_softening])

        # Define the adaptive binning keyword dict and then bin
        n_bin = np.min([500, n_star//10]) # n per bin
        adaptive_binning_kwargs = {'n':n_bin,
                                   'rmin':rmin,
                                   'rmax':np.max(r),
                                   'bin_mode':'exact numbers',
                                   'bin_equal_n':True,
                                   'end_mode':'ignore',
                                   'bin_cents_mode':'median'}
        bin_edges, bin_cents, _ = pkin.get_radius_binning(orbs, 
            **adaptive_binning_kwargs)
        bin_size = bin_edges[1:] - bin_edges[:-1]

        # Loop over each set of DF samples and compute moments and statistics
        sample_data_arr = [sample_data_cb,
                           sample_data_om,
                           sample_data_om2]
        sample_data_suffix = ['cb','om','om2']

        _data = (
            z0_sid,
            major_acc_sid,
            major_mlpid,
            j+1,
        )
        if make_mewd_data_dtype_list:
            mewd_data_dtype_list.extend([
                ('z0_sid',int),
                ('major_acc_sid',int),
                ('major_mlpid',int),
                ('merger_number',int),
            ])

        for k in range(len(sample_data_arr)):

            sample_data = sample_data_arr[k]

            # Mask the sample data correctly
            mask = (sample_data['z0_sid'] == z0_sid) &\
                   (sample_data['major_acc_sid'] == major_acc_sid) &\
                   (sample_data['major_mlpid'] == major_mlpid) &\
                   (sample_data['merger_number'] == j+1)
            indx = np.where(mask)[0]
            assert len(indx) == 1, 'Something went wrong'
            sample = sample_data[indx[0]]['sample']

            # Make plots of the moments
            if make_moment_plot:
                # moms = [1,2,3,4]
                plot_log = [False, False, False, False]
                fig,axs = plot_velocity_moments(orbs, sample, bin_edges, 
                    bin_cents, moms, n_bs=n_bs, plot_log=plot_log, 
                    stdev_normalization=stdev_normalization, 
                    pearson_correction=pearson_correction)

                fig.tight_layout()
                this_fig_dir = os.path.join(fig_dir, str(z0_sid), 
                    'merger_'+str(j+1))
                os.makedirs(this_fig_dir, exist_ok=True)
                this_figname = os.path.join(this_fig_dir,
                    'moments_'+'_'.join([str(m) for m in moms])+\
                    '_'+sample_data_suffix[k]+'.png')
                fig.savefig(this_figname, dpi=300, bbox_inches='tight')
                if not show_plots: plt.close(fig)
                else: plt.show()
            
            # Loop over each moment and compute the MEWD
            for l in range(len(moms)):
                mewd = compute_mass_error_weighted_deviation_vmoment(orbs,
                    sample, star_mass, moms[l], n_bs=n_bs, 
                    adaptive_binning_kwargs=adaptive_binning_kwargs,
                    raise_inverse_power=raise_inverse_power, 
                    stdev_normalization=stdev_normalization[l], 
                    pearson_correction=pearson_correction[l])
                _data += (mewd[0], mewd[1], mewd[2])
                if make_mewd_data_dtype_list:
                    mewd_data_dtype_list.extend([
                        ('mewd_'+sample_data_suffix[k]+'_vr'+str(moms[l]),object),
                        ('mewd_'+sample_data_suffix[k]+'_vp'+str(moms[l]),object),
                        ('mewd_'+sample_data_suffix[k]+'_vt'+str(moms[l]),object),
                    ])
        
        for l in range(len(moms)):
            mewd_self = compute_mass_error_weighted_deviation_vmoment(
                orbs, orbs, star_mass, moms[l], n_bs=n_bs, 
                adaptive_binning_kwargs=adaptive_binning_kwargs,
                raise_inverse_power=raise_inverse_power, 
                stdev_normalization=stdev_normalization[l], 
                pearson_correction=pearson_correction[l])
            _data += (mewd_self[0], mewd_self[1], mewd_self[2])
            if make_mewd_data_dtype_list:
                mewd_data_dtype_list.extend([
                    ('mewd_self_vr'+str(moms[l]),object),
                    ('mewd_self_vp'+str(moms[l]),object),
                    ('mewd_self_vt'+str(moms[l]),object),
                ])

        mewd_data.append(_data)
        make_mewd_data_dtype_list = False
        

# Save the data as a pickle
header = [f'kewords: stdev_normalization={stdev_normalization}, pearson_correctino={pearson_correction}, '+\
          f'raise_inverse_power={raise_inverse_power}, n_bs={n_bs}',
    [mewd_data_dtype_list[i][0] for i in range(len(mewd_data_dtype_list))]
]
mewd_data_filename = os.path.join(analysis_dir,
    'mewd_data_vmom_'+'_'.join([str(m) for m in moms]))
with open(mewd_data_filename+'.pkl','wb') as handle:
    pickle.dump([header,mewd_data], handle)

# Also save as a structured array
mewd_data_dtype = np.dtype(mewd_data_dtype_list)
mewd_data = np.array(mewd_data, dtype=mewd_data_dtype)
np.save(os.path.join(analysis_dir,
    'mewd_data_vmom_'+'_'.join([str(m) for m in moms])+'.npy'), mewd_data)

### Get the stashed data

In [None]:
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Or load the structured arrays
# moms = [1,2,3,4]
moms = ['mean','std',3,4]
mewd_data_filename = os.path.join(analysis_dir,
    'mewd_data_vmom_'+'_'.join([str(m) for m in moms])+'.npy')
mewd_data = np.load(mewd_data_filename, allow_pickle=True)

# Load the merger information
merger_data = np.load(os.path.join(analysis_dir,'merger_data.npy'), 
    allow_pickle=True)

checks = True
if checks:
    assert np.all( mewd_data['z0_sid'] == merger_data['z0_sid'] ), \
        'Something went wrong'
    assert np.all( mewd_data['major_acc_sid'] == merger_data['major_acc_sid'] ), \
        'Something went wrong'
    assert np.all( mewd_data['major_mlpid'] == merger_data['major_mlpid'] ), \
        'Something went wrong'

### Show a comparison of the MEWD values, Moments 1-4

In [None]:
columnwidth, textwidth = pplot.get_latex_columnwidth_textwidth_inches()

# fig = plt.figure(figsize=(textwidth,5))
# axs = fig.subplots(nrows=3, ncols=4)

df_names = ['cb','om','om2']
velocity_names = ['vr','vp','vt']
velocity_labels = [r'v_r',r'v_\phi',r'v_\theta']

n_merger = len(mewd_data)
xlims = [[0.4,2],
         [0.4,2],
         [0.4,2],
         [0.4,2]]
ylims = [[0.4,10],
         [0.4,10],
         [0.4,10],
         [0.4,10]]
marker_size = 4
marker_alpha = 0.5
ticklabel_fs = 8
xaxis_label_fs = 10
yaxis_label_fs = 10
panel_label_fs = 8
title_fs = 10

for g in range(len(df_names)):
    df_name = df_names[g]

    fig = plt.figure(figsize=(textwidth,5))
    axs = fig.subplots(nrows=3, ncols=4)

    for i in range(len(moms)):


        for j in range(len(velocity_names)):
            
            x = mewd_data['mewd_self_'+velocity_names[j]+str(moms[i])]
            y = mewd_data['mewd_'+df_name+'_'+velocity_names[j]+str(moms[i])]

            for k in range(n_merger):
                
                xl, xm, xu = np.percentile(x[k], [16,50,84])
                yl, ym, yu = np.percentile(y[k], [16,50,84])

                # axs[j,i].errorbar(xm, ym, xerr=[[xm-xl],[xu-xm]],
                #     yerr=[[ym-yl],[yu-ym]], fmt='o',
                #     markersize=marker_size, color='Black', alpha=marker_alpha)
                axs[j,i].scatter(xm, ym, s=marker_size, color='Black', 
                    alpha=marker_alpha)
                axs[j,i].axline([0.1,0.1],[1,1], color='Black', linestyle='--')
        
            # axs[j,i].set_xscale('log')
            axs[j,i].set_yscale('log')
            # axs[j,i].set_xlim(xlims[i])
            # axs[j,i].set_ylim(ylims[i])

            if i != 0:
                axs[j,i].tick_params(labelleft=False)
            else:
                axs[j,i].set_ylabel('$\delta_{NB-DF}$', 
                    fontsize=yaxis_label_fs)
                axs[j,i].tick_params(labelsize=ticklabel_fs)

            if j != 2:
                # pass
                axs[j,i].xaxis.set_major_formatter(plt.NullFormatter())
                axs[j,i].xaxis.set_minor_formatter(plt.NullFormatter())
                axs[j,i].tick_params(labelbottom=False)
            else:
                axs[j,i].set_xlabel('$\delta_{NB-NB}$', fontsize=xaxis_label_fs)
                axs[j,i].tick_params(labelsize=ticklabel_fs)

            if moms[i] == 'mean':
                vtext = r'$\overline{'+velocity_labels[j]+r'}$'
            elif moms[i] == 'std':
                vtext = r'$\sigma_{'+velocity_labels[j]+r'}$'
            elif moms[i] == 'meansq':
                vtext = r'$\overline{'+velocity_labels[j]+r'^2}$'
            else:
                vtext = r'$\mu^{'+str(moms[i])+r'}_{'+velocity_labels[j]+r'}$'

            axs[j,i].text(0.95,0.15, vtext,
                ha='right', va='top', transform=axs[j,i].transAxes, 
                fontsize=panel_label_fs)

    fig.suptitle(df_name.upper()+'DF', fontsize=title_fs)
    fig.tight_layout()
    this_figname = os.path.join(local_fig_dir,
        'mewd_vmom_comparison_'+df_name+'.png')
    fig.savefig(this_figname, dpi=300, bbox_inches='tight')
    if not show_plots: plt.close(fig)


### Compute histograms of the Nbody - DF delta divided by the Nbody - Nbody delta

In [None]:
columnwidth, textwidth = pplot.get_latex_columnwidth_textwidth_inches()

fig = plt.figure(figsize=(textwidth,5))
axs = fig.subplots(nrows=3, ncols=4)

df_names = ['cb','om','om2']
df_colors = ['DodgerBlue','Red','Black']
df_linewidths = [4.0, 2.0, 1.0]
df_zorders = [1,2,3]
df_linestyles = ['solid','solid','dashed']
velocity_names = ['vr','vp','vt']
velocity_labels = [r'v_r',r'v_\phi',r'v_\theta']
n_merger = len(mewd_data)
# xlims = [[0.4,2],
#          [0.4,2],
#          [0.4,2],
#          [0.4,2]]
# ylims = [[0.4,100],
#          [0.4,100],
#          [0.4,100],
#          [0.4,100]]
marker_size = 4
marker_alpha = 0.5
ticklabel_fs = 8
xaxis_label_fs = 8
yaxis_label_fs = 10
panel_label_fs = 8
title_fs = 10

for g in range(len(df_names)):
    # if g > 0: continue
    df_name = df_names[g]

    for i in range(len(moms)):


        for j in range(len(velocity_names)):
            
            x = np.concatenate(
                mewd_data['mewd_self_'+velocity_names[j]+str(moms[i])]
                )
            y = np.concatenate(
                mewd_data['mewd_'+df_name+'_'+velocity_names[j]+str(moms[i])]
                )
            
            axs[j,i].hist(np.log10(y/x), bins=15, histtype='step', 
                range=(-0.5,1.5), color=df_colors[g], 
                zorder=df_zorders[g], linestyle=df_linestyles[g],
                linewidth=df_linewidths[g], 
                density=True)
        
            # axs[j,i].set_xscale('log')
            # axs[j,i].set_yscale('log')
            # axs[j,i].set_xlim(xlims[i])
            # axs[j,i].set_ylim(ylims[i])

            if i != 0:
                axs[j,i].tick_params(labelleft=False)
            else:
                axs[j,i].set_ylabel('Density', 
                    fontsize=yaxis_label_fs)
                axs[j,i].tick_params(labelsize=ticklabel_fs)

            if j != 2:
                # pass
                # axs[j,i].xaxis.set_major_formatter(plt.NullFormatter())
                # axs[j,i].xaxis.set_minor_formatter(plt.NullFormatter())
                axs[j,i].tick_params(labelbottom=False)
            else:
                axs[j,i].set_xlabel(
                    '$\log_{10}[ \delta_{NB-DF}/\delta_{NB-NB} ]$', 
                    fontsize=xaxis_label_fs)
                axs[j,i].tick_params(labelsize=ticklabel_fs)

            if moms[i] == 'mean':
                vtext = r'$\overline{'+velocity_labels[j]+r'}$'
            elif moms[i] == 'std':
                vtext = r'$\sigma_{'+velocity_labels[j]+r'}$'
            elif moms[i] == 'meansq':
                vtext = r'$\overline{'+velocity_labels[j]+r'^2}$'
            else:
                vtext = r'$\mu^{'+str(moms[i])+r'}_{'+velocity_labels[j]+r'}$'

            axs[j,i].text(0.95,0.95, vtext,
                ha='right', va='top', transform=axs[j,i].transAxes, 
                fontsize=panel_label_fs)

            # Turn on gridlines
            axs[j,i].xaxis.set_minor_locator(plt.MultipleLocator(0.1))
            axs[j,i].grid(True, linestyle='solid', alpha=1)
            axs[j,i].grid(which='minor', linestyle='--', alpha=0.5)

# Make the legend
for i in range(len(df_names)):
    axs[0,0].plot([],[],color=df_colors[i],label=df_names[i].upper(),
        linestyle=df_linestyles[i], linewidth=df_linewidths[i])
axs[0,0].legend(loc='center right', fontsize=ticklabel_fs)

# fig.suptitle(df_name.upper()+' DF', fontsize=title_fs)
fig.tight_layout()
this_figname = os.path.join(local_fig_dir,
    'mewd_vmom_hist.png')
fig.savefig(this_figname, dpi=300, bbox_inches='tight')
if not show_plots: plt.close(fig)