In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 1_interpolate_primaries.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Compute interpolations of primary radial mass profiles.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, pdb, time, dill as pickle, logging

## Matplotlib
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Analysis
import scipy.interpolate

## galpy
from galpy import potential
from galpy import util as gputil

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/3_fit_density_profiles/1_interpolate_primaries/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Compute the spherical enclosed mass profile of each primary

In [None]:
# Function to compute the radial force from enclosed mass
def galpy_radial_force_from_enclosed_mass(r, menc, ro=ro, vo=vo):
    '''galpy_radial_force_from_enclosed_mass:

    Compute the radial force from an enclosed mass profile in internal galpy 
    units.
    '''
    gmenc = menc/gputil.conversion.mass_in_msol(vo,ro)
    gr = r/ro
    return -gmenc/gr**2

In [None]:
### Some keywords and properties
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
interpot_version = 'all_star_dm_enclosed_mass'
n_thin = 100
verbose = True
force_interpolator = True

# Begin logging
os.makedirs('./log',exist_ok=True)
log_filename = './log/1_interpolate_primaries.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning potential interpolation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    # if i > 0: continue
    if verbose: 
        msg = f'Analyzing MW {i+1}/{n_mw}'
        logging.info(msg)
        print(msg)

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    # Fitting dir, check if already exists
    this_fitting_dir = os.path.join(dens_fitting_dir, 
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    interpolator_filename = os.path.join(this_fitting_dir,
        'interp_potential.pkl')
    if os.path.exists(interpolator_filename) and not force_interpolator:
        if verbose:
            msg = f'Interpolator already exists for {z0_sid}, continuing'
            logging.info(msg)
            print(msg)
        continue 
    
    if verbose:
        msg = 'loading data...'
        logging.info(msg)
        print(msg)
    orbs_star = co.get_orbs('stars')
    orbs_dm = co.get_orbs('dm')
    rs_star = orbs_star.r().to_value(apu.kpc)
    rs_dm = orbs_dm.r().to_value(apu.kpc)
    masses_star = co.get_masses('stars').to_value(apu.Msun)
    masses_dm = co.get_masses('dm').to_value(apu.Msun)
    pe_star = co.get_potential_energy('stars').to_value(apu.km**2/apu.s**2)
    pe_dm = co.get_potential_energy('dm').to_value(apu.km**2/apu.s**2)

    rs = np.concatenate([rs_star,rs_dm])
    # rs = rs_star
    masses = np.concatenate([masses_star,masses_dm])
    # masses = masses_star
    pe = np.concatenate([pe_star,pe_dm])
    # pe = pe_star
    sort_idx = np.argsort(rs)
    rs = rs[sort_idx]
    masses = masses[sort_idx]
    pe = pe[sort_idx]
    menc = np.cumsum(masses)

    rs_thin = rs[::n_thin]
    masses_thin = masses[::n_thin]
    pe_thin = pe[::n_thin]
    menc_thin = menc[::n_thin]
    
    if verbose:
        msg = 'Computing interpolated potential...'
        logging.info(msg)
        print(msg)
    grforce = galpy_radial_force_from_enclosed_mass(rs_thin, menc_thin, ro, vo)
    rforce_interp = scipy.interpolate.interp1d(rs_thin/ro, grforce, 
        kind='linear')
    grs_interp = np.geomspace(rs_thin[0]/ro, rs_thin[-1]/ro, 1001)
    interpot = potential.interpSphericalPotential(rforce_interp, 
        rgrid=rs_thin/ro, Phi0=pe_thin[0]/vo**2, ro=ro, vo=vo)

    # Plot
    this_fig_dir = os.path.join(fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)
    trs = np.logspace(np.log10(rs.min()), np.log10(rs.max()), 100)
    fig = plt.figure()
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)
    ax1.hist(np.log10(rs_star), bins=100, 
        range=(np.log10(trs[0]), np.log10(trs[-1])), log=True, histtype='step', 
        edgecolor='Black')
    ax1.set_ylabel(r'$p(r_{\star})$')
    ax1.set_xlim(np.log10(trs[0]),np.log10(trs[-1]))
             
    ax2.plot(np.log10(rs_thin), np.log10(menc_thin), color='k', 
        label='Enclosed Mass')
    ax2.plot(np.log10(trs), np.log10(interpot.mass(trs*apu.kpc).value), 
        color='Red', linestyle='dashed', label='Interpolated Mass')
    # ax2.set_xscale('log')
    # ax2.set_yscale('log')
    ax2.set_xlabel('log radius [kpc]')
    ax2.set_ylabel('enclosed mass [Msun]')
    ax2.set_xlim(np.log10(trs[0]),np.log10(trs[-1]))
    ax2.legend(loc='best')
    this_figname = os.path.join(this_fig_dir,'enclosed_mass.png')
    fig.savefig(this_figname,dpi=300)
    plt.close(fig)

    # Save the model
    os.makedirs(os.path.dirname(interpolator_filename),exist_ok=True)
    with open(interpolator_filename,'wb') as handle:
        pickle.dump(interpot,handle)
    
    if verbose:
        msg = 'Done with this analog'
        logging.info(msg)
        print(msg)

### Test the potential using nthin=100 vs nthin=1. Test the difference in potential, force, density

In [None]:
# this_fitting_dir =  '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
#                     'density_profile/spherical_interpolated_potential/'+\
#                     'all_star_dm_enclosed_mass/394621/'
# interp1_filename = os.path.join(this_fitting_dir, 'interp_potential_nthin_1.pkl')
# with open(interp1_filename,'rb') as handle:
#     interpot1 = pickle.load(handle)
# interp100_filename = os.path.join(this_fitting_dir, 'interp_potential.pkl')
# with open(interp100_filename,'rb') as handle:
#     interpot100 = pickle.load(handle)

# # Test and plot
# tgrid = np.geomspace(0.1, 100, 1001)*apu.kpc

# fig = plt.figure(figsize=(10,12))
# axs = fig.subplots(3,2)

# fns = [potential.evaluatePotentials,
#        potential.evaluateRforces,
#        potential.evaluateDensities,
#        ]
# labels = ['Potential', 'Radial Force', 'Density']

# for i in range(3):
#     fn = fns[i]
#     label = labels[i]
#     ax1 = axs[i,0]
#     ax2 = axs[i,1]
#     fn1 = fn(interpot1, tgrid, 0)
#     fn100 = fn(interpot100, tgrid, 0)

#     ax1.plot(tgrid, np.abs(fn1), color='Black', label='nthin=1')
#     ax1.plot(tgrid, np.abs(fn100), color='Red', linestyle='dashed', 
#         label='nthin=100')
#     ax1.set_xscale('log')
#     ax1.set_yscale('log')
#     ax1.set_xlabel('radius [kpc]')
#     ax1.set_ylabel(r'$\vert$'+label+r'$\vert$')
#     ax1.legend(loc='best')

#     ax2.plot(tgrid, 100*(fn100-fn1)/fn1, color='Black')
#     ax2.set_xscale('log')
#     ax2.set_xlabel('radius [kpc]')
#     ax2.set_ylabel(r'$\Delta$ '+label+' [per cent]')

# fig.tight_layout()
# fig.show()