In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 4_anisotropic_df_jeans.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Compute Jeans quantities and summary statistics.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle
import pdb

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Analysis
import scipy.interpolate

## galpy
from galpy import potential

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/5_compare_distribution_functions/4_anisotropic_df_jeans/')
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Compute Jeans equation summary statistics

In [None]:
n_bootstrap = 10
verbose = True

# Pathing
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function')
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Get the orbits
sample_data_cb = np.load(os.path.join(analysis_dir,'sample_data_cb.npy'),
    allow_pickle=True)
sample_data_om = np.load(os.path.join(analysis_dir,'sample_data_om.npy'),
    allow_pickle=True)
sample_data_om2 = np.load(os.path.join(analysis_dir,'sample_data_om2.npy'),
    allow_pickle=True)

# Potential interpolator version
interpot_version = 'all_star_dm_enclosed_mass'

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_dftype = 'tanh_rotation'
stellar_halo_rotation_version = 'asymmetry_fit'
stellar_halo_rotation_ncut = 500

# Beta information
beta_ncut = 500

# Define density profiles
stellar_halo_densfunc = pdens.TwoPowerSpherical()

J_vals = []

for i in range(n_mw):
    # if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'Calculating Jeans equation for MW analog '
                          f'{i+1}/{n_mw}, merger {j+1}/{n_major}', end='\r')

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]

        # Get energy and angular momentum
        orbs = co.get_orbs('stars')[indx]
        n_star = len(orbs)
        masses = co.get_masses('stars')[indx].to_value(apu.Msun)
        rs = orbs.r().to_value(apu.kpc)
        pe = co.get_potential_energy('stars')[indx].to_value(apu.km**2/apu.s**2)

        # Mask the input orbits to only include those with radius greater
        # than the softening length
        r_softening = putil.get_softening_length('stars', z=0, physical=True)
        mask = rs > r_softening
        orbs = orbs[mask]
        masses = masses[mask]
        rs = rs[mask]
        pe = pe[mask]

        # Generate the adaptive binning kwargs
        rmin = np.max([np.min(rs), r_softening])
        n_bin = np.min([500, len(orbs)//10]) # n per bin
        adaptive_binning_kwargs = {
            'n':n_bin,
            'rmin':0.,
            'rmax':np.max( orbs.r().to_value(apu.kpc) ),
            'bin_mode':'exact numbers',
            'bin_equal_n':True,
            'end_mode':'ignore',
            'bin_cents_mode':'median',
        }

        # Compute the spherical Jeans equation quantities for N-body
        Js_nb,rs_nb,qs_nb = pkin.calculate_spherical_jeans(orbs, pe=pe, 
            n_bootstrap=n_bootstrap, rs_is_bin_mean_r=True, 
            adaptive_binning=adaptive_binning_kwargs)
        if np.sum(qs_nb[2]) > 0.:
            Jw_nb,Jwd_nb = pkin.calculate_weighted_average_J(
                Js_nb,rs_nb,qs=qs_nb,handle_nans=True)
        else:
            Jw_nb,Jwd_nb = np.nan,np.nan
        
        ### Constant beta DF
        mask_cb = (sample_data_cb['z0_sid'] == z0_sid) &\
                  (sample_data_cb['major_acc_sid'] == major_acc_sid) &\
                  (sample_data_cb['major_mlpid'] == major_mlpid) &\
                  (sample_data_cb['merger_number'] == j+1)
        indx_cb = np.where(mask_cb)[0]
        assert len(indx_cb) == 1, 'Something went wrong'
        indx_cb = indx_cb[0]
        sample_cb = sample_data_cb['sample'][indx_cb]
        pe_sample_cb = potential.evaluatePotentials(interpot, sample_cb.R(), 
            sample_cb.z()).to_value(apu.km**2/apu.s**2)

        # Compute the spherical Jeans equation quantities
        Js_cb,rs_cb,qs_cb = pkin.calculate_spherical_jeans(sample_cb, 
            pe=pe_sample_cb, n_bootstrap=n_bootstrap, rs_is_bin_mean_r=True, 
            adaptive_binning=adaptive_binning_kwargs)
        if np.sum(qs_cb[2]) > 0.:
            Jw_sample_cb,Jwd_sample_cb = pkin.calculate_weighted_average_J(
                Js_cb,rs_cb,qs=qs_cb,handle_nans=True)
        else:
            Jw_sample_cb,Jwd_sample_cb = np.nan,np.nan

        ### Osipkov-Merritt DF
        mask_om = (sample_data_om['z0_sid'] == z0_sid) &\
                  (sample_data_om['major_acc_sid'] == major_acc_sid) &\
                  (sample_data_om['major_mlpid'] == major_mlpid) &\
                  (sample_data_om['merger_number'] == j+1)
        indx_om = np.where(mask_om)[0]
        assert len(indx_om) == 1, 'Something went wrong'
        indx_om = indx_om[0]
        sample_om = sample_data_om['sample'][indx_om]
        pe_sample_om = potential.evaluatePotentials(interpot, sample_om.R(), 
            sample_om.z()).to_value(apu.km**2/apu.s**2)

        # Compute the spherical Jeans equation quantities
        Js_om,rs_om,qs_om = pkin.calculate_spherical_jeans(sample_om, 
            pe=pe_sample_om, n_bootstrap=n_bootstrap, rs_is_bin_mean_r=True, 
            adaptive_binning=adaptive_binning_kwargs)
        if np.sum(qs_om[2]) > 0.:
            Jw_sample_om,Jwd_sample_om = pkin.calculate_weighted_average_J(
                Js_om,rs_om,qs=qs_om,handle_nans=True)
        else:
            Jw_sample_om,Jwd_sample_om = np.nan,np.nan

        ### Osipkov-Merritt combination DF
        mask_om2 = (sample_data_om2['z0_sid'] == z0_sid) &\
                   (sample_data_om2['major_acc_sid'] == major_acc_sid) &\
                   (sample_data_om2['major_mlpid'] == major_mlpid) &\
                   (sample_data_om2['merger_number'] == j+1)
        indx_om2 = np.where(mask_om2)[0]
        assert len(indx_om2) == 1, 'Something went wrong'
        indx_om2 = indx_om2[0]
        sample_om2 = sample_data_om2['sample'][indx_om2]
        pe_sample_om2 = potential.evaluatePotentials(interpot, sample_om2.R(), 
            sample_om2.z()).to_value(apu.km**2/apu.s**2)

        # Compute the spherical Jeans equation quantities
        Js_om2,rs_om2,qs_om2 = pkin.calculate_spherical_jeans(sample_om2, 
            pe=pe_sample_om2, n_bootstrap=n_bootstrap, rs_is_bin_mean_r=True, 
            adaptive_binning=adaptive_binning_kwargs)
        if np.sum(qs_om2[2]) > 0.:
            Jw_sample_om2,Jwd_sample_om2 = pkin.calculate_weighted_average_J(
                Js_om2,rs_om2,qs=qs_om2,handle_nans=True)
        else:
            Jw_sample_om2,Jwd_sample_om2 = np.nan,np.nan

        # Save the values
        J_vals.append(
            (Jw_nb, Jwd_nb, Js_nb, rs_nb, qs_nb,
             Jw_sample_cb, Jwd_sample_cb, Js_cb, rs_cb, qs_cb,
             Jw_sample_om, Jwd_sample_om, Js_om, rs_om, qs_om,
             Jw_sample_om2, Jwd_sample_om2, Js_om2, rs_om2, qs_om2,
             z0_sid, j+1, major_mlpid, np.sum(masses))
            )

dt = np.dtype([ ('J_mean',float),
                ('J2_mean',float),
                ('Js',object),
                ('rs',object),
                ('qs',object),
                ('J_mean_cb',float),
                ('J2_mean_cb',float),
                ('Js_cb',object),
                ('rs_cb',object),
                ('qs_cb',object),
                ('J_mean_om',float),
                ('J2_mean_om',float),
                ('Js_om',object),
                ('rs_om',object),
                ('qs_om',object),
                ('J_mean_om2',float),
                ('J2_mean_om2',float),
                ('Js_om2',object),
                ('rs_om2',object),
                ('qs_om2',object),
                ('z0_sid',int),
                ('major_merger',int),
                ('major_mlpid',int),
                ('star_mass',float)
                ])
J_vals = np.array(J_vals, dtype=dt)

np.save(os.path.join(analysis_dir,'J_vals.npy'), J_vals)

### Wrangle the data

In [None]:
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Load the structured arrays
J_vals = np.load(os.path.join(analysis_dir,'J_vals.npy'), allow_pickle=True)

# Load the merger_information
merger_data = np.load(os.path.join(analysis_dir,'merger_data.npy'), 
    allow_pickle=True)

### Create plots showing mean J

In [None]:
Jm_samples = [J_vals['J_mean_cb'],
              J_vals['J_mean_om'],
              J_vals['J_mean_om2']]
sample_suffix = ['CB','OM','OM2']
Jm_lim = [-1.1,1.1]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5
annotate_fs = 8

# Make the figure
fig = plt.figure(figsize=(4,7))
gs = mpl.gridspec.GridSpec(nrows=10,ncols=4,figure=fig)

# Main axes
axs = [fig.add_subplot(gs[1:4,0:3]),
       fig.add_subplot(gs[4:7,0:3]),
       fig.add_subplot(gs[7:10,0:3])
        ]
# Top histogram axes
tax = fig.add_subplot(gs[0,0:3])
# Right histogram axes
raxs = [fig.add_subplot(gs[1:4,3]),
        fig.add_subplot(gs[4:7,3]),
        fig.add_subplot(gs[7:10,3])
        ]

# Loop over main axes
for i in range(3):
    axs[i].scatter(J_vals['J_mean'], Jm_samples[i], s=s, facecolor=facecolor,
        edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jm_lim)
    axs[i].set_ylim(Jm_lim)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\overline{\mathcal{J}}_{\mathrm{data}}$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'$\overline{\mathcal{J}}_{\mathrm{'+sample_suffix[i]+'}}$')
    axs[i].axhline(0., color='Black', linestyle='--')
    axs[i].axvline(0., color='Black', linestyle='--')

# Do the top histogram
tax.hist(J_vals['J_mean'], bins=21, range=Jm_lim, histtype='step', 
    color='Black', orientation='vertical')
tax.tick_params(labelbottom=False)
tax.set_xlim(Jm_lim)
tax.set_ylabel(r'$N$')
tax.axvline(0, color='Black', linestyle='--')
Jm_mean, Jm_std = np.mean(J_vals['J_mean']), np.std(J_vals['J_mean'])
tax.text(0.1, 0.5, r'$\mu = $'+str(round(Jm_mean,2))+'\n'+r'$\sigma = $'+str(round(Jm_std,2)),
    transform=tax.transAxes, fontsize=annotate_fs)

# Do the right histograms
for i in range(3):
    raxs[i].hist(Jm_samples[i], bins=21, range=Jm_lim, histtype='step', 
        color='Black', orientation='horizontal')
    raxs[i].tick_params(labelleft=False)
    raxs[i].set_xlabel(r'$N$')
    if i == 0:
        raxs[i].set_xlabel(r'$N$')
        raxs[i].xaxis.set_label_position('top')
        raxs[i].tick_params(labelbottom=False, labeltop=True)
    raxs[i].set_ylim(Jm_lim)
    raxs[i].axhline(0, color='Black', linestyle='--')
    mean_std_mask = (Jm_samples[i] > -5) &\
                    (Jm_samples[i] < 5)
    Jm_mean = np.mean(Jm_samples[i][mean_std_mask])
    Jm_std = np.std(Jm_samples[i][mean_std_mask])
    raxs[i].text(0.25, 0.8, r'$\mu = $'+str(round(Jm_mean,2))+'\n'+r'$\sigma = $'+str(round(Jm_std,2)),
        transform=raxs[i].transAxes, fontsize=annotate_fs)

# Set ticks properly
ticks = [-1,-0.5,0,0.5,1]
tax.set_xticks(ticks)
for i in range(3):
    axs[i].set_xticks(ticks)
    axs[i].set_yticks(ticks)
    raxs[i].set_yticks(ticks)

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.savefig('./fig/J_mean.pdf')
fig.show()

### Create plots showing J variance

In [None]:
Jwd_samples = [J_vals['J2_mean_cb'], 
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']]
sample_suffix = ['CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5

# Make the figure
fig = plt.figure(figsize=(4,7))
axs = fig.subplots(nrows=3, ncols=1)

# Loop over main axes
for i in range(3):
    axs[i].scatter(J_vals['J2_mean'], Jwd_samples[i], s=s, facecolor=facecolor,
        edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jwd_lim)
    axs[i].set_ylim(Jwd_lim)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{data}})$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'$\sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+'}})$')
    axs[i].axline((0,0),(1,1), color='Black', linestyle='--', alpha=0.5, 
        transform=axs[i].transAxes)
    n_low = np.sum(Jwd_samples[i] > J_vals['J2_mean'])
    n_high = np.sum(Jwd_samples[i] < J_vals['J2_mean'])
    axs[i].text(0.75, 0.90, str(n_low), transform=axs[i].transAxes)
    axs[i].text(0.90, 0.75, str(n_high), transform=axs[i].transAxes)
    axs[i].set_xscale('log')
    axs[i].set_yscale('log')

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05)
fig.savefig('./fig/J_dispersion.pdf')
fig.show()

### Same figure, but show it as a fractional difference rather than relation

In [None]:
Jwd_samples = [J_vals['J2_mean_cb'],
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']
                ]
sample_suffix = ['CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5

# Make the figure
fig = plt.figure(figsize=(4,7))
axs = fig.subplots(nrows=3, ncols=1)

# Assign colours to the points
np.all(J_vals['z0_sid'] == merger_data['z0_sid'])

# Loop over main axes
for i in range(3):
    difference = (Jwd_samples[i] - J_vals['J2_mean']) / J_vals['J2_mean']
    axs[i].scatter(J_vals['J2_mean'], difference, s=s, 
        facecolor=facecolor, edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jwd_lim)
    # axs[i].set_ylim(-1,1)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{data}})$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'fractional $\Delta \sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+'}})$')
    axs[i].axhline(0, color='Black', linestyle='--', alpha=0.5)
    # n_low = np.sum(Jwd_samples[i] > J_vals['J2_mean'])
    # n_high = np.sum(Jwd_samples[i] < J_vals['J2_mean'])
    # axs[i].text(0.80, 0.90, str(n_low), transform=axs[i].transAxes)
    # axs[i].text(0.90, 0.80, str(n_high), transform=axs[i].transAxes)
    axs[i].set_xscale('log')
    # axs[i].set_yscale('log')
    axs[i].set_ylim(-1,2.5)

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05)
fig.savefig('./fig/J_dispersion.pdf')
fig.show()