In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 2_plot_anisotropic_dfs.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Use density profile fits to construct anisotropic DFs
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle
import pdb, copy, glob, time

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu
from astropy import constants as apc

## Analysis
import scipy.stats
import scipy.interpolate

## galpy
from galpy import orbit
from galpy import potential
from galpy import actionAngle as aA
from galpy import df
from galpy import util as gputil

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
# epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/sample/'
# os.makedirs(epsen_fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

## Plotting functions

In [None]:
def plot_ELz(nbody_orbs, sample_orbs, nbody_E, pot, interpot=None, 
    plot_hist=True, fig=None, axs=None):
    '''Plot energy and angular momentum'''

    # Some plotting kwargs
    Lz_range = [-3,3]
    label_fs = 12

    # Set up figure
    if fig is None or axs is None:
        fig = plt.figure(figsize=(12,3))
        axs = fig.subplots(nrows=1, ncols=3)

    # N-body properties
    nbody_Lz = nbody_orbs.Lz().to_value(apu.kpc*apu.km/apu.s)

    # Sample properties
    sample_Lz = sample_orbs.Lz().to_value(apu.kpc*apu.km/apu.s)
    sample_potE = sample_orbs.E(pot=pot).to_value(apu.km**2/apu.s**2)
    sample_interE = sample_orbs.E(pot=interpot).to_value(apu.km**2/apu.s**2)

    # Angular momentum vs energy
    Lzs = [nbody_Lz, sample_Lz, sample_Lz]
    energies = [nbody_E, sample_potE, sample_interE]
    labels = ['N-body',
            'Samples (interpolated pot)', 
            'Samples (Best-fit pot)']
    for k in range(3):
        E_range = [np.nanmin(energies[k])/1e5, np.nanmax(energies[k])/1e5]
        if plot_hist:
            H, xedges, yedges = np.histogram2d(Lzs[k]/1e3, energies[k]/1e5, 
                bins=[45,30], range=[Lz_range,E_range])
            H = np.rot90(H)
            H = np.flipud(H)
            Hmasked = np.ma.masked_where(H==0,H)
            cmap = mpl.colormaps.get_cmap('viridis')
            cmap.set_bad(color='white')
            axs[k].pcolormesh(xedges,yedges,Hmasked,cmap=cmap)        
        else:
            axs[k].scatter(Lzs[k], energies[k], s=1, color='Black', alpha=0.1)
        
        # Decorate
        axs[k].axvline(0, linestyle='dashed', linewidth=1., color='Grey')
        _annotate_bbox_kwargs = dict(facecolor='White', edgecolor='Black', 
            fill=True, alpha=0.5)
        axs[k].annotate(labels[k], xy=(0.05,0.05), xycoords='axes fraction',
            fontsize=8, bbox=_annotate_bbox_kwargs)
        axs[k].set_xlabel(r'Lz [$10^{3}$ kpc km/s]', fontsize=label_fs)
        axs[k].set_ylabel(r'E [$10^{5}$ km$^{2}$/s$^{2}$]', fontsize=label_fs)
        
        axs[k].set_xlim(Lz_range[0],Lz_range[1])
        axs[k].set_ylim(E_range[0],E_range[1])

    fig.tight_layout()
    return fig, axs

def plot_beta_vdisp(nbody_orbs, sample_orbs, n_bs=100, fig=None, axs=None):
    '''Plot the beta and then the velocity dispersion'''

    # Some kwargs for plotting
    nbody_color = 'Black'
    sample_color = 'DodgerBlue'

    # Set up figure
    if fig is None or axs is None:
        fig = plt.figure(figsize=(5,12))
        axs = fig.subplots(nrows=4, ncols=1)

    # Binning for velocity dispersions and betas
    n_bin = np.min([500, len(nbody_orbs)//10]) # n per bin
    adaptive_binning_kwargs = {
        'n':n_bin,
        'rmin':0.,
        'rmax':np.max( nbody_orbs.r().to_value(apu.kpc) ),
        'bin_mode':'exact numbers',
        'bin_equal_n':True,
        'end_mode':'ignore',
        'bin_cents_mode':'median',
    }
    bin_edges, bin_cents, bin_n = pkin.get_radius_binning(nbody_orbs, 
        **adaptive_binning_kwargs)

    # Compute velocity dispersions for N-body
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    nbody_beta, nbody_vr2, nbody_vp2, nbody_vz2 = \
        pkin.compute_betas_bootstrap(nbody_orbs, bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Compute velocity dispersions for the DF samples
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    sample_beta, sample_vr2, sample_vp2, sample_vz2 = \
        pkin.compute_betas_bootstrap(sample_orbs, bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Beta for the N-body
    axs[0].plot(bin_cents, np.median(nbody_beta, axis=0), color=nbody_color, 
        label='N-body')
    axs[0].fill_between(bin_cents, np.percentile(nbody_beta, 16, axis=0),
        np.percentile(nbody_beta, 84, axis=0), color=nbody_color, alpha=0.25)

    # Beta for the DF samples
    axs[0].plot(bin_cents, np.median(sample_beta, axis=0), color=sample_color, 
        label='DF Samples')
    axs[0].fill_between(bin_cents, np.percentile(sample_beta, 16, axis=0),
        np.percentile(sample_beta, 84, axis=0), color=sample_color, alpha=0.25)

    # Velocity dispersions for the N-body
    sv2s = [nbody_vr2,nbody_vp2,nbody_vz2]
    v_suffixes = ['r','\phi','z']
    for k in range(3):
        axs[k+1].plot(bin_cents, np.sqrt(np.median(sv2s[k], axis=0)), 
            color=nbody_color)
        axs[k+1].fill_between(bin_cents, 
            np.sqrt(np.percentile(sv2s[k], 16, axis=0)),
            np.sqrt(np.percentile(sv2s[k], 84, axis=0)), 
            color=nbody_color, alpha=0.25)
        axs[k+1].set_ylabel(r'$\sigma_'+v_suffixes[k]+r'$')

    # Velocity dispersions for the DF samples
    sv2s = [sample_vr2,sample_vp2,sample_vz2]
    v_suffixes = ['r','\phi','z']
    for k in range(3):
        axs[k+1].plot(bin_cents, np.sqrt(np.median(sv2s[k], axis=0)), 
            color=sample_color)
        axs[k+1].fill_between(bin_cents, 
            np.sqrt(np.percentile(sv2s[k], 16, axis=0)),
            np.sqrt(np.percentile(sv2s[k], 84, axis=0)), 
            color=sample_color, alpha=0.25)
        axs[k+1].set_ylabel(r'$\sigma_'+v_suffixes[k]+r'$')

    # Labels
    axs[0].set_ylabel(r'$\beta$')
    axs[0].legend()
    for k in range(4):
        axs[k].set_xscale('log')
        axs[k].set_xlabel(r'$r$ [kpc]')
        if k > 0:
            axs[k].set_yscale('log')
    fig.tight_layout()
    
    return fig, axs

def compute_mass_error_weighted_deviation_beta_vdisp(nbody_orbs, sample_orbs, nbody_mass,
    n_bs=100, adaptive_binning_kwargs={}, velocity_quantities_squared=False):
    '''compute_mass_error_weighted_deviation_beta_vdisp:
    
    Compute the mass- and uncertainty-weighted deviation of the N-body 
    velocity dispersion / beta trend from the DF samples. 

    For the binning scheme the default kwargs are:
    - n: min(500, number of N-body particles//10)
    - rmin: 0.
    - rmax: max(N-body particle radii)
    - bin_mode: 'exact numbers'
    - bin_equal_n: True
    - end_mode: 'ignore'
    - bin_cents_mode: 'median'

    Args:
        nbody_orbs (galpy.orbit.Orbit): N-body orbits
        sample_orbs (galpy.orbit.Orbit): DF samples
        nbody_mass (np.ndarray): N-body particle masses
        n_bs (int): Number of times to bootstrap the DF/N-body samples
            to compute the deviation statistic for error estimation
        adaptive_binning_kwargs (dict): kwargs for get_radius_binning(), will
            be populated with defaults listed above if not provided.
        velocity_quantities_squared (bool): If True, use the squared velocity 
            dispersions/mean squares that are output from 
            pkin.compute_betas_bootstrap(). If False, take the square root of
            these quantities.
    
    Returns:
        mwed_[beta,vr2,vp2,vt2] (np.ndarray): Mass-weighted error deviation
    '''

    # Binning for velocity dispersions and betas
    n_bin = np.min([500, len(nbody_orbs)//10]) # n per bin
    if 'n' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['n'] = n_bin
    if 'rmin' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['rmin'] = 0.
    if 'rmax' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['rmax'] = np.max( nbody_orbs.r().to_value(apu.kpc) )
    if 'bin_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_mode'] = 'exact numbers'
    if 'bin_equal_n' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_equal_n'] = True
    if 'end_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['end_mode'] = 'ignore'
    if 'bin_cents_mode' not in adaptive_binning_kwargs.keys():
        adaptive_binning_kwargs['bin_cents_mode'] = 'median'

    adaptive_binning_kwargs = {
        'n':n_bin,
        'rmin':0.,
        'rmax':np.max( nbody_orbs.r().to_value(apu.kpc) ),
        'bin_mode':'exact numbers',
        'bin_equal_n':True,
        'end_mode':'ignore',
        'bin_cents_mode':'median',
    }
    bin_edges, bin_cents, _ = pkin.get_radius_binning(nbody_orbs, 
        **adaptive_binning_kwargs)
    bin_size = bin_edges[1:] - bin_edges[:-1]

    # Compute velocity dispersions for N-body
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    nbody_beta, nbody_vr2, nbody_vp2, nbody_vt2 = \
        pkin.compute_betas_bootstrap(nbody_orbs, bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    # Compute velocity dispersions for the DF samples
    compute_betas_kwargs = {'use_dispersions':True,
                            'return_kinematics':True}
    sample_beta, sample_vr2, sample_vp2, sample_vt2 = \
        pkin.compute_betas_bootstrap(sample_orbs, bin_edges, n_bootstrap=n_bs, 
        compute_betas_kwargs=compute_betas_kwargs)

    if not velocity_quantities_squared:
        nbody_vr2 = np.sqrt(nbody_vr2)
        nbody_vp2 = np.sqrt(nbody_vp2)
        nbody_vt2 = np.sqrt(nbody_vt2)
        sample_vr2 = np.sqrt(sample_vr2)
        sample_vp2 = np.sqrt(sample_vp2)
        sample_vt2 = np.sqrt(sample_vt2)

    # Compute the mass profile for the N-body data
    mass_profile = np.zeros(len(bin_cents))
    rs = nbody_orbs.r().to_value(apu.kpc)
    for i in range(len(bin_cents)):
        mass_profile[i] = np.sum(nbody_mass[(rs > bin_edges[i]) & (rs < bin_edges[i+1])])

    # Compute the inter-sigma range for the N-body data, which will be the error
    nbody_beta_err = np.percentile(nbody_beta, 84, axis=0) - np.percentile(nbody_beta, 16, axis=0)
    nbody_vr2_err = np.percentile(nbody_vr2, 84, axis=0) - np.percentile(nbody_vr2, 16, axis=0)
    nbody_vp2_err = np.percentile(nbody_vp2, 84, axis=0) - np.percentile(nbody_vp2, 16, axis=0)
    nbody_vt2_err = np.percentile(nbody_vt2, 84, axis=0) - np.percentile(nbody_vt2, 16, axis=0)

    # Compute the mass-error-weighted deviation between the N-body and DF sample trends
    mewd_beta = np.sum( np.abs(nbody_beta - sample_beta)*mass_profile/nbody_beta_err, axis=1 )/np.sum(mass_profile)
    mewd_vr2 = np.sum( np.abs(nbody_vr2 - sample_vr2)*mass_profile/nbody_vr2_err, axis=1 )/np.sum(mass_profile)
    mewd_vp2 = np.sum( np.abs(nbody_vp2 - sample_vp2)*mass_profile/nbody_vp2_err, axis=1 )/np.sum(mass_profile)
    mewd_vt2 = np.sum( np.abs(nbody_vt2 - sample_vt2)*mass_profile/nbody_vt2_err, axis=1 )/np.sum(mass_profile)

    return mewd_beta, mewd_vr2, mewd_vp2, mewd_vt2
    

### Plot the constant beta DFs

In [None]:
verbose = True
epsen_dens_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/'
epsen_df_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'distribution_function/'
fig_dir = './fig/constant_beta/'
os.makedirs(fig_dir,exist_ok=True)

# DM halo information
dm_halo_version = 'poisson_nfw'
dm_halo_ncut = 500

# Stellar bulge and disk information
stellar_bulge_disk_version = 'miyamoto_disk_pswc_bulge_tps_halo'
stellar_bulge_disk_ncut = 2000

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Beta information
beta_version = 'constant_beta'
beta_ncut = 500

# Define density profiles
dm_halo_densfunc = pdens.NFWSpherical()
disk_densfunc = pdens.MiyamotoNagaiDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
stellar_halo_densfunc = pdens.TwoPowerSpherical()
stellar_bulge_disk_densfunc = pdens.CompositeDensityProfile(
    [disk_densfunc,
     bulge_densfunc,
     pdens.TwoPowerSpherical()]
     )

mwd_cb = []
mwd_cb_self = []

for i in range(n_mw):
    # if i not in [7,8,9,10]: continue
    if verbose: print(f'Plotting MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # # Get the dark halo
    dm_halo_filename = os.path.join(epsen_dens_fitting_dir,'dm_halo/',dm_halo_version,
        str(z0_sid), 'sampler.pkl')
    dm_halo_pot = pfit.construct_pot_from_fit(dm_halo_filename,
        dm_halo_densfunc, dm_halo_ncut, ro=ro, vo=vo)
    
    # Get the stellar bulge and disk
    stellar_bulge_disk_filename = os.path.join(epsen_dens_fitting_dir,
        'stellar_bulge_disk/',stellar_bulge_disk_version,str(z0_sid),
        'sampler.pkl')
    stellar_pots = pfit.construct_pot_from_fit(stellar_bulge_disk_filename,
        stellar_bulge_disk_densfunc, stellar_bulge_disk_ncut, ro=ro, vo=vo)
    fpot = [stellar_pots[1], stellar_pots[0], dm_halo_pot] # bulge, disk, halo

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(epsen_dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    _mwd = []
    _mwd_self = []

    for j in range(n_major):
        # if j != 0: continue
        if verbose: print(f'Constructing DF for major merger {j+1}/{n_major}')

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        n_star = len(orbs)
        star_mass = co.get_masses('stars')[indx].to_value(apu.Msun)
        pe = co.get_potential_energy('stars')[indx].to_value(apu.km**2/apu.s**2)
        vels = co.get_velocities('stars')[indx].to_value(apu.km/apu.s)
        vmag = np.linalg.norm(vels,axis=1)
        energy = pe + 0.5*vmag**2
        Lz = orbs.Lz().to_value(apu.kpc*apu.km/apu.s)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Get the stellar halo rotation kernel
        stellar_halo_rotation_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_rotation_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        assert os.path.exists(stellar_halo_rotation_filename)
        with open(stellar_halo_rotation_filename,'rb') as handle:
            stellar_halo_rotation_sampler = pickle.load(handle)
        stellar_halo_rotation_samples = stellar_halo_rotation_sampler.get_chain(
            discard=stellar_halo_rotation_ncut, flat=True)
        # Params are frot, chi
        stellar_halo_frot, stellar_halo_chi = \
            np.median(stellar_halo_rotation_samples,axis=0)
        
        # try:
        # Load the distribution function and wrangle
        df_filename = os.path.join(epsen_df_fitting_dir,beta_version,
            str(z0_sid),'merger_'+str(j+1),'df.pkl')
        with open(df_filename,'rb') as handle:
            dfcb = pickle.load(handle)
        dfcb = pkin.reconstruct_anisotropic_df(dfcb, interpot, denspot)
        
        # Create sample and apply rotation
        sample = dfcb.sample(n=n_star, rmin=rs.min()*apu.kpc*0.9)
        sample = pkin.rotate_df_samples(sample,stellar_halo_frot,stellar_halo_chi)
        # except Exception as e:
        # print('Caught an error when loading DF / sampling:',e,'continuing')
        # continue

        _mwd.append( compute_mass_error_weighted_deviation_beta_vdisp(orbs, 
            sample, star_mass, n_bs=10) )
        _mwd_self.append( compute_mass_error_weighted_deviation_beta_vdisp(orbs,
            orbs, star_mass, n_bs=10) )

        # ### Plotting
        # print('Plotting')
        # this_fig_dir = os.path.join(fig_dir, str(z0_sid), 'merger_'+str(j+1))
        # os.makedirs(this_fig_dir,exist_ok=True)

        # fig,axs = plot_ELz(orbs, sample, energy, fpot, interpot)
        # fig.tight_layout()
        # figname = os.path.join(this_fig_dir,'energy_Lz.png')
        # fig.savefig(figname, dpi=300)
        # plt.close(fig)

        # fig,axs = plot_beta_vdisp(orbs, sample)
        # fig.tight_layout()
        # figname = os.path.join(this_fig_dir,'velocity_dispersions.png')
        # fig.savefig(figname, dpi=300)
        # plt.close(fig)

    mwd_cb.append( _mwd )
    mwd_cb_self.append( _mwd_self )

os.makedirs('./data/', exist_ok=True)
with open('./data/mwd_cb.pkl','wb') as handle:
    pickle.dump(mwd_cb, handle)
with open('./data/mwd_cb_self.pkl','wb') as handle:
    pickle.dump(mwd_cb_self, handle)


### Plot the Osipkov-Merritt DFs

In [None]:
verbose = True
epsen_dens_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/'
epsen_df_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'distribution_function/'
fig_dir = './fig/osipkov_merritt/'
os.makedirs(fig_dir,exist_ok=True)

# DM halo information
dm_halo_version = 'poisson_nfw'
dm_halo_ncut = 500

# Stellar bulge and disk information
stellar_bulge_disk_version = 'miyamoto_disk_pswc_bulge_tps_halo'
stellar_bulge_disk_ncut = 2000

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Beta information
beta_version = 'osipkov_merritt'
beta_ncut = 500

# Define density profiles
dm_halo_densfunc = pdens.NFWSpherical()
disk_densfunc = pdens.MiyamotoNagaiDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
stellar_halo_densfunc = pdens.TwoPowerSpherical()
stellar_bulge_disk_densfunc = pdens.CompositeDensityProfile(
    [disk_densfunc,
     bulge_densfunc,
     pdens.TwoPowerSpherical()]
     )

mwd_om = []
mwd_om_self = []

for i in range(n_mw):
    # if i > 5: continue
    if verbose: print(f'Plotting MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # # Get the dark halo
    dm_halo_filename = os.path.join(epsen_dens_fitting_dir,'dm_halo/',dm_halo_version,
        str(z0_sid), 'sampler.pkl')
    dm_halo_pot = pfit.construct_pot_from_fit(dm_halo_filename,
        dm_halo_densfunc, dm_halo_ncut, ro=ro, vo=vo)
    
    # Get the stellar bulge and disk
    stellar_bulge_disk_filename = os.path.join(epsen_dens_fitting_dir,
        'stellar_bulge_disk/',stellar_bulge_disk_version,str(z0_sid),
        'sampler.pkl')
    stellar_pots = pfit.construct_pot_from_fit(stellar_bulge_disk_filename,
        stellar_bulge_disk_densfunc, stellar_bulge_disk_ncut, ro=ro, vo=vo)
    fpot = [stellar_pots[1], stellar_pots[0], dm_halo_pot] # bulge, disk, halo

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(epsen_dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    _mwd = []
    _mwd_self = []

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'Constructing DF for major merger {j+1}/{n_major}')

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        n_star = len(orbs)
        star_mass = co.get_masses('stars')[indx].to_value(apu.Msun)
        pe = co.get_potential_energy('stars')[indx].to_value(apu.km**2/apu.s**2)
        vels = co.get_velocities('stars')[indx].to_value(apu.km/apu.s)
        vmag = np.linalg.norm(vels,axis=1)
        energy = pe + 0.5*vmag**2
        Lz = orbs.Lz().to_value(apu.kpc*apu.km/apu.s)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Get the stellar halo rotation kernel
        stellar_halo_rotation_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_rotation_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        assert os.path.exists(stellar_halo_rotation_filename)
        with open(stellar_halo_rotation_filename,'rb') as handle:
            stellar_halo_rotation_sampler = pickle.load(handle)
        stellar_halo_rotation_samples = stellar_halo_rotation_sampler.get_chain(
            discard=stellar_halo_rotation_ncut, flat=True)
        # Params are frot, chi
        stellar_halo_frot, stellar_halo_chi = \
            np.median(stellar_halo_rotation_samples,axis=0)

        # try:
        # Load the distribution function and wrangle
        df_filename = os.path.join(epsen_df_fitting_dir,beta_version,
            str(z0_sid),'merger_'+str(j+1),'df.pkl')
        with open(df_filename,'rb') as handle:
            dfom = pickle.load(handle)
        dfom = pkin.reconstruct_anisotropic_df(dfom, interpot, denspot)
        
        # Create sample and apply rotation
        sample = dfom.sample(n=n_star, rmin=rs.min()*apu.kpc*0.9)
        sample = pkin.rotate_df_samples(sample,stellar_halo_frot,stellar_halo_chi)
        # except Exception as e:
        #     print('Caught an error when loading DF / sampling:',e,'continuing')
        #     continue
            
        _mwd.append( compute_mass_error_weighted_deviation_beta_vdisp(orbs, 
            sample, star_mass, n_bs=10) )
        _mwd_self.append( compute_mass_error_weighted_deviation_beta_vdisp(orbs,
            orbs, star_mass, n_bs=10) )
        
        # ### Plotting
        # print('Plotting')
        # this_fig_dir = os.path.join(fig_dir, str(z0_sid), 'merger_'+str(j+1))
        # os.makedirs(this_fig_dir,exist_ok=True)

        # fig,axs = plot_ELz(orbs, sample, energy, fpot, interpot)
        # fig.tight_layout()
        # figname = os.path.join(this_fig_dir,'energy_Lz.png')
        # fig.savefig(figname, dpi=300)
        # plt.close(fig)

        # fig,axs = plot_beta_vdisp(orbs, sample)
        # fig.tight_layout()
        # figname = os.path.join(this_fig_dir,'velocity_dispersions.png')
        # fig.savefig(figname, dpi=300)
        # plt.close(fig)
    
    mwd_om.append( _mwd )
    mwd_om_self.append( _mwd_self )

os.makedirs('./data/', exist_ok=True)
with open('./data/mwd_om.pkl','wb') as handle:
    pickle.dump(mwd_om, handle)
with open('./data/mwd_om_self.pkl','wb') as handle:
    pickle.dump(mwd_om_self, handle)

### Wrangle the data

In [None]:
# All MWD values should have been saved if already calculated, so don't need to 
# ask permission to overwrite
with open('./data/mwd_cb.pkl','rb') as handle:
    mwd_cb = pickle.load(handle)
with open('./data/mwd_cb_self.pkl','rb') as handle:
    mwd_cb_self = pickle.load(handle)

with open('./data/mwd_om.pkl','rb') as handle:
    mwd_om = pickle.load(handle)
with open('./data/mwd_om_self.pkl','rb') as handle:
    mwd_om_self = pickle.load(handle)

In [None]:
if not os.path.exists('./data/star_mass.pkl') or not os.path.exists('./data/dm_mass.pkl'):
    star_mass = []
    dm_mass = []

    verbose = True

    for i in range(n_mw):
        # if i > 1: continue
        if verbose: print(f'Getting MW {i+1}/{n_mw}')

        # Get the primary
        primary = tree_primaries[i]
        z0_sid = primary.subfind_id[0]
        n_snap = len(primary.snapnum)
        n_major = primary.n_major_mergers
        primary_filename = primary.get_cutout_filename(mw_analog_dir,
            snapnum=primary.snapnum[0])
        co = pcutout.TNGCutout(primary_filename)
        dmpid = co.get_property('dm','ParticleIDs')
        dmass = co.get_masses('dm').to_value(apu.Msun)
        spid = co.get_property('stars','ParticleIDs')
        smass = co.get_masses('stars').to_value(apu.Msun)

        _star_mass = []
        _dm_mass = []

        for j in range(n_major):
            if verbose: print(f'Merger {j+1}/{n_major}')

            # Get the major merger particle IDs and mask
            major_merger = primary.tree_major_mergers[j]
            dmupid = major_merger.get_unique_particle_ids('dm',data_dir=data_dir)
            supid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
            dmindx = np.isin(dmpid, dmupid)
            sindx = np.isin(spid, supid)

            _star_mass.append( np.sum(smass[sindx]) )
            _dm_mass.append( np.sum(dmass[dmindx]) )
        
        star_mass.append(_star_mass)
        dm_mass.append(_dm_mass)

    os.makedirs('./data/', exist_ok=True)
    with open('./data/star_mass.pkl','wb') as handle:
        pickle.dump(star_mass, handle)
    with open('./data/dm_mass.pkl','wb') as handle:
        pickle.dump(dm_mass, handle)
else:
    with open('./data/star_mass.pkl','rb') as handle:
        star_mass = pickle.load(handle)
    with open('./data/dm_mass.pkl','rb') as handle:
        dm_mass = pickle.load(handle)

In [None]:
### Finally construct a structured numpy array

# Compute the loglike differences
data = []

for i in range(n_mw):
    # if i > 5: continue
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    for j in range(n_major):

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_mlpid = major_merger.secondary_mlpid

        # l

        _data = (mwd_cb[i][j],
                 mwd_cb_self[i][j],
                 mwd_om[i][j],
                 mwd_om_self[i][j],
                 star_mass[i][j], 
                 dm_mass[i][j],
                 major_merger.star_mass_ratio,
                 major_merger.dm_mass_ratio,
                 putil.snapshot_to_redshift( major_merger.merger_snapnum ),
                 z0_sid, 
                 j+1, 
                 major_mlpid, 
                 )

        data.append( _data )
    
dt = np.dtype([('mwd_cb',object),
                ('mwd_cb_self',object),
                ('mwd_om',object),
                ('mwd_om_self',object),
                ('star_mass',float),
                ('dm_mass',float),
                ('star_mass_ratio',float),
                ('dm_mass_ratio',float),
                ('z_merger',float),
                ('z0_sid',int),
                ('major_merger',int),
                ('major_mlpid',int),
                ])
mwddata = np.array(data,dtype=dt)

os.makedirs('./data/', exist_ok=True)
np.save('./data/mwddata.npy', mwddata)

### Plot the mass-weighted $\sigma$ differences for each DF

In [None]:
mwd_cb_self_percs = np.zeros((4,3))

for k in range(4):
    p = np.array([])
    for i in range(len(mwd_cb_self)):
        for j in range(len(mwd_cb_self[i])):
            p = np.concatenate( (p, mwd_cb_self[i][j][k]) )
    mwd_cb_self_percs[k] = np.percentile(p, [16,50,84])

mwd_om_self_percs = np.zeros((4,3))

for k in range(4):
    p = np.array([])
    for i in range(len(mwd_om_self)):
        for j in range(len(mwd_om_self[i])):
            p = np.concatenate( (p, mwd_om_self[i][j][k]) )
    mwd_om_self_percs[k] = np.percentile(p, [16,50,84])


### Figure showing each MWD as a function of merger stellar mass

facecolor='none'
edgecolor='Black'
s=10

fig = plt.figure(figsize=(12,4))
axs = fig.subplots(nrows=1, ncols=4)

ylabels = [r'$\beta$', r'$\sigma_{r}$', r'$\sigma_{\phi}$', r'$\sigma_{\theta}$']
star_counter = 0

for i in range(n_mw):
    # if i > 5: continue
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    for j in range(n_major):

        for k in range(4):
            axs[k].scatter(np.log10(star_mass[i][j]), np.median(mwd_cb[i][j][k]), 
                        facecolor=facecolor, edgecolor=edgecolor, s=s )
            # axs[k].errorbar(np.log10(star_mass[i][j]), 
            #                 np.median(mwd_cb[i][j][k]), 
            #                 yerr=np.std(mwd_cb[i][j][k]),
            #             markerfacecolor=facecolor, color=edgecolor, 
            #             markersize=s, ecolor=edgecolor, capsize=2)
        star_counter += 1

for k in range(4):
    axs[k].set_xlabel(r'$\log_{10} M_{\star}$')
    axs[k].set_ylabel(r'CB $\delta$'+ylabels[k])
    if k == 0:
        axs[k].set_ylim(0,5)
    else:
        axs[k].set_ylim(0,20)
    axs[k].axhline(mwd_cb_self_percs[k,1], color='Black', ls='solid')
    axs[k].axhspan(mwd_cb_self_percs[k,0], mwd_cb_self_percs[k,2], 
        color='Black', alpha=0.25)

fig.tight_layout()
plt.show()

####################

fig = plt.figure(figsize=(12,4))
axs = fig.subplots(nrows=1, ncols=4)

for i in range(n_mw):
    # if i > 5: continue
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    for j in range(n_major):

        for k in range(4):
            axs[k].scatter(np.log10(star_mass[i][j]), np.median(mwd_om[i][j][k]), 
                        facecolor=facecolor, edgecolor=edgecolor, s=s )

        star_counter += 1

for k in range(4):
    axs[k].set_xlabel(r'$\log_{10} M_{\star}$')
    axs[k].set_ylabel(r'OM $\delta$'+ylabels[k])
    if k == 0:
        axs[k].set_ylim(0,5)
    else:
        axs[k].set_ylim(0,20)
    axs[k].axhline(mwd_om_self_percs[k,1], color='Black', ls='solid')
    axs[k].axhspan(mwd_om_self_percs[k,0], mwd_om_self_percs[k,2], 
        color='Black', alpha=0.25)


fig.tight_layout()
plt.show()

####################

fig = plt.figure(figsize=(12,4))
axs = fig.subplots(nrows=1, ncols=4)

for i in range(n_mw):
    # if i > 5: continue
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    for j in range(n_major):

        for k in range(4):
            axs[k].scatter(np.median(mwd_cb[i][j][k]), np.median(mwd_om[i][j][k]), 
                        facecolor=facecolor, edgecolor=edgecolor, s=s )

        star_counter += 1

for k in range(4):
    axs[k].set_xlabel(r'CB $\delta$'+ylabels[k])
    axs[k].set_ylabel(r'OM $\delta$'+ylabels[k])
    if k == 0:
        axs[k].set_xlim(0,5)
        axs[k].set_ylim(0,5)
    else:
        axs[k].set_xlim(0,10)
        axs[k].set_ylim(0,10)
    axs[k].axline(xy1 = [0,0], slope=1., color='k', ls='--')
    axs[k].axvline(mwd_cb_self_percs[k,1], color='Black', ls='solid')
    axs[k].axvspan(mwd_cb_self_percs[k,0], mwd_cb_self_percs[k,2], 
        color='Black', alpha=0.25)
    axs[k].axhline(mwd_om_self_percs[k,1], color='Black', ls='solid')
    axs[k].axhspan(mwd_om_self_percs[k,0], mwd_om_self_percs[k,2], 
        color='Black', alpha=0.25)

fig.tight_layout()
plt.show()

In [None]:
this_star_mass = []
this_zmerge = []
this_star_mass_ratio = []
this_dm_mass_ratio = []

ratio_beta = []
ratio_sigma_r = []
ratio_sigma_phi = []
ratio_sigma_theta = []

for i in range(n_mw):
    # if i > 5: continue
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    for j in range(n_major):
        major_merger = primary.tree_major_mergers[j]

        # Properties of the merger
        this_star_mass.append( star_mass[i][j] )
        this_zmerge.append( putil.snapshot_to_redshift( major_merger.merger_snapnum ) )
        this_star_mass_ratio.append( major_merger.star_mass_ratio )
        this_dm_mass_ratio.append( major_merger.dm_mass_ratio )
        
        ratio_beta.append( np.median(mwd_cb[i][j][0])/np.median(mwd_om[i][j][0]) )
        ratio_sigma_r.append( np.median(mwd_cb[i][j][1])/np.median(mwd_om[i][j][1]) )
        ratio_sigma_phi.append( np.median(mwd_cb[i][j][2])/np.median(mwd_om[i][j][2]) )
        ratio_sigma_theta.append( np.median(mwd_cb[i][j][3])/np.median(mwd_om[i][j][3]) )

this_star_mass = np.array(this_star_mass).flatten()
this_zmerge = np.array(this_zmerge).flatten()
this_star_mass_ratio = 1/np.array(this_star_mass_ratio).flatten()
this_dm_mass_ratio = 1/np.array(this_dm_mass_ratio).flatten()

ratio_beta = np.array(ratio_beta).flatten()
ratio_sigma_r = np.array(ratio_sigma_r).flatten()
ratio_sigma_phi = np.array(ratio_sigma_phi).flatten()
ratio_sigma_theta = np.array(ratio_sigma_theta).flatten()


fig = plt.figure(figsize=(12,12))
axs = fig.subplots(nrows=4, ncols=4).T

xlabels = [r'$M_{\star}$', r'$z_{\rm merger}$', 
           r'$M_{\rm \star,p}/M_{\rm \star,s}$', 
           r'$M_{\rm DM,p}/M_{\rm DM,s}$']

xs = [this_star_mass, this_zmerge, this_star_mass_ratio, this_dm_mass_ratio]
ys = [ratio_beta, ratio_sigma_r, ratio_sigma_phi, ratio_sigma_theta]

text_quantities = [r'$\beta$', r'$\sigma_{r}$', 
                   r'$\sigma_{\phi}$', r'$\sigma_{\theta}$']

for i in range(4):

    for j in range(4):

        axs[i,j].scatter(xs[i], ys[j], 
                        facecolor=facecolor, edgecolor=edgecolor, s=s )
        axs[i,j].axhline(1., color='Gray', linestyle='dashed')
        if i in [0,2,3]:
            axs[i,j].set_xscale('log')
        if i > 0:
            axs[i,j].tick_params(labelleft=False)
        if j < 3:
            axs[i,j].tick_params(labelbottom=False)
        
        x_argsort = np.argsort(xs[i])
        window_size=10
        ma_x = np.convolve(xs[i][x_argsort], 
            np.ones(window_size)/window_size, mode='valid')
        ma_y = np.convolve(ys[j][x_argsort], 
            np.ones(window_size)/window_size, mode='valid')
        axs[i,j].plot(ma_x, ma_y, color='Red', ls='solid', lw=1, zorder=3)

                
for k in range(4):                    
    axs[k,3].set_xlabel(xlabels[k])
    axs[0,k].text(0.9,0.9, text_quantities[k], 
        transform=axs[0,k].transAxes)
    axs[0,k].set_ylabel(r'$\delta_{\rm CB}/\delta_{\rm OM}$')

fig.tight_layout()
plt.show()