In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 1_construct_anisotropic_dfs.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Use density profile fits to construct anisotropic DFs
'''

__author__ = "James Lane"

In [None]:
# Control jax platform
import jax
jax.config.update('jax_platform_name', 'cpu')

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle
import pdb, copy, glob, time, warnings, logging

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu
from astropy import constants as apc

## Analysis
import scipy.stats
import scipy.interpolate

## galpy
from galpy import orbit
from galpy import potential
from galpy import actionAngle as aA
from galpy import df
from galpy import util as gputil

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import io as pio
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/5_compare_distribution_functions/1_construct_anisotropic_dfs/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

## Construct and save the constant anisotropy distribution functions

In [None]:
### Some keywords and properties
force_df = True
test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500
stellar_halo_densfunc = pdens.TwoPowerSpherical()

# Anisotropy information
df_type = 'constant_beta'
anisotropy_fit_version = 'anisotropy_params_softening'
anisotropy_ncut = 500

# DF versioning
df_version = 'df_density_softening'

# Ignore some standard warnings
warnings.filterwarnings(action='ignore', 
    message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', 
    message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', 
    message='invalid value encountered', category=RuntimeWarning)

# Begin logging
os.makedirs('./log/',exist_ok=True)
log_filename = './log/1_construct_constant_beta_dfs.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning constant anisotropy DF creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir,df_type,df_version,
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_filename = os.path.join(this_fitting_dir,'df.pkl')
        if os.path.exists(df_filename) and not force_df:
            print('  Already have DF, continuing')
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Get the stellar halo beta information
        anisotropy_param_dir = os.path.join(df_fitting_dir, df_type,
                anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        anisotropy_filename = os.path.join(anisotropy_param_dir,'sampler.pkl')
        beta = pio.median_params_from_emcee_sampler(anisotropy_filename,
            ncut=anisotropy_ncut)[0][0]

        # Construct the distribution function and do some dummy sampling
        # to set the interpolators. Then save.
        try:
            if verbose:
                msg = f'Beta: {round(beta,3)}, building DF'
                logging.info(msg)
                print(msg)
            dfcb = df.constantbetadf(pot=interpot, denspot=denspot, beta=beta, ro=ro, 
                vo=vo, rmax=rs.max()*apu.kpc*1.1)
            if verbose:
                msg = 'Sampling DF'
                logging.info(msg)
                print(msg)
            _ = dfcb.sample(n=100, rmin=rs.min()*apu.kpc*0.9)
        except Exception as e:
            msg = f'Failed to build DF, skipping. Error: {e}'
            logging.info(msg)
            print(msg)

        # Filename built above
        if test_pickling:
            try:
                pickle.loads(pickle.dumps(dfcb))
            except RecursionError:
                if verbose:
                    msg = 'Caught recursion error when (un)pickling, quiting.'
                    logging.info(msg)
                    print(msg)
                sys.exit()
        with open(df_filename,'wb') as handle:
            pickle.dump(dfcb,handle)
        
        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)

warnings.resetwarnings()

### Also pickle the f(E) interpolator separately because it's the expensive thing to create

In [None]:
for i in range(n_mw):
    # if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_major = primary.n_major_mergers

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            print(f'Saving f(E) interpolator for major merger {j+1}/{n_major}'
                  f' of MW {i+1}/{n_mw}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir,df_type,
            str(z0_sid),'merger_'+str(j+1))
        df_filename = os.path.join(this_fitting_dir,'df.pkl')

        if not os.path.exists(df_filename):
            print('Pickled DF not found, skipping')
            continue

        with open(df_filename, 'rb') as handle:
            dfcb = pickle.load(handle)

        # Pull the f(E) interpolator
        fE_interp = dfcb._fE_interp
        fE_interp_filename = os.path.join(this_fitting_dir, 'fE_interp.pkl')
        with open(fE_interp_filename, 'wb') as handle:
            pickle.dump(fE_interp, handle)

## Construct and save the Osipkov-Merritt distribution functions

In [None]:
### Some keywords and properties
force_df = True
test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500
stellar_halo_densfunc = pdens.TwoPowerSpherical()

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Anisotropy information
df_type = 'osipkov_merritt'
anisotropy_fit_version = 'anisotropy_params_softening'
anisotropy_ncut = 500

# DF versioning
df_version = 'df_density_softening'

# Ignore some standard warnings
warnings.filterwarnings(action='ignore', 
    message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', 
    message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', 
    message='invalid value encountered', category=RuntimeWarning)

# Begin logging
log_filename = './log/1_construct_osipkov_merritt_dfs.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning Osipkov-Merritt DF creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir,df_type,df_version,
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_filename = os.path.join(this_fitting_dir,'df.pkl')
        if os.path.exists(df_filename) and not force_df:
            print('  Already have DF, continuing')
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Get the stellar halo beta information
        anisotropy_param_dir = os.path.join(df_fitting_dir, df_type,
                anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        anisotropy_filename = os.path.join(anisotropy_param_dir,'sampler.pkl')
        ra = pio.median_params_from_emcee_sampler(anisotropy_filename,
            ncut=anisotropy_ncut)[0][0]

        # Construct the distribution function and do some dummy sampling
        # to set the interpolators. Then save.
        try:
            if verbose:
                msg = f'Beta: {round(beta,3)}, building DF'
                logging.info(msg)
                print(msg)
            dfom = df.osipkovmerrittdf(pot=interpot, denspot=denspot, 
                ra=ra*apu.kpc, ro=ro, vo=vo, rmax=rs.max()*apu.kpc*1.1)
            print('  Sampling DF')
            _ = dfom.sample(n=100, rmin=rs.min()*apu.kpc*0.9)
        except Exception as e:
            msg = f'Failed to build DF, skipping. Error: {e}'
            logging.info(msg)
            print(msg)

        # Filename built above
        if test_pickling:
            try:
                pickle.loads(pickle.dumps(dfom))
            except RecursionError:
                if verbose:
                    msg = 'Caught recursion error when (un)pickling, quiting.'
                    logging.info(msg)
                    print(msg)
        with open(df_filename,'wb') as handle:
            pickle.dump(dfom,handle)

        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)


warnings.resetwarnings()

### Also pickle the f(Q) interpolator separately because it's the expensive thing to create

In [None]:
for i in range(n_mw):
    # if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_major = primary.n_major_mergers

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            print(f'Saving f(Q) interpolator for major merger {j+1}/{n_major}'
                  f' of MW {i+1}/{n_mw}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir,df_type,
            str(z0_sid),'merger_'+str(j+1))
        df_filename = os.path.join(this_fitting_dir,'df.pkl')

        if not os.path.exists(df_filename):
            print('Pickled DF not found, skipping')
            continue

        with open(df_filename, 'rb') as handle:
            dfom = pickle.load(handle)

        # Pull the f(Q) interpolator
        fQ_interp = dfom._logfQ_interp
        fQ_interp_filename = os.path.join(this_fitting_dir, 'fQ_interp.pkl')
        with open(fQ_interp_filename, 'wb') as handle:
            pickle.dump(fQ_interp, handle)