In [None]:
# ------------------------------------------------------------------------
#
# TITLE - fit_ossipkov_merrit_df.ipynb
# AUTHOR - James Lane
# PROJECT - sample_project
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Try and do some fits with Ossipkov-Merrit distribution functions.
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, glob, subprocess, warnings, dill as pickle

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu
from astropy.coordinates import SkyCoord

## Scipy
import scipy.optimize

## galpy
# from galpy import orbit
# from galpy import potential
# from galpy import actionAngle as aA
# from galpy import df

## Project-specific
sys.path.insert(0,'../../src/')
import tng_dfs.util as putil
import tng_dfs.kinematics as pkin
import tng_dfs.cutout as pcutout
import tng_dfs.tree as ptree

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = putil.parse_config_dict(cdict,keywords)

fig_dir = './fig/fit_ossipkov_merrit_dfs/'
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

## Make some plots of Ossipkov-Merrit DF anisotropy

### Standard anisotropy as a function of radius

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

rs = np.logspace(-2,2,num=1001)
ra = np.logspace(-1,1,num=10)

for i in range(len(ra)):
    beta = pkin.beta_ossipkov_merrit(rs,ra=ra[i])
    ax.plot(rs, beta, color='Black')
    ax.axvline(ra[i], color='Black', linestyle='dashed')

# cbar = plt.colorbar()
ax.set_xlabel('r [kpc]')
ax.set_ylabel('beta')
ax.set_xscale('log')

fig.show()

### Flexible anisotropy as a function of radius

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

rs = np.logspace(-2,2,num=1001)
ra = 1
alpha = np.linspace(-0.5,0.5,num=10)

for i in range(len(alpha)):
    beta = pkin.beta_any_alpha_cuddeford91(rs,ra=ra,alpha=alpha[i])
    ax.plot(rs, beta, color='Black')
    ax.axhline(alpha[i], color='Black', linestyle='dashed')

ax.set_xlabel('r [kpc]')
ax.set_ylabel('beta')
ax.set_xscale('log')
ax.axvline(ra, color='Black', linestyle='dotted')

fig.show()

### Load some data and try and do some fits at $z=0$

What likelihood to use? Perhaps just use weighted least squares on the binned 
data to begin with


Then compute the reduced chi square statistic on the binned data.

$\Chi^{2}_{\nu} = \Chi^{2}/\nu$

$\nu = m - n$


$n$ is the number of observations, $m$ is the number of fitted parameters

In [None]:
with open('../parse_sublink_trees/data/tree_primaries.pkl','rb') as handle:
    tree_primaries = pickle.load(handle)

with open('../parse_sublink_trees/data/tree_major_mergers.pkl','rb') as handle:
    tree_major_mergers = pickle.load(handle)

# Number of primary MW analogs under consideration
n_mw = len(tree_primaries)

In [None]:
# Get all particle IDs that belong with each major merger progenitor
unique_particle_ids = []

for i in range(n_mw): # Loop over MW analogs
    # if i > n_do: continue
    
    major_dict = all_major_list[i]
    major_list = major_dict['major_list']
    n_major = major_dict['n_major']
    
    # Hold star particle IDs for this MW analog
    analog_unique_particle_ids = []
    
    for j in range(n_major+1): # Loop over majors + primary (index 0)
        if j == 0:
            assert major_list[j]['is_primary'], 'Index 0 not primary'
            continue
        ###j
        major_snaps = major_list[j]['snaps']
        major_subfind_ids = major_list[j]['subfind_ids']
        major_nsnaps = len(major_snaps)
        
        # Hold star particle IDs for this major merger progenitor
        major_unique_particle_ids = np.array([],dtype=int) 
        
        for k in range(major_nsnaps):
            
            snap_filename = data_dir+'cutouts/snap_'+str(major_snaps[k])+\
                '/cutout_'+str(int(major_subfind_ids[k]))+'.hdf5'
            f = h5py.File(snap_filename,'r')
            try:
                major_unique_particle_ids = np.unique(np.concatenate((
                    major_unique_particle_ids,
                    np.array(f['PartType4']['ParticleIDs'],dtype=int))))
            except KeyError:
                pass
            f.close()
            
        analog_unique_particle_ids.append(np.sort(major_unique_particle_ids))
    unique_particle_ids.append(analog_unique_particle_ids)

In [None]:
major_merger.secondary_mlpid

In [None]:
# Make plots of the path of the major mergers compared with the primary, 
# checks for consistency.
verbose = True
this_fig_dir = fig_dir+'major_merger_fits/'
os.makedirs(this_fig_dir,exist_ok=True)
this_data_dir = './data/betas/'
os.makedirs(this_data_dir,exist_ok=True)
r_range = [0,50]
n_bin = 7
n_bs = 25
show_plots = True

unique_particle_ids = []

for i in range(n_mw):
    # if i > 0: continue
    if verbose: print(f'Plotting MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers

    co = pcutout.TNGCutout(
        primary.get_cutout_filename(mw_analog_dir,snapnum=primary.snapnum[0]))
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Loop over the major mergers, collect unique particle IDs for plotting
    _unique_particle_ids = []
    for j in range(n_major):
        if verbose: print(f'Plotting merger {j+1}/{n_major}')

        # Get the major merger
        #print(primary.tree_major_mergers[j])
        major_merger = primary.tree_major_mergers[j]
        _unique_particle_ids.append(
            major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
            )

        # # Get the indices of the unique particles in the z=0 snapshot
        pid = co.get_property('stars','ParticleIDs')
        indx = np.where(np.isin(pid,_unique_particle_ids[j]))[0]
        orbs = co.get_orbs('stars')[indx]
        pe = co.get_potential_energy('stars')[indx]

        Js,rs,qs = pkin.calculate_spherical_jeans(orbs,pot=None,pe=pe,
            n_bootstrap=n_bs,r_range=r_range,n_bin=n_bin)
        beta = 1-(qs[4]+qs[5])/(2*qs[3])
        rs = rs[0]
        lbeta,mbeta,ubeta = np.percentile(beta,[16,50,84],axis=0)
        sbeta = (ubeta-lbeta)/2
        
        filename = this_data_dir+str(major_merger.secondary_mlpid)+'.pkl'
        with open(filename,'wb') as handle:
            pickle.dump([rs,mbeta,sbeta,lbeta,ubeta],handle)

        # # Make the figure
        # fig = plt.figure()
        # ax = fig.add_subplot(111)
        # ax.plot(rs, mbeta, color='Black')
        # ax.fill_between(rs, mbeta-sbeta, mbeta+sbeta, color='Black', alpha=0.5)

        # # Do the fit
        # try:
        #     popt,pcov = scipy.optimize.curve_fit(pkin.beta_any_alpha_cuddeford91,
        #         rs, mbeta, sigma=sbeta, absolute_sigma=True, p0=[1,0], 
        #         maxfev=1000)
        #     ra,alpha = popt
        #     ra_err,alpha_err = np.sqrt(np.diag(pcov))
        #     tbeta = pkin.beta_any_alpha_cuddeford91(rs,ra,alpha)
        #     ax.plot(rs, tbeta, color='Red', linestyle='dashed')
        # except RuntimeError:
        #     # Plot text saying fit did not converge
        #     ax.text(0.5,0.5,'Fit did not converge',transform=ax.transAxes,
        #         horizontalalignment='center',verticalalignment='center')

        # # fig.savefig(this_fig_dir+str(z0_sid)+'_'+str(j)+'.png',dpi=300)
        # if not show_plots: plt.close(fig)

    unique_particle_ids.append(_unique_particle_ids)

In [None]:
# Make plots of the path of the major mergers compared with the primary, 
# checks for consistency.
verbose = True
this_fig_dir = fig_dir+'major_merger_fits/'
os.makedirs(this_fig_dir,exist_ok=True)
this_data_dir = './data/halo/betas/'
os.makedirs(this_data_dir,exist_ok=True)
r_range = [0,50]
n_bin = 7
n_bs = 25
show_plots = True

unique_particle_ids = []

for i in range(n_mw):
    if i > 10: continue
    if verbose: print(f'Plotting MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    
    # Get properties, energy, angular momentum
    orbs = co.get_orbs('PartType4')
    rs = orbs.r().value
    vels = co.get_velocities('PartType4', physical=True)
    pot = co.get_potential_energy('PartType4', physical=True)
    kin = 0.5*np.sum(np.square(vels),axis=1)
    E = (pot+kin)
    J,Jz,Jp = co.get_J_Jz_Jp('PartType4',physical=True)
    co.get_E_Jcirc_spline('PartType4',angmom='J')
    Jcirc = co.Jcirc(E)
    Enorm = E/np.abs(E).max()
    Jz_Jcirc = Jz / Jcirc
    Jp_Jcirc = Jp / Jcirc
    
    # Mask the halo using these quantities
    Jz_Jcirc_halo_bound = 0.5
    Enorm_bulge_bound = -0.75
    halo_mask = (Jz_Jcirc < Jz_Jcirc_halo_bound) &\
                (Enorm > Enorm_bulge_bound)
    
    orbs = orbs[halo_mask]
    pe = pot[halo_mask]

    Js,rs,qs = pkin.calculate_spherical_jeans(orbs,pot=None,pe=pe,
        n_bootstrap=n_bs,r_range=r_range,n_bin=n_bin)
    beta = 1-(qs[4]+qs[5])/(2*qs[3])
    rs = rs[0]
    lbeta,mbeta,ubeta = np.percentile(beta,[16,50,84],axis=0)
    sbeta = (ubeta-lbeta)/2
    
    filename = this_data_dir+str(z0_sid)+'.pkl'
    with open(filename,'wb') as handle:
        pickle.dump([rs,mbeta,sbeta,lbeta,ubeta],handle)

    
