In [None]:
# ------------------------------------------------------------------------
#
# TITLE - jeans_through_time.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstring:
'''Examine candidate remnants and z=0 stellar halos in the context of the Jeans 
equation
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb
import h5py
import glob
import copy
import dill as pickle
from astropy import units as apu

## Matplotlib
import matplotlib
from matplotlib import pyplot as plt

## Galpy
from galpy import orbit, potential, df
import galpy.util

## Scipy
import scipy.interpolate

sys.path.insert(0,'../../src/')
from tng_dfs import util as putil
from tng_dfs import tree as ptree
from tng_dfs import cutout as pcutout
from tng_dfs.util import get

### Notebook setup
%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','RO','VO','ZO','LITTLE_H']
data_dir,ro,vo,zo,h = putil.parse_config_dict(cdict,keywords)

### Functions

In [None]:
def calculate_spherical_jeans_quantities(orbs,pe,r_range=[0,100],n_bin=10,
    norm_by_galpy_scale_units=False,calculate_pe_with_pot=False,ro=ro,vo=vo):
    '''calculate_spherical_jeans_quantities:
    
    Calculate the quantities used in the spherical Jeans equation.
    
    Args:
        orbs (Orbits) - Orbits object containing particles / kinematic sample
        pe (array) - Potential energy of each particle in the sample
        r_range (optional, list) - Range of radii to consider, in kpc 
            [default: [0,100]]
        n_bin (optional, int) - Number of bins to use in calculating Jeans
            equation, note derivative quantities will be calculated with 
            n_bin+1 bins [default: 10]
        norm_by_galpy_scale_units (optional, bool) - If True, normalize the
            Jeans equation by galpy scale units [default: False]
        calculate_pe_with_pot (optional, bool) - If True, calculate the 
            potential at the bin centers, rather than the mean potential of the 
            orbs in the bin [default: False]
        ro (optional, float) - Distance scale in kpc [default: 8.275]
        vo (optional, float) - Velocity scale in km/s [default: 220.]
    
    Returns:
        qs (tuple) - Tuple of kinematic quantities used to calculate Jeans
            equation, output from calculate_spherical_jeans_quantities, 
            in order: dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs
    '''
    orbs = copy.deepcopy(orbs)
    orbs.turn_physical_on(ro=ro,vo=vo)

    ## Determine bins for kinematic properties
    
    # First need bins for derivatives, one more bin than for the data itself, 
    # since we're taking derivatives
    n_dr_bin = n_bin+1
    dr_bin_edge = np.linspace(r_range[0],r_range[1],n_dr_bin+1)
    dr_bin_cents = (dr_bin_edge[1:]+dr_bin_edge[:-1])/2
    # dr_bin_delta = dr_bin_edge[1:]-dr_bin_edge[:-1]

    # One fewer bin for data, since we're taking derivatives. The edges are 
    # the derivative bin centers
    bin_edge = copy.deepcopy(dr_bin_cents)
    bin_cents = (bin_edge[1:]+bin_edge[:-1])/2
    # bin_delta = bin_edge[1:]-bin_edge[:-1]

    # Bin the data, derivative quantities first
    nuvr2 = np.zeros_like(dr_bin_cents)
    phi = np.zeros_like(dr_bin_cents)
    # Non-derivative quantities
    nu = np.zeros_like(bin_cents)
    vr2 = np.zeros_like(bin_cents)
    vt2 = np.zeros_like(bin_cents)
    vp2 = np.zeros_like(bin_cents)

    rs = orbs.r(use_physical=True).to(apu.kpc).value
    pe = pe.to(apu.km**2/apu.s**2).value
    # pe = potential.evaluatePotentials(pot,orbs.R(),orbs.z(),
    #     use_physical=True).to(apu.km**2/apu.s**2).value
    # pe_bin_cents = potential.evaluatePotentials(pot,dr_bin_cents*apu.kpc,
    #     0*apu.kpc,use_physical=True).to(apu.km**2/apu.s**2).value

    # Derivative quantities
    for i in range(len(dr_bin_cents)):
        bin_mask = (rs>=dr_bin_edge[i]) & (rs<dr_bin_edge[i+1])
        n_in_bin = np.sum( bin_mask )
        bin_vol = 4*np.pi/3*(dr_bin_edge[i+1]**3-dr_bin_edge[i]**3)
        dr_nu = n_in_bin/bin_vol
        dr_vr2 = np.mean(orbs.vr(use_physical=True).to(apu.km/apu.s).value
            [bin_mask]**2.)
        if calculate_pe_with_pot:
            phi[i] = pe_bin_cents[i]
        else:
            phi[i] = np.mean(pe[bin_mask])
        nuvr2[i] = dr_nu*dr_vr2
    dphidr = np.diff(phi)/np.diff(dr_bin_cents)
    dnuvr2dr = np.diff(nuvr2)/np.diff(dr_bin_cents)

    # Non-derivative quantities
    for i in range(len(bin_cents)):
        bin_mask = (rs>=bin_edge[i]) & (rs<bin_edge[i+1])
        n_in_bin = np.sum( bin_mask )
        bin_vol = 4*np.pi/3*(bin_edge[i+1]**3-bin_edge[i]**3)
        nu[i] = n_in_bin/bin_vol
        vr2[i] = np.mean(orbs.vr(use_physical=True).to(apu.km/apu.s).value
            [bin_mask]**2.)
        vp2[i] = np.mean(orbs.vtheta(use_physical=True).to(apu.km/apu.s).value
            [bin_mask]**2.)
        vt2[i] = np.mean(orbs.vT(use_physical=True).to(apu.km/apu.s).value
            [bin_mask]**2.)
    
    # Normalize densities by number of orbits so they're proper number 
    # densities
    nu /= len(orbs)
    dnuvr2dr /= len(orbs)

    if norm_by_galpy_scale_units:
        nu = nu*(ro**3)
        vr2 = vr2/(vo**2)
        vp2 = vp2/(vo**2)
        vt2 = vt2/(vo**2)
        bin_cents = bin_cents/ro
        dphidr = dphidr*ro/(vo**2)
        dnuvr2dr = dnuvr2dr*(ro**4)/(vo**2)

    return dnuvr2dr,dphidr,nu,vr2,vp2,vt2,bin_cents

def calculate_spherical_jeans(orbs,pe,n_bootstrap=1,r_range=[0,100],n_bin=10,
    norm_by_galpy_scale_units=False,norm_by_nuvr2_r=True,
    calculate_pe_with_pot=False,return_kinematics=True,ro=ro,vo=vo):
    '''calculate_spherical_jeans:

    Calculate the spherical Jeans equation for a given kinematic sample

    Args:
        orbs (Orbits) - Orbits object containing particles / kinematic sample
        pe (array) - Potential energy of each particle in the sample
        n_bootstrap (optional, int) - Number of bootstrap samples to calculate 
            the Jeans equation for, if 1, then don't bootstrap [default: 1]
        r_range (optional, list) - Range of radii to consider, in kpc 
            [default: [0,100]]
        n_bin (optional, int) - Number of bins to use in calculating Jeans
            equation, note derivative quantities will be calculated with 
            n_bin+1 bins [default: 10]
        norm_by_galpy_scale_units (optional, bool) - If True, normalize the
            Jeans equation by galpy scale units [default: False]
        norm_by_nuvr2_r (optional, bool) - If True, normalize the Jeans equation
            by nu*vr^2/r [default: True]
        calculate_pe_with_pot (optional, bool) - If True, calculate the 
            potential at the bin centers, rather than the mean potential of the 
            orbs in the bin [default: False]
        return_kinematics (optional, bool) - If True, return the kinematics
            used to calculate the Jeans equation [default: True]
        ro (optional, float) - Distance scale in kpc [default: 8.275]
        vo (optional, float) - Velocity scale in km/s [default: 220.]
    
    Returns:
        J (np.ndarray) - Jeans equation, may be normalized
        rs (np.ndarray) - Radii at which Jeans equation is calculated
        qs (tuple) - Tuple of kinematic quantities used to calculate Jeans
            equation, output from calculate_spherical_jeans_quantities, 
            in order: dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs
    '''
    # Compute the quantities for the spherical Jeans equation
    if n_bootstrap>1:
        qs = np.zeros((7,n_bootstrap,n_bin))
        for i in range(n_bootstrap):
            # Random bootstrap index
            indx = np.random.choice(np.arange(len(orbs),dtype=int),
                size=len(orbs)-1,replace=False)
            # Bootstrap sample
            _qs = calculate_spherical_jeans_quantities(orbs[indx],pe[indx],
                r_range=r_range,n_bin=n_bin,
                norm_by_galpy_scale_units=norm_by_galpy_scale_units,
                calculate_pe_with_pot=calculate_pe_with_pot,ro=ro,vo=vo)
            qs[:,i,:] = _qs
    else:
        qs = calculate_spherical_jeans_quantities(orbs,pe,r_range=r_range,
            n_bin=n_bin,norm_by_galpy_scale_units=norm_by_galpy_scale_units,
            calculate_pe_with_pot=calculate_pe_with_pot,ro=ro,vo=vo)

    dnuvr2dr,dphidr,nu,vr2,vp2,vt2,rs = qs

    # Compute the Jeans equation
    J = nu*(dphidr + (2*vr2-vp2-vt2)/rs) + dnuvr2dr

    # Normalize by nu*vr^2/r if desired. Note that this returns the same 
    # answer regardless of whether using physical or galpy units.
    if norm_by_nuvr2_r and not norm_by_galpy_scale_units:
        J = J/(nu*vr2/rs)

    if return_kinematics:
        return J,rs,qs
    else:
        return J,rs

def plot_jeans_diagnostics(Js,rs,qs,norm_by_nuvr2_r=True):

    data_color = 'Black'
    data_linewidth = 2.
    truth_color = 'Red'
    plot_spans = True
    if np.ndim(Js)==1:
        plot_spans = False

    fig = plt.figure(figsize=(12,8))
    gs = fig.add_gridspec(nrows=4,ncols=3)
    axs = np.array([fig.add_subplot(gs[:2,0]),
                    fig.add_subplot(gs[:2,1]),
                    fig.add_subplot(gs[:2,2]),
                    fig.add_subplot(gs[2:,0]),
                    fig.add_subplot(gs[2,1]),
                    fig.add_subplot(gs[3,1]),
                    fig.add_subplot(gs[2:,2])
                    ])
    # axs = fig.subplots(nrows=2,ncols=3).flatten()

    # J in the first panel
    lJ,mJ,uJ = np.percentile(np.atleast_2d(Js), [16,50,84], axis=0)
    axs[0].plot(rs, mJ, color=data_color, linewidth=data_linewidth)
    if plot_spans:
        axs[0].fill_between(rs, lJ, uJ, color='Black', alpha=0.25)
    axs[0].axhline(0, color='Black', linestyle='--', linewidth=0.5)
    axs[0].set_xlim(0,50)
    axs[0].set_xlabel('r [kpc]')
    if norm_by_nuvr2_r:
        axs[0].set_ylabel(r'$J / (\nu \bar{v_{r}^{2}} / r)$')
    else:
        axs[0].set_ylabel('$J$')

    # Density in the second upper panel
    fiducial_alphas = [1,2,3,4,5]
    lnu,mnu,unu = np.percentile(np.atleast_2d(qs[2]), [16,50,84], axis=0)
    axs[1].plot(rs, mnu, color=data_color, linewidth=data_linewidth)
    if plot_spans:
        axs[1].fill_between(rs, unu, lnu, color='Black', alpha=0.25)
    for alpha in fiducial_alphas:
        _norm = mnu[0]/rs[0]**-alpha
        axs[1].plot(rs, _norm*rs**-alpha, color='DodgerBlue', 
            linewidth=0.5, linestyle='--')
    # axs[1].set_xlim(0,50)
    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'r [kpc]')
    axs[1].set_ylabel(r'$\nu$')

    # Beta in the third panel
    beta = 1 - (qs[4]+qs[5])/(2*qs[3])
    lbeta,mbeta,ubeta = np.percentile(np.atleast_2d(beta), [16,50,84], axis=0)
    axs[2].plot(rs, mbeta, color=data_color, linewidth=data_linewidth)
    if plot_spans:
        axs[2].fill_between(rs, ubeta, lbeta, color='Black', alpha=0.25)
    axs[2].axhline(0, color='Black', linestyle='--', linewidth=0.5)
    axs[2].set_xlim(0,50)
    axs[2].set_xlabel(r'r [kpc]')
    axs[2].set_ylabel(r'$\beta$')

    # Radial velocity dispersions in the fourth panel, polar and azimuthal 
    # in the fifth upper/lower panels
    colors = ['DodgerBlue','Crimson','DarkOrange']
    v2_names = [r'$\bar{v_{r}^{2}}$',
                r'$\bar{v_{\phi}^{2}}$',
                r'$\bar{v_{\theta}^{2}}$',]
    for i in range(3):
        for j in range(3):
            lv2,mv2,uv2 = np.percentile(np.atleast_2d(qs[j+3]), [16,50,84], 
                axis=0)
            if i == j:
                axs[i+3].plot(rs, mv2, color=colors[j], 
                    linewidth=data_linewidth+2, zorder=2)
                if plot_spans:
                    axs[i+3].fill_between(rs, uv2, lv2, color=colors[i], 
                        alpha=0.25, zorder=1)
            else:
                axs[i+3].plot(rs, mv2, color=colors[j], alpha=1., 
                    linestyle='--', linewidth=1., zorder=3)
        axs[i+3].set_xlim(0,50)
        if i in [0,2]:
            axs[i+3].set_xlabel(r'r [kpc]')
        axs[i+3].set_ylabel(v2_names[i])
        axs[i+3].set_yscale('log')
    
    # dphi/dr in the sixth panel
    ldphidr,mdphidr,udphidr = np.percentile(np.atleast_2d(qs[1]), [16,50,84], 
        axis=0)
    axs[6].plot(rs, mdphidr, color=data_color, linewidth=data_linewidth)
    if plot_spans:
        axs[6].fill_between(rs, udphidr, ldphidr, color='Black', alpha=0.25)
    axs[6].set_xlim(0,50)
    axs[6].set_xlabel(r'r [kpc]')
    axs[6].set_ylabel(r'$\mathrm{d}\Phi/\mathrm{d}r$')
    axs[6].set_yscale('log')
    
    return fig,axs

### API Setup & Milky Way Analogs from TNG50-1

In [None]:
# Base URL
baseURL = 'http://www.tng-project.org/api/'
# Get list of simulations
r = get(baseURL)
sim_names = [sim['name'] for sim in r['simulations']]
tng50_indices = [sim_names.index('TNG50-'+str(i+1)) for i in range(4)]
# Choose the lowest resolution tng50 run
tng50_urls = [r['simulations'][i]['url'] for i in tng50_indices]
tng50_url = tng50_urls[0]

# Get the simulation, snapshots, snapshot redshifts
sim = get( tng50_url )
snaps = get( sim['snapshots'] )

In [None]:
with open('../parse_sublink_trees/data/all_major_list.pkl','rb') as f:
    all_major_list = pickle.load(f)
##wi

# Number of primary MW analogs under consideration
n_mw = len(all_major_list)

# Number that we'll actually analyse for now
n_do = 3

In [None]:
# Check how many primaries we actually have data for
for i in range(n_mw):
    if i > 3: continue

    # Get primary particle properties
    major_dict = all_major_list[i]
    major_list = major_dict['major_list']
    n_major = major_dict['n_major']

    has_data = np.zeros(n_major+1,dtype=bool)
    for j in range(n_major+1):
        if j == 0:
            assert major_list[j]['is_primary'], 'Index 0 not primary'
        
        major_snaps = major_list[j]['snaps']
        major_subfind_ids = major_list[j]['subfind_ids']
        has_file = np.zeros(len(major_snaps),dtype=bool)
        for k in range(len(major_snaps)):
            snap_path = data_dir+'cutouts/snap_'+str(major_snaps[k])+'/'
            snap_filename = snap_path+'cutout_'+str(major_subfind_ids[k])+\
                '.hdf5'
            has_file[k] = os.path.isfile(snap_filename)
        
        # Report the max and min snapshot for which we have data
        print('Primary '+str(i)+', major '+str(j))
        if not has_file.any():
            print('No data for this major')
            continue
        print('Max snapshot: '+str(major_snaps.max()))
        print('Min snapshot: '+str(major_snaps.min()))
        print('Max snapshot saved: '+str(major_snaps[has_file].max()))
        print('Min snapshot saved: '+str(major_snaps[has_file].min()))
        print('--------\n')
    
    
        if has_file.all(): has_data[j] = True
    print('--------------------------------')
    if has_data.all():
        print('Primary '+str(i)+' has data for all majors')
    print('--------------------------------\n')
    

In [None]:
# Get primary particle properties
major_dict = all_major_list[0]
subfind_id = major_dict['major_list'][0]['subfind_ids'][0]
snap = major_dict['major_list'][0]['snaps'][0]
primary_path = data_dir+'cutouts/snap_'+str(snap)+'/'
primary_filename = primary_path+'cutout_'+str(subfind_id)+'.hdf5'
cutout1 = pcutout.TNGCutout(primary_filename)

In [None]:
cutout1.header

In [None]:
# Get primary particle properties
major_dict = all_major_list[0]
subfind_id = major_dict['major_list'][0]['subfind_ids'][1]
snap = major_dict['major_list'][0]['snaps'][1]
primary_path = data_dir+'cutouts/snap_'+str(snap)+'/'
primary_filename = primary_path+'cutout_'+str(subfind_id)+'.hdf5'
cutout2 = pcutout.TNGCutout(primary_filename)

In [None]:
for cr in cutout2.header['CutoutRequest'].split('+'):
    c = cr.split('=')
    print(c[0],c[1].split(','))

## Spherical Jeans equation

Will compute the spherical Jeans equation to see how well it is satisfied

In [None]:
for i in range(n_mw):
    if i > n_do: continue

    # Get primary particle properties
    major_dict = all_major_list[i]
    primary_z0_subfind_id = major_dict['primary_z0_subfind_id']
    primary_path = data_dir+'cutouts/snap_99/'
    primary_filename = primary_path+'cutout_'+str(primary_z0_subfind_id)+'.hdf5'
    
    # Make the TNGCutout instance
    cutout = pcutout.TNGCutout(primary_filename)
    # Bounding radii for centering / rectifying in kpc
    vcen_rmin = 0.
    vcen_rmax = 5.
    rot_rmin = 2.
    rot_rmax = 10.
    # Center and rectify
    cutout.center_and_rectify(cen_ptype='PartType4', vcen_ptype='PartType4',
        vcen_rmin=vcen_rmin, vcen_rmax=vcen_rmax, rot_ptype='PartType4', 
        rot_rmin=rot_rmin, rot_rmax=rot_rmax)
    
    # Get properties, energy, angular momentum
    orbs = cutout.get_orbs('PartType4')
    rs = orbs.r().value
    vels = cutout.get_velocities('PartType4', physical=True)
    pot = cutout.get_potential_energy('PartType4', physical=True)
    orbs = cutout.get_orbs('PartType4')
    kin = 0.5*np.sum(np.square(vels),axis=1)
    E = (pot+kin)
    J,Jz,Jp = cutout.get_J_Jz_Jp('PartType4',physical=True)
    cutout.get_E_Jcirc_spline('PartType4',angmom='J')
    Jcirc = cutout.Jcirc(E)
    Enorm = E/np.abs(E).max()
    Jz_Jcirc = Jz / Jcirc
    Jp_Jcirc = Jp / Jcirc
    
    # Make the figure
    fig = plt.figure(figsize=(4,4))
    ax = fig.add_subplot(111)
    
    ### Jz/Jc - Enorm
    ax.set_aspect('auto')
    hist,_,_ = np.histogram2d(Jz_Jcirc, Enorm, bins=20, range=[[-1,1],[-1,0]])
    hist = np.log10(np.rot90(hist))
    im = ax.imshow(hist, cmap='Blues', extent=(-1,1,-1,0), aspect='auto',
                       vmin=2, vmax=5)
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('log N')
    ax.set_xlabel(r'$j_{z}/j_{circ}$', fontsize=16)
    ax.set_ylabel(r'$E/ \vert E_{norm} \vert$', fontsize=16)
    
    # Lines
    Jz_Jcirc_halo_bound = 0.5
    Jz_Jcirc_disk_bound = 0.8
    Enorm_bulge_bound = -0.75
    ax.axvline(Jz_Jcirc_halo_bound, color='Black', linestyle='dashed')
    ax.axvline(Jz_Jcirc_disk_bound, color='Black', linestyle='dashed')
    ax.plot([-1,Jz_Jcirc_halo_bound], [Enorm_bulge_bound,Enorm_bulge_bound], 
            color='Black', linestyle='dashed')
    
    # Labels
    ax.annotate('Halo', xy=(0.02,0.95), xycoords='axes fraction', 
                    fontsize=12)
    ax.annotate('Bulge', xy=(0.02,0.2), xycoords='axes fraction', 
                    fontsize=12)
    ax.annotate('Thick Disk', xy=(0.77,0.75), xycoords='axes fraction', 
                fontsize=12, rotation='vertical')
    ax.annotate('Thin Disk', xy=(0.92,0.75), xycoords='axes fraction', 
                fontsize=12, rotation='vertical')
    
    fig.suptitle('primary z=0 subfind id: '+str(primary_z0_subfind_id))
    figname = './fig/Enorm_Jcirc_primary_halo_'+str(primary_z0_subfind_id)+'.png'
    fig.savefig(figname, dpi=300)
    # plt.close(fig)
    # fig.show()

In [None]:
nbins = 10
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False

for i in range(n_mw):
    if i > n_do: continue

    # Get primary particle properties
    major_dict = all_major_list[i]
    primary_z0_subfind_id = major_dict['primary_z0_subfind_id']
    primary_path = data_dir+'cutouts/snap_99/'
    primary_filename = primary_path+'cutout_'+str(primary_z0_subfind_id)+'.hdf5'
    
    # Make the TNGCutout instance
    cutout = pcutout.TNGCutout(primary_filename)
    # Bounding radii for centering / rectifying in kpc
    vcen_rmin = 0.
    vcen_rmax = 5.
    rot_rmin = 2.
    rot_rmax = 10.
    # Center and rectify
    cutout.center_and_rectify(cen_ptype='PartType4', vcen_ptype='PartType4',
        vcen_rmin=vcen_rmin, vcen_rmax=vcen_rmax, rot_ptype='PartType4', 
        rot_rmin=rot_rmin, rot_rmax=rot_rmax)
    
    # Get properties, energy, angular momentum
    orbs = cutout.get_orbs('PartType4')
    rs = orbs.r().value
    vels = cutout.get_velocities('PartType4', physical=True)
    pot = cutout.get_potential_energy('PartType4', physical=True)
    kin = 0.5*np.sum(np.square(vels),axis=1)
    E = (pot+kin)
    J,Jz,Jp = cutout.get_J_Jz_Jp('PartType4',physical=True)
    cutout.get_E_Jcirc_spline('PartType4',angmom='J')
    Jcirc = cutout.Jcirc(E)
    Enorm = E/np.abs(E).max()
    Jz_Jcirc = Jz / Jcirc
    Jp_Jcirc = Jp / Jcirc
    
    # Mask the halo using these quantities
    Jz_Jcirc_halo_bound = 0.5
    Enorm_bulge_bound = -0.75
    halo_mask = (Jz_Jcirc < Jz_Jcirc_halo_bound) &\
                (Enorm > Enorm_bulge_bound)
    
    # Compute the Jeans equation quantities
    Js,rs,qs = calculate_spherical_jeans(orbs[halo_mask],pot[halo_mask],
        n_bootstrap=1, r_range=[0,50], n_bin=nbins, 
        norm_by_nuvr2_r=norm_by_nuvr2_r,
        norm_by_galpy_scale_units=norm_by_galpy_scale_units)

    # Make the Jeans equation figure
    fig,axs = plot_jeans_diagnostics(Js,rs,qs)
    fig.suptitle('primary z=0 subfind id: '+str(primary_z0_subfind_id))
    fig.tight_layout()
    figname = './fig/jeans_diagnostics_primary_halo_'+str(primary_z0_subfind_id)+'.png'
    fig.savefig(figname, dpi=300)
    # plt.close(fig)
    fig.show()
    

In [None]:
nbins = 10
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False
snap_analyze = np.array([99,91,84,78,72,67,59,50,40,33,25,21])
redshift_analyze = np.array([0.,0.1,0.2,0.3,0.4,0.5,0.7,1.0,1.5,2.0,3.0,4.0])
n_snap_analyze = len(snap_analyze)

for i in range(n_mw):
    if i > n_do: continue

    # Get primary particle properties
    major_dict = all_major_list[i]
    major_list = major_dict['major_list']
    analyze_mask = np.isin(major_list[0]['snaps'],snap_analyze)
    assert (major_list[0]['snaps'][analyze_mask] == snap_analyze).all()
    primary_snaps = major_list[0]['snaps'][analyze_mask]
    primary_redshift = np.zeros(n_snap_analyze)
    primary_subfind_ids = major_list[0]['subfind_ids'][analyze_mask]
    primary_z0_subfind_id = major_dict['primary_z0_subfind_id']
    print('Analyzing primary z=0 subfind id: '+str(primary_z0_subfind_id))
    print('------------------------------------------')

    # Initialize arrays
    J_abs_weighted = np.zeros(n_snap_analyze)
    J_values = np.zeros((n_snap_analyze,nbins))

    for j in range(n_snap_analyze):
        primary_path = data_dir+'cutouts/snap_'+str(primary_snaps[j])+'/'
        primary_filename = primary_path+'cutout_'+str(primary_subfind_ids[j])+'.hdf5'
        print('Analyzing primary subfind id: '+str(primary_subfind_ids[j]))
    
        # Make the TNGCutout instance
        cutout = pcutout.TNGCutout(primary_filename)
        primary_redshift[j] = cutout.header['Redshift']
        # Bounding radii for centering / rectifying in kpc
        vcen_rmin = 0.
        vcen_rmax = 5.
        rot_rmin = 2.
        rot_rmax = 10.
        # Center and rectify
        cutout.center_and_rectify(cen_ptype='PartType4', vcen_ptype='PartType4',
            vcen_rmin=vcen_rmin, vcen_rmax=vcen_rmax, rot_ptype='PartType4', 
            rot_rmin=rot_rmin, rot_rmax=rot_rmax)
    
        # Get properties, energy, angular momentum
        orbs = cutout.get_orbs('PartType4')
        rs = orbs.r().value
        vels = cutout.get_velocities('PartType4', physical=True)
        pot = cutout.get_potential_energy('PartType4', physical=True)
        kin = 0.5*np.sum(np.square(vels),axis=1)
        E = (pot+kin)
        J,Jz,Jp = cutout.get_J_Jz_Jp('PartType4',physical=True)
        cutout.get_E_Jcirc_spline('PartType4',angmom='J')
        Jcirc = cutout.Jcirc(E)
        Enorm = E/np.abs(E).max()
        Jz_Jcirc = Jz / Jcirc
        Jp_Jcirc = Jp / Jcirc
        
        # Mask the halo using these quantities
        Jz_Jcirc_halo_bound = 0.5
        Enorm_bulge_bound = -0.75
        halo_mask = (Jz_Jcirc < Jz_Jcirc_halo_bound) &\
                    (Enorm > Enorm_bulge_bound)
    
        # Compute the Jeans equation quantities
        Js,rs,qs = calculate_spherical_jeans(orbs[halo_mask],pot[halo_mask],
            n_bootstrap=1, r_range=[0,50], n_bin=nbins, 
            norm_by_nuvr2_r=norm_by_nuvr2_r,
            norm_by_galpy_scale_units=norm_by_galpy_scale_units)
        
        J_weights = qs[2]*rs**2 # density * r^2
        J_abs_weighted[j] = np.sum(np.abs(Js)*J_weights)/np.sum(J_weights)
        J_values[j,:] = Js

    # Make the Jeans equation figure
    fig = plt.figure()
    axs = fig.subplots(nrows=2,ncols=1)
    axs[0].plot(primary_redshift,J_abs_weighted,'o-')
    axs[0].axhline(0.,color='k',ls='--')
    axs[0].set_xlabel('Redshift')
    axs[0].set_ylabel(r'weighted $|J_{0}|$')

    for j in range(n_snap_analyze):
        axs[1].plot(rs,J_values[j,:],'-',color=matplotlib.cm.rainbow(j/n_snap_analyze))
    axs[1].axhline(0.,color='k',ls='--')
    axs[1].set_xlabel('r [kpc]')
    axs[1].set_ylabel(r'$J_{0}$')

    fig.suptitle('primary z=0 subfind id: '+str(primary_z0_subfind_id))
    fig.tight_layout()
    # figname = './fig/jeans_diagnostics_primary_halo_'+str(primary_z0_subfind_id)+'.png'
    # fig.savefig(figname, dpi=300)
    # plt.close(fig)
    fig.show()

# Plot some kinematics and distributions of individual merger remnants

In [None]:
# Get all particle IDs that belong with each major merger progenitor

unique_particle_ids = []

for i in range(n_mw): # Loop over MW analogs
    if i > n_do: continue
    
    major_dict = all_major_list[i]
    major_list = major_dict['major_list']
    n_major = major_dict['n_major']
    
    # Hold star particle IDs for this MW analog
    analog_unique_particle_ids = []
    
    for j in range(n_major+1): # Loop over majors + primary (index 0)
        if j == 0:
            assert major_list[j]['is_primary'], 'Index 0 not primary'
            continue
        ###j
        major_snaps = major_list[j]['snaps']
        major_subfind_ids = major_list[j]['subfind_ids']
        major_nsnaps = len(major_snaps)
        
        # Hold star particle IDs for this major merger progenitor
        major_unique_particle_ids = np.array([],dtype=int) 
        
        for k in range(major_nsnaps):
            
            snap_filename = data_dir+'cutouts/snap_'+str(major_snaps[k])+\
                '/cutout_'+str(int(major_subfind_ids[k]))+'.hdf5'
            f = h5py.File(snap_filename,'r')
            try:
                major_unique_particle_ids = np.unique(np.concatenate((
                    major_unique_particle_ids,
                    np.array(f['PartType4']['ParticleIDs'],dtype=int))))
            except KeyError:
                pass
            f.close()
            
        analog_unique_particle_ids.append(np.sort(major_unique_particle_ids))
    unique_particle_ids.append(analog_unique_particle_ids)

In [None]:
nbins = 10
norm_by_nuvr2_r = True
norm_by_galpy_scale_units = False
snap_analyze = np.array([99,91,84,78,72,67,59,50])
redshift_analyze = np.array([0.,0.1,0.2,0.3,0.4,0.5,0.7,1.0])
n_snap_analyze = len(snap_analyze)

for i in range(n_mw):
    if i > n_do: continue

    # Get primary particle properties
    major_dict = all_major_list[i]
    major_list = major_dict['major_list']
    analyze_mask = np.isin(major_list[0]['snaps'],snap_analyze)
    assert (major_list[0]['snaps'][analyze_mask] == snap_analyze).all()
    primary_snaps = major_list[0]['snaps'][analyze_mask]
    primary_redshift = np.zeros(n_snap_analyze)
    primary_subfind_ids = major_list[0]['subfind_ids'][analyze_mask]
    primary_z0_subfind_id = major_dict['primary_z0_subfind_id']
    n_major = major_dict['n_major']
    print('Analyzing primary z=0 subfind id: '+str(primary_z0_subfind_id))
    print('------------------------------------------')

    # Initialize arrays
    J_abs_weighted = np.zeros((n_major,n_snap_analyze))
    J_values = np.zeros((n_major,n_snap_analyze,nbins))

    for j in range(n_snap_analyze):
        primary_path = data_dir+'cutouts/snap_'+str(primary_snaps[j])+'/'
        primary_filename = primary_path+'cutout_'+str(primary_subfind_ids[j])+'.hdf5'
        print('Analyzing primary subfind id: '+str(primary_subfind_ids[j]))
    
        # Make the TNGCutout instance
        cutout = pcutout.TNGCutout(primary_filename)
        primary_redshift[j] = cutout.header['Redshift']
        # Bounding radii for centering / rectifying in kpc
        vcen_rmin = 0.
        vcen_rmax = 5.
        rot_rmin = 2.
        rot_rmax = 10.
        # Center and rectify
        cutout.center_and_rectify(cen_ptype='PartType4', vcen_ptype='PartType4',
            vcen_rmin=vcen_rmin, vcen_rmax=vcen_rmax, rot_ptype='PartType4', 
            rot_rmin=rot_rmin, rot_rmax=rot_rmax)
        f_primary = h5py.File(primary_filename)
        primary_particle_ids = np.array(f_primary['PartType4']['ParticleIDs'])
        f_primary.close()
    
        # Get properties, energy, angular momentum
        orbs = cutout.get_orbs('PartType4')
        rs = orbs.r().value
        vels = cutout.get_velocities('PartType4', physical=True)
        pot = cutout.get_potential_energy('PartType4', physical=True)

        # Now parse secondary major mergers
        major_list = major_dict['major_list']
        n_major = major_dict['n_major']
        
        for k in range(n_major+1): # Loop over majors + primary (index 0)
            if k == 0:
                assert major_list[k]['is_primary'], 'Index 0 not primary'
                continue
            
            this_unique_ids = unique_particle_ids[i][k-1]
            unique_ids_in_primary = np.array([])
            this_mlpid = major_list[k]['mlpid']
            n_unique = len(this_unique_ids)
            
            primary_particle_ids_argsort = np.argsort(primary_particle_ids)
            primary_particle_ids_sorted = primary_particle_ids[primary_particle_ids_argsort]
            where_this_unique_ids_sorted = np.searchsorted(
                primary_particle_ids_sorted, this_unique_ids)
            where_this_unique_ids = np.take(primary_particle_ids_argsort, 
                where_this_unique_ids_sorted, mode='clip')
            mask = primary_particle_ids[where_this_unique_ids] !=\
                this_unique_ids
            result = np.ma.array(where_this_unique_ids, mask=mask)
            where_merg = result.data[~result.mask].astype(int)

            print('Plotting '+str(this_mlpid)+'...')
            # Compute the Jeans equation quantities
            Js,rs,qs = calculate_spherical_jeans(orbs[where_merg],pot[where_merg],
                n_bootstrap=1, r_range=[0,50], n_bin=nbins, 
                norm_by_nuvr2_r=norm_by_nuvr2_r,
                norm_by_galpy_scale_units=norm_by_galpy_scale_units)
            
            J_weights = qs[2]*rs**2 # density * r^2
            J_abs_weighted[k-1,j] = np.sum(np.abs(Js)*J_weights)/np.sum(J_weights)
            J_values[k-1,j,:] = Js

    for l in range(n_major):
        # Make the Jeans equation figure
        fig = plt.figure()
        axs = fig.subplots(nrows=2,ncols=1)
        axs[0].plot(primary_redshift,J_abs_weighted[l],'o-')
        axs[0].axhline(0.,color='k',ls='--')
        axs[0].set_xlabel('Redshift')
        axs[0].set_ylabel(r'weighted $|J_{0}|$')

        for m in range(n_snap_analyze):
            axs[1].plot(rs,J_values[l,m,:],'-',color=matplotlib.cm.rainbow(m/n_snap_analyze))
        axs[1].axhline(0.,color='k',ls='--')
        axs[1].set_xlabel('r [kpc]')
        axs[1].set_ylabel(r'$J_{0}$')

        fig.suptitle('primary z=0 subfind id: '+str(primary_z0_subfind_id)+', major merger: '+str(l+1))
        fig.tight_layout()
        # figname = './fig/jeans_diagnostics_primary_halo_'+str(primary_z0_subfind_id)+'.png'
        # fig.savefig(figname, dpi=300)
        # plt.close(fig)
        fig.show()