In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 4.1_anisotropic_df_jeans_plots.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Compute Jeans quantities and summary statistics.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle, copy
import pdb

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Analysis
import scipy.interpolate

## galpy
from galpy import potential

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import kinematics as pkin
from tng_dfs import plot as pplot
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/5_compare_distribution_functions/4.1_anisotropic_df_jeans_plots/')
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Wrangle the data

In [None]:
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Load the structured arrays
J_vals = np.load(os.path.join(analysis_dir,'J_vals.npy'), allow_pickle=True)

# Load the merger_information
merger_data = np.load(os.path.join(analysis_dir,'merger_data.npy'), 
    allow_pickle=True)

### For each merger remnant plot various kinematic parameters, Jeans parameters, etc.. 

In [None]:
verbose = True

# Pathing
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function')
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Potential interpolator version
interpot_version = 'all_star_dm_enclosed_mass'

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500

# Define density profiles
stellar_halo_densfunc = pdens.TwoPowerSpherical()

ctr = 0

rs_len = []

for i in range(n_mw):
    # if i > 2: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    # co = pcutout.TNGCutout(primary_filename)
    # co.center_and_rectify()
    # pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'Calculating Jeans equation for MW analog '
                          f'{i+1}/{n_mw}, merger {j+1}/{n_major}', end='\r')

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Make sure the Jeans data matches the N-body data
        assert J_vals[ctr]['z0_sid'] == z0_sid
        assert J_vals[ctr]['major_merger'] == j+1

        # Make the Jeans data
        Js = J_vals[ctr]['Js']
        rs = J_vals[ctr]['rs']
        qs = J_vals[ctr]['qs']

        # Plotting
        this_fig_dir = os.path.join(fig_dir, str(z0_sid), 'merger_'+str(j+1))
        os.makedirs(this_fig_dir,exist_ok=True)

        # Jeans diagnostics
        make_jeans_diagnostics_plot = False
        if make_jeans_diagnostics_plot:
            res = np.median(rs, axis=0)
            fig,axs = pplot.plot_jeans_diagnostics(Js,res,qs,
                pot=interpot, denspot=denspot)
            axs[0].axhline(J_vals[ctr]['J_mean'], color='DodgerBlue', ls='--')
            axs[0].fill_between(axs[0].get_xlim(),
                J_vals[ctr]['J_mean']-J_vals[ctr]['J2_mean'],
                J_vals[ctr]['J_mean']+J_vals[ctr]['J2_mean'],
                color='DodgerBlue', alpha=0.1)
            fig.tight_layout()
            figname = os.path.join(this_fig_dir, 'jeans_diagnostics.png')
            fig.savefig(figname)
            plt.close(fig)
        
        # More Jeans diagnostics
        make_extra_jeans_diagnostics_plot = False
        if make_extra_jeans_diagnostics_plot:
            Jnorm = qs[2]*qs[3]/qs[6]
            J1 = copy.deepcopy(qs[0])/Jnorm
            J2 = qs[2]*(qs[1] + (2*qs[3]-qs[4]-qs[5])/qs[6])/Jnorm
            res = np.median(rs, axis=0)
            fig,axs = pplot.plot_jeans_diagnostics2(Js,res,qs,J1=J1,J2=J2,
                pot=interpot, denspot=denspot)
            axs[0,0].axhline(J_vals[ctr]['J_mean'], color='DodgerBlue', ls='--')
            axs[0,0].fill_between(axs[0,0].get_xlim(),
                J_vals[ctr]['J_mean']-J_vals[ctr]['J2_mean'],
                J_vals[ctr]['J_mean']+J_vals[ctr]['J2_mean'],
                color='DodgerBlue', alpha=0.1)
            fig.tight_layout()
            figname = os.path.join(this_fig_dir, 'jeans_diagnostics_extra.png')
            fig.savefig(figname)
            plt.close(fig)

        # Make a plot of Jnorm
        make_jnorm_plot = True
        if make_jnorm_plot:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            Jnorm = qs[2]*qs[3]/qs[6]
            res = np.median(rs, axis=0)
            lJn,mJn,uJn = np.percentile(Jnorm, [16,50,84], axis=0)
            ax.plot(res, mJn, color='Black', linewidth=2.)
            ax.fill_between(res, lJn, uJn, color='Black', alpha=0.2)
            ax.set_xlabel(r'$r$ [kpc]')
            ax.set_ylabel(r'$J_{\rm norm}$')
            ax.set_xscale('log')
            ax.set_yscale('log')
            fig.tight_layout()
            figname = os.path.join(this_fig_dir, 'Jnorm.png')
            fig.savefig(figname)
            plt.close(fig)

        # Up counter
        ctr += 1

### Create plots showing mean J

In [None]:
Jm_samples = [J_vals['J_mean_cb'],
              J_vals['J_mean_om'],
              J_vals['J_mean_om2']]
sample_suffix = ['CB','OM','OM2']
Jm_lim = [-1.1,1.1]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5
annotate_fs = 8

# Make the figure
fig = plt.figure(figsize=(4,7))
gs = mpl.gridspec.GridSpec(nrows=10,ncols=4,figure=fig)

# Main axes
axs = [fig.add_subplot(gs[1:4,0:3]),
       fig.add_subplot(gs[4:7,0:3]),
       fig.add_subplot(gs[7:10,0:3])
        ]
# Top histogram axes
tax = fig.add_subplot(gs[0,0:3])
# Right histogram axes
raxs = [fig.add_subplot(gs[1:4,3]),
        fig.add_subplot(gs[4:7,3]),
        fig.add_subplot(gs[7:10,3])
        ]

# Loop over main axes
for i in range(3):
    axs[i].scatter(J_vals['J_mean'], Jm_samples[i], s=s, facecolor=facecolor,
        edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jm_lim)
    axs[i].set_ylim(Jm_lim)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\overline{\mathcal{J}}_{\mathrm{data}}$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'$\overline{\mathcal{J}}_{\mathrm{'+sample_suffix[i]+'}}$')
    axs[i].axhline(0., color='Black', linestyle='--')
    axs[i].axvline(0., color='Black', linestyle='--')

# Do the top histogram
tax.hist(J_vals['J_mean'], bins=21, range=Jm_lim, histtype='step', 
    color='Black', orientation='vertical')
tax.tick_params(labelbottom=False)
tax.set_xlim(Jm_lim)
tax.set_ylabel(r'$N$')
tax.axvline(0, color='Black', linestyle='--')
Jm_mean, Jm_std = np.mean(J_vals['J_mean']), np.std(J_vals['J_mean'])
tax.text(0.1, 0.5, r'$\mu = $'+str(round(Jm_mean,2))+'\n'+r'$\sigma = $'+str(round(Jm_std,2)),
    transform=tax.transAxes, fontsize=annotate_fs)

# Do the right histograms
for i in range(3):
    raxs[i].hist(Jm_samples[i], bins=21, range=Jm_lim, histtype='step', 
        color='Black', orientation='horizontal')
    raxs[i].tick_params(labelleft=False)
    raxs[i].set_xlabel(r'$N$')
    if i == 0:
        raxs[i].set_xlabel(r'$N$')
        raxs[i].xaxis.set_label_position('top')
        raxs[i].tick_params(labelbottom=False, labeltop=True)
    raxs[i].set_ylim(Jm_lim)
    raxs[i].axhline(0, color='Black', linestyle='--')
    mean_std_mask = (Jm_samples[i] > -5) &\
                    (Jm_samples[i] < 5)
    Jm_mean = np.mean(Jm_samples[i][mean_std_mask])
    Jm_std = np.std(Jm_samples[i][mean_std_mask])
    raxs[i].text(0.25, 0.8, r'$\mu = $'+str(round(Jm_mean,2))+'\n'+r'$\sigma = $'+str(round(Jm_std,2)),
        transform=raxs[i].transAxes, fontsize=annotate_fs)

# Set ticks properly
ticks = [-1,-0.5,0,0.5,1]
tax.set_xticks(ticks)
for i in range(3):
    axs[i].set_xticks(ticks)
    axs[i].set_yticks(ticks)
    raxs[i].set_yticks(ticks)

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.savefig('./fig/J_mean.pdf')
fig.show()

### Create plots showing J variance

In [None]:
Jwd_samples = [J_vals['J2_mean_cb'], 
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']]
sample_suffix = ['CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5

# Make the figure
fig = plt.figure(figsize=(4,7))
axs = fig.subplots(nrows=3, ncols=1)

# Loop over main axes
for i in range(3):
    axs[i].scatter(J_vals['J2_mean'], Jwd_samples[i], s=s, facecolor=facecolor,
        edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jwd_lim)
    axs[i].set_ylim(Jwd_lim)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{data}})$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'$\sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+'}})$')
    axs[i].axline((0,0),(1,1), color='Black', linestyle='--', alpha=0.5, 
        transform=axs[i].transAxes)
    n_low = np.sum(Jwd_samples[i] > J_vals['J2_mean'])
    n_high = np.sum(Jwd_samples[i] < J_vals['J2_mean'])
    axs[i].text(0.75, 0.90, str(n_low), transform=axs[i].transAxes)
    axs[i].text(0.90, 0.75, str(n_high), transform=axs[i].transAxes)
    axs[i].set_xscale('log')
    axs[i].set_yscale('log')

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05)
fig.savefig('./fig/J_dispersion.pdf')
fig.show()

### Same figure, but show it as a fractional difference rather than relation

In [None]:
Jwd_samples = [J_vals['J2_mean_cb'],
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']
                ]
sample_suffix = ['CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.5

# Make the figure
fig = plt.figure(figsize=(4,7))
axs = fig.subplots(nrows=3, ncols=1)

# Assign colours to the points
np.all(J_vals['z0_sid'] == merger_data['z0_sid'])

# Loop over main axes
for i in range(3):
    difference = (Jwd_samples[i] - J_vals['J2_mean']) / J_vals['J2_mean']
    axs[i].scatter(J_vals['J2_mean'], difference, s=s, 
        facecolor=facecolor, edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jwd_lim)
    # axs[i].set_ylim(-1,1)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{data}})$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'fractional $\Delta \sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+'}})$')
    axs[i].axhline(0, color='Black', linestyle='--', alpha=0.5)
    # n_low = np.sum(Jwd_samples[i] > J_vals['J2_mean'])
    # n_high = np.sum(Jwd_samples[i] < J_vals['J2_mean'])
    # axs[i].text(0.80, 0.90, str(n_low), transform=axs[i].transAxes)
    # axs[i].text(0.90, 0.80, str(n_high), transform=axs[i].transAxes)
    axs[i].set_xscale('log')
    # axs[i].set_yscale('log')
    axs[i].set_ylim(-1,2.5)

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05)
fig.savefig('./fig/J_dispersion.pdf')
fig.show()

## Same figures but coloured by various parameters

In [None]:
Jwd_samples = [J_vals['J2_mean_cb'], 
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']]
sample_suffix = ['CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.8

# Make the figure
fig = plt.figure(figsize=(4,7))
axs = fig.subplots(nrows=3, ncols=1)

# Make a colour for the points
np.all(J_vals['z0_sid'] == merger_data['z0_sid'])
c = np.log10(
    merger_data['star_mass']
)
# c = merger_data[]
vmin = np.min(c)
vmax = np.max(c)
colorbar_label = r'$\log_{10} \left( M_{\star} / M_{\odot} \right)$'

# Loop over main axes
for i in range(3):
    pts = axs[i].scatter(J_vals['J2_mean'], Jwd_samples[i], s=s, c=c,
        edgecolor=edgecolor, alpha=alpha, vmin=vmin, vmax=vmax)
    axs[i].set_xlim(Jwd_lim)
    axs[i].set_ylim(Jwd_lim)
    # axs[i].set_aspect('equal')
    if i == 2:
        axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{data}})$')
    else:
        axs[i].tick_params(labelbottom=False)
    axs[i].set_ylabel(r'$\sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+'}})$')
    axs[i].axline((0,0),(1,1), color='Black', linestyle='--', alpha=0.5, 
        transform=axs[i].transAxes)
    # n_low = np.sum(Jwd_samples[i] > J_vals['J2_mean'])
    # n_high = np.sum(Jwd_samples[i] < J_vals['J2_mean'])
    # axs[i].text(0.75, 0.90, str(n_low), transform=axs[i].transAxes)
    # axs[i].text(0.90, 0.75, str(n_high), transform=axs[i].transAxes)
    axs[i].set_xscale('log')
    axs[i].set_yscale('log')

cax = fig.add_axes([0.92, 0.1, 0.02, 0.8])
cbar = fig.colorbar(pts, cax=cax)
cbar.set_label(colorbar_label, fontsize=8)

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05, right=0.9)
# fig.savefig('./fig/J_dispersion.pdf')
fig.show()

In [None]:
Jwd_samples = [J_vals['J2_mean'],
               J_vals['J2_mean_cb'], 
               J_vals['J2_mean_om'],
               J_vals['J2_mean_om2']]
sample_suffix = ['Data','CB','OM','OM2']
Jwd_lim = [0.5,50]

s=10
facecolor='none'
edgecolor='Black'
alpha=0.8

# Make the figure
fig = plt.figure(figsize=(10,4))
axs = fig.subplots(nrows=1, ncols=4)

# Make a colour for the points
np.all(J_vals['z0_sid'] == merger_data['z0_sid'])
y = np.log10(
    merger_data['star_mass']
)
# c = merger_data[]
# vmin = np.min(c)
# vmax = np.max(c)
ylabel = r'$\log_{10} \left( M_{\star} / M_{\odot} \right)$'

# Loop over main axes
for i in range(4):
    pts = axs[i].scatter(Jwd_samples[i], y, s=s,
        edgecolor=edgecolor, alpha=alpha)
    axs[i].set_xlim(Jwd_lim)
    # axs[i].set_aspect('equal')
    axs[i].set_xlabel(r'$\sigma(\mathcal{J}_{\mathrm{'+sample_suffix[i]+r'}})$')
    if i == 0:
        axs[i].set_ylabel(ylabel)
    else:
        axs[i].tick_params(labelleft=False)
    axs[i].set_xscale('log')

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.05)
# fig.savefig('./fig/J_dispersion.pdf')
fig.show()

In [None]:
ms = [0.,8.,9.,12.]

for i in range(len(ms)-1):
    this_fig_dir = os.path.join(fig_dir,'stellar_mass_divisions',
        'mass_range_1e'+str(ms[i])+'_1e'+str(ms[i+1]))
    os.makedirs(this_fig_dir,exist_ok=True)
    print(ms[i],ms[i+1],'\n--------')
    for j in range(len(merger_data)):
        if (merger_data['star_mass'][j] > 10**ms[i]) &\
           (merger_data['star_mass'][j] < 10**ms[i+1]):
            print(merger_data['z0_sid'][j], merger_data['merger_number'][j])
        else:
            continue
        targ_fig_dir = os.path.join(fig_dir,str(merger_data['z0_sid'][j]),
            'merger_'+str(merger_data['merger_number'][j]))
        # Copy figures from targ_fig_dir to this_fig_dir
        for f in os.listdir(targ_fig_dir):
            if f[-4:] in ['.pdf','.png']:
                new_f = f[:-4]+'_'+str(merger_data['z0_sid'][j])+\
                    '_merger_'+str(merger_data['merger_number'][j])+f[-4:]
                os.system('cp '+os.path.join(targ_fig_dir,f)+' '+\
                    os.path.join(this_fig_dir,new_f))
        

In [None]:
fig = plt.figure(figsize=(5,10))
axs = fig.subplots(nrows=3, ncols=1)
xkey = 'anisotropy'

for i in range(len(merger_data)):

    nJ = len(J_vals['Js'][i][0])
    J2 = J_vals['J2_mean'][i]

    Jstart = J_vals['Js'][i][:,0] / J2
    Jmid = J_vals['Js'][i][:,nJ//2] / J2
    Jend = J_vals['Js'][i][:,-1] / J2

    xval = merger_data[xkey][i]

    for j,val in enumerate([Jstart,Jmid,Jend]):
        val = np.abs(val)
        lv,mv,uv = np.percentile(val,[16,50,84])
        axs[j].errorbar(xval,mv,yerr=[[mv-lv],[uv-mv]],
            fmt='o',color='k',alpha=0.5)

for i in range(3):
    # axs[i].set_xscale('log')
    axs[i].set_yscale('log')
    pass

axs[0].set_ylabel(r'$\vert J_{\rm start}/\sigma_J \vert$')
axs[1].set_ylabel(r'$\vert J_{\rm mid}/\sigma_J \vert$')
axs[2].set_ylabel(r'$\vert J_{\rm end}/\sigma_J \vert$')
# axs[2].set_xlabel(r'$M_{\rm star}$')
# axs[2].set_xlabel(r'$\alpha_{2}$')
axs[2].set_xlabel(r'$\beta$')

fig.tight_layout()
fig.savefig(os.path.join(fig_dir,'J_slices_'+xkey+'.png'))
fig.show()