In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 1_anisotropic_dfs.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Use density profile fits to construct anisotropic DFs
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle
import pdb, copy, glob, time, warnings

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu
from astropy import constants as apc

## Analysis
import scipy.stats
import scipy.interpolate

## galpy
from galpy import orbit
from galpy import potential
from galpy import actionAngle as aA
from galpy import df
from galpy import util as gputil

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
fig_dir = './fig/sample/'
epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/sample/'
os.makedirs(fig_dir,exist_ok=True)
os.makedirs(epsen_fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Load all the density profile fits and construct the constant anisotropy distribution functions

In [None]:
# Ignore some standard warnings
warnings.filterwarnings(action='ignore', message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', message='invalid value encountered', category=RuntimeWarning)

force_df = True
test_pickling = True
verbose = True
epsen_dens_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/'
epsen_df_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'distribution_function/'

# DM halo information
dm_halo_version = 'poisson_nfw'
dm_halo_ncut = 500

# Stellar bulge and disk information
stellar_bulge_disk_version = 'miyamoto_disk_pswc_bulge_tps_halo'
stellar_bulge_disk_ncut = 2000

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Beta information
beta_version = 'beta_constant'
beta_ncut = 500

# Define density profiles
dm_halo_densfunc = pdens.NFWSpherical()
disk_densfunc = pdens.MiyamotoNagaiDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
stellar_halo_densfunc = pdens.TwoPowerSpherical()
stellar_bulge_disk_densfunc = pdens.CompositeDensityProfile(
    [disk_densfunc,
     bulge_densfunc,
     pdens.TwoPowerSpherical()]
     )

for i in range(n_mw):
    # if i > 0: continue
    if verbose: print(f'Analyzing MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(epsen_dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'  Constructing DF for major merger {j+1}/{n_major}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(epsen_df_fitting_dir,'constant_beta',
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_filename = os.path.join(this_fitting_dir,'df.pkl')
        if os.path.exists(df_filename) and not force_df:
            print('  Already have DF, continuing')
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)

        # # Get the stellar halo density profile
        # stellar_halo_density_dir = os.path.join(epsen_dens_fitting_dir,
        #     'stellar_halo/',stellar_halo_density_version,str(z0_sid),
        #     'merger_'+str(j+1)+'/')
        # stellar_halo_density_filename = os.path.join(stellar_halo_density_dir,
        #     'sampler.pkl')
        # assert os.path.exists(stellar_halo_density_filename)
        # with open(stellar_halo_density_filename,'rb') as handle:
        #     stellar_halo_density_sampler = pickle.load(handle)
        # stellar_halo_density_samples = stellar_halo_density_sampler.get_chain(
        #     discard=stellar_halo_density_ncut, flat=True)
        # # Params are alpha, beta, a, amp
        # stellar_halo_alpha, stellar_halo_beta, stellar_halo_a, stellar_halo_amp = \
        #     np.median(stellar_halo_density_samples,axis=0)
        
        # Get the stellar halo beta information
        beta_dir = os.path.join(epsen_dens_fitting_dir,'stellar_halo/',beta_version,
            str(z0_sid),'merger_'+str(j+1)+'/')
        beta_filename = os.path.join(beta_dir,'sampler.pkl')
        assert os.path.exists(beta_filename)
        with open(beta_filename,'rb') as handle:
            beta_sampler = pickle.load(handle)
        beta_samples = beta_sampler.get_chain(discard=beta_ncut, flat=True)
        beta = np.median(beta_samples,axis=0)[0]
        if beta < -5: 
            print('    Beta < -5, setting beta=-5')
            beta = -5
        if beta >= 1.: 
            print('    Beta >= 1, setting beta=0.9')
            beta = 0.9
        
        # Construct the distribution function and do some dummy sampling
        # to set the interpolators. Then save.
        try:
            print(f'  Beta: {round(beta,3)}')
            print('  Building DF')
            dfcb = df.constantbetadf(pot=interpot, denspot=denspot, beta=beta, ro=ro, 
                vo=vo, rmax=rs.max()*apu.kpc*1.1)
            print('  Sampling DF')
            _ = dfcb.sample(n=100, rmin=rs.min()*apu.kpc*0.9)
        except Exception as e:
            print('Caught an error:',e,'skipping...') 

        # Filename built above
        if test_pickling:
            try:
                pickle.loads(pickle.dumps(dfcb))
            except RecursionError:
                print('Caught recursion error when pickle/unpickling, quiting...')
                sys.exit()
        with open(df_filename,'wb') as handle:
            pickle.dump(dfcb,handle)

warnings.resetwarnings()

### Also pickle the f(E) interpolator separately because it's the expensive thing to create

In [None]:
for i in range(n_mw):
    # if i > 0: continue
    if verbose: print(f'Analyzing MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_major = primary.n_major_mergers

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'  Saving f(E) interpolator for major merger {j+1}/{n_major}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(epsen_df_fitting_dir,'constant_beta',
            str(z0_sid),'merger_'+str(j+1))
        df_filename = os.path.join(this_fitting_dir,'df.pkl')

        if not os.path.exists(df_filename):
            print('Pickled DF not found, skipping')
            continue

        with open(df_filename, 'rb') as handle:
            dfcb = pickle.load(handle)

        # Pull the f(E) interpolator
        fE_interp = dfcb._fE_interp
        fE_interp_filename = os.path.join(this_fitting_dir, 'fE_interp.pkl')
        with open(fE_interp_filename, 'wb') as handle:
            pickle.dump(fE_interp, handle)

### Construct and save the Osipkov-Merritt distribution functions

In [None]:
# Ignore some standard warnings
warnings.filterwarnings(action='ignore', message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', message='invalid value encountered', category=RuntimeWarning)

force_df = True
test_pickling = True
verbose = True
epsen_dens_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/'
epsen_df_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'distribution_function/'

# DM halo information
dm_halo_version = 'poisson_nfw'
dm_halo_ncut = 500

# Stellar bulge and disk information
stellar_bulge_disk_version = 'miyamoto_disk_pswc_bulge_tps_halo'
stellar_bulge_disk_ncut = 2000

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower'
stellar_halo_density_ncut = 500

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Beta information
beta_version = 'beta_osipkov_merritt'
beta_ncut = 500

# Define density profiles
dm_halo_densfunc = pdens.NFWSpherical()
disk_densfunc = pdens.MiyamotoNagaiDisk()
bulge_densfunc = pdens.SinglePowerCutoffSpherical()
stellar_halo_densfunc = pdens.TwoPowerSpherical()
stellar_bulge_disk_densfunc = pdens.CompositeDensityProfile(
    [disk_densfunc,
     bulge_densfunc,
     pdens.TwoPowerSpherical()]
     )

for i in range(n_mw):
    # if i > 0: continue
    if verbose: print(f'Analyzing MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(epsen_dens_fitting_dir,
        'spherical_interpolated_potential/',str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'  Constructing DF for major merger {j+1}/{n_major}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(epsen_df_fitting_dir,'osipkov_merritt',
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_filename = os.path.join(this_fitting_dir,'df.pkl')
        if os.path.exists(df_filename) and not force_df:
            print('  Already have DF, continuing')
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(epsen_dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)

        # # Get the stellar halo density profile
        # stellar_halo_density_dir = os.path.join(epsen_dens_fitting_dir,
        #     'stellar_halo/',stellar_halo_density_version,str(z0_sid),
        #     'merger_'+str(j+1)+'/')
        # stellar_halo_density_filename = os.path.join(stellar_halo_density_dir,
        #     'sampler.pkl')
        # assert os.path.exists(stellar_halo_density_filename)
        # with open(stellar_halo_density_filename,'rb') as handle:
        #     stellar_halo_density_sampler = pickle.load(handle)
        # stellar_halo_density_samples = stellar_halo_density_sampler.get_chain(
        #     discard=stellar_halo_density_ncut, flat=True)
        # # Params are alpha, beta, a, amp
        # stellar_halo_alpha, stellar_halo_beta, stellar_halo_a, stellar_halo_amp = \
        #     np.median(stellar_halo_density_samples,axis=0)
        
        # Get the stellar halo beta information
        om_dir = os.path.join(epsen_dens_fitting_dir,'stellar_halo/',beta_version,
            str(z0_sid),'merger_'+str(j+1)+'/')
        om_filename = os.path.join(om_dir,'sampler.pkl')
        assert os.path.exists(om_filename)
        with open(om_filename,'rb') as handle:
            om_sampler = pickle.load(handle)
        ra_samples = om_sampler.get_chain(discard=beta_ncut, flat=True)
        ra = np.median(ra_samples,axis=0)[0]
        if ra < 0: 
            print('    ra < 0, skipping...')
            continue
        
        # Construct the distribution function and do some dummy sampling
        # to set the interpolators. Then save.
        try:
            print(f'  r_a: {round(ra,3)}')
            print('  Building DF')
            dfom = df.osipkovmerrittdf(pot=interpot, denspot=denspot, 
                ra=ra*apu.kpc, ro=ro, vo=vo, rmax=rs.max()*apu.kpc*1.1)
            print('  Sampling DF')
            _ = dfom.sample(n=100, rmin=rs.min()*apu.kpc*0.9)
        except Exception as e:
            print('Caught an error:',e,'skipping...') 

        # Filename built above
        if test_pickling:
            try:
                pickle.loads(pickle.dumps(dfom))
            except RecursionError:
                print('Caught recursion error when pickle/unpickling, quiting...')
                sys.exit()
        with open(df_filename,'wb') as handle:
            pickle.dump(dfom,handle)

warnings.resetwarnings()

### Also pickle the f(Q) interpolator separately because it's the expensive thing to create

In [None]:
for i in range(n_mw):
    # if i > 0: continue
    if verbose: print(f'Analyzing MW {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_major = primary.n_major_mergers

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'  Saving f(Q) interpolator for major merger {j+1}/{n_major}')

        # Filename and pathing check
        this_fitting_dir = os.path.join(epsen_df_fitting_dir,'osipkov_merritt',
            str(z0_sid),'merger_'+str(j+1))
        df_filename = os.path.join(this_fitting_dir,'df.pkl')

        if not os.path.exists(df_filename):
            print('Pickled DF not found, skipping')
            continue

        with open(df_filename, 'rb') as handle:
            dfom = pickle.load(handle)

        # Pull the f(Q) interpolator
        fQ_interp = dfom._logfQ_interp
        fQ_interp_filename = os.path.join(this_fitting_dir, 'fQ_interp.pkl')
        with open(fQ_interp_filename, 'wb') as handle:
            pickle.dump(fQ_interp, handle)