In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 2_fit_om_2_combination.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Fit Osipkov-Merritt linear combination models to data. First construct a 
grid of such models for each scale radius. Then create a grid of velocity 
dispersion.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle
import pdb, copy, glob, time, warnings, logging, multiprocessing

## Plotting
import matplotlib as mpl
from matplotlib import pyplot as plt
import corner

## Astropy
from astropy import units as apu
from astropy import constants as apc

## Analysis
import scipy.optimize
import scipy.interpolate
import emcee

## galpy
from galpy import potential
from galpy import df

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import densprofile as pdens
from tng_dfs import fitting as pfit
from tng_dfs import io as pio
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/4_fit_distribution_functions/2_fit_om_2_combination/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Create the grid of DF data for fitting

In [None]:
### Some keywords and properties
force_df_grid = False
test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# Potential interpolator version
interpot_version = 'all_star_dm_enclosed_mass'

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500
stellar_halo_densfunc = pdens.TwoPowerSpherical()

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Anisotropy information
df_type = 'osipkov_merritt_2_combination'
anisotropy_fit_version = 'ra_N20_001_to_1000_softening'
# anisotropy_ncut = 500

# DF versioning
# df_version = 'df_density_softening'

# Ignore some standard warnings
warnings.filterwarnings(action='ignore', 
    message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', 
    message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', 
    message='invalid value encountered', category=RuntimeWarning)
warnings.filterwarnings(action='ignore',
    message='subdivisions', category=scipy.integrate.IntegrationWarning)
warnings.filterwarnings(action='ignore',
    message='divergent', category=scipy.integrate.IntegrationWarning)

# Begin logging
log_filename = './log/2_fit_om_2_combination_df_grid.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning Osipkov-Merritt 2 combination DF grid creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

# Construct the r_a grid
log_ra_min = -2
log_ra_max = 3
ra_n = 20
ra = (10**np.linspace(log_ra_min, log_ra_max, ra_n, endpoint=True))*apu.kpc

for i in range(n_mw):
    if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir, df_type, 
            anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_grid_filename = os.path.join(this_fitting_dir,'df_grid.pkl')
        if os.path.exists(df_grid_filename) and not force_df_grid:
            if verbose:
                msg = f'Already have DF grid, continuing'
                logging.info(msg)
                print(msg)
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)
        
        # Get the stellar halo beta information
        # anisotropy_param_dir = os.path.join(df_fitting_dir, df_type,
        #         anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        # anisotropy_filename = os.path.join(anisotropy_param_dir,'sampler.pkl')
        # ra = pio.median_params_from_emcee_sampler(anisotropy_filename,
        #     ncut=anisotropy_ncut)[0][0]

        df_grid = []
        fQ_interp_grid = []

        # Loop over the ra grid and build the DFs
        for k in range(ra_n):
            # Construct the distribution function and do some dummy sampling
            # to set the interpolators. Then save.
            try:
                if verbose:
                    msg = f'ra: {round(ra[k].value,3)} kpc, building DF'
                    logging.info(msg)
                    print(msg)
                dfom = df.osipkovmerrittdf(pot=interpot, denspot=denspot, 
                    ra=ra[k], ro=ro, vo=vo, rmax=rs.max()*apu.kpc*1.1)
                print('  Sampling DF')
                _ = dfom.sample(n=100, rmin=rs.min()*apu.kpc*0.9)
            except Exception as e:
                msg = f'Failed to build DF, skipping. Error: {e}'
                logging.info(msg)
                print(msg)

            # Filename built above
            if test_pickling:
                try:
                    pickle.loads(pickle.dumps(dfom))
                except RecursionError:
                    if verbose:
                        msg = 'Caught recursion error when (un)pickling, quiting.'
                        logging.info(msg)
                        print(msg)
            # with open(df_filename,'wb') as handle:
            #     pickle.dump(dfom,handle)
            df_grid.append(dfom)
            fQ_interp_grid.append(dfom._logfQ_interp)

        # Save the grid
        df_grid_filename = os.path.join(this_fitting_dir,'df_grid.pkl')
        with open(df_grid_filename,'wb') as handle:
            pickle.dump([df_grid,fQ_interp_grid,ra],handle)

        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)


warnings.resetwarnings()

### Compute a grid of velocity dispersion values, the grid is value of ra versus r/ra

Use the `fQ` trick. Assign `dfom.fQ = lambda Q: np.exp(dfom._logfQ_interp(Q))`

In [None]:
def make_dispersion_grid_plot(sigma_r_grid, sigma_t_grid, r_ra, ra):
    '''make_dispersion_grid_plot:

    Make a plot of the dispersion grid.
    '''
    fig = plt.figure(figsize=(8,12))
    axs = fig.subplots(3,1)
    axs[0].pcolormesh(r_ra.value, ra.value, sigma_r_grid,
        cmap='magma',vmin=0,vmax=400)
    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$r$ [kpc]')
    axs[0].set_ylabel(r'$r_a$ [kpc]')
    axs[0].set_title(r'$\sigma_r$ [km/s]')
    cbar1 = fig.colorbar(axs[0].collections[0], ax=axs[0])
    axs[1].pcolormesh(r_ra.value, ra.value, sigma_t_grid,
        cmap='magma',vmin=0,vmax=400)
    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$r$ [kpc]')
    axs[1].set_ylabel(r'$r_a$ [kpc]')
    axs[1].set_title(r'$\sigma_T$ [km/s]')
    cbar2 = fig.colorbar(axs[1].collections[0], ax=axs[1])
    beta_grid = 1-(sigma_t_grid**2)/(2*sigma_r_grid**2)
    axs[2].pcolormesh(r_ra.value, ra.value, beta_grid,
        cmap='magma',vmin=0,vmax=1)
    axs[2].set_xscale('log')
    axs[2].set_yscale('log')
    axs[2].set_xlabel(r'$r$ [kpc]')
    axs[2].set_ylabel(r'$r_a$ [kpc]')
    axs[2].set_title(r'$\beta$')
    cbar3 = fig.colorbar(axs[2].collections[0], ax=axs[2])
    fig.tight_layout()
    return fig,axs

def make_interpolated_dispersion_grid_plot(sigma_r_interp, sigma_t_interp,
    r_ra, ra):
    '''make_interpolated_dispersion_grid_plot:
    
    Make a plot of the interpolated dispersion grid.
    '''
    xnew = 10**np.linspace(np.log10(ra.min().value), 
        np.log10(ra.max().value), 100)
    ynew = 10**np.linspace(np.log10(r_ra.min().value), 
        np.log10(r_ra.max().value), 100)
    xx,yy = np.meshgrid(xnew,ynew,indexing='ij')
    zz_r = sigma_r_interp((xx,yy)).T
    zz_t = sigma_t_interp((xx,yy)).T
    zz_beta = 1-(zz_t**2)/(2*zz_r**2)

    fig = plt.figure(figsize=(8,12))
    axs = fig.subplots(3,1)
    axs[0].pcolormesh(xx,yy,zz_r,cmap='magma',vmin=0,vmax=400)
    axs[0].set_xscale('log')
    axs[0].set_yscale('log')
    axs[0].set_xlabel(r'$r$ [kpc]')
    axs[0].set_ylabel(r'$r_a$ [kpc]')
    axs[0].set_title(r'$\sigma_r$ [km/s]')
    cbar1 = fig.colorbar(axs[0].collections[0], ax=axs[0])
    axs[1].pcolormesh(xx,yy,zz_t,cmap='magma',vmin=0,vmax=400)
    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'$r$ [kpc]')
    axs[1].set_ylabel(r'$r_a$ [kpc]')
    axs[1].set_title(r'$\sigma_T$ [km/s]')
    cbar2 = fig.colorbar(axs[1].collections[0], ax=axs[1])
    axs[2].pcolormesh(xx,yy,zz_beta,cmap='magma',vmin=0,vmax=1)
    axs[2].set_xscale('log')
    axs[2].set_yscale('log')
    axs[2].set_xlabel(r'$r$ [kpc]')
    axs[2].set_ylabel(r'$r_a$ [kpc]')
    axs[2].set_title(r'$\beta$')
    cbar3 = fig.colorbar(axs[2].collections[0], ax=axs[2])
    fig.tight_layout()
    return fig,axs

In [None]:
### Some keywords and properties
force_r_ra_grid = True
# test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# Potential interpolator version
interpot_version = 'all_star_dm_enclosed_mass'

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500
stellar_halo_densfunc = pdens.TwoPowerSpherical()

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Anisotropy information
df_type = 'osipkov_merritt_2_combination'
anisotropy_fit_version = 'ra_N10_01_to_300_softening'
# anisotropy_ncut = 500

# Begin logging
log_filename = './log/2_fit_om_2_combination_vdisp_grid.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning Osipkov-Merritt 2 combination velocity dispersion'+\
             ' grid creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir, df_type, 
            anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        sigma_grid_filename = os.path.join(this_fitting_dir,'sigma_grid.pkl')
        if os.path.exists(sigma_grid_filename) and not force_r_ra_grid:
            if verbose:
                msg = f'Already have dispersion grids, continuing'
                logging.info(msg)
                print(msg)
            continue
        
        # Load the DF grid
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_grid_filename = os.path.join(this_fitting_dir,'df_grid.pkl')
        with open(df_grid_filename,'rb') as handle:
            df_grid,fQ_interp_grid,ra = pickle.load(handle)
        ra_n = len(ra)

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)

        # Make a grid of r/ra where the individual velocity dispersions 
        # will be calculated
        log_r_ra_min = np.log10( putil.get_softening_length('stars') )
        log_r_ra_max = np.log10(rs.max())
        r_ra_n = 20
        r_ra = (10**np.linspace(log_r_ra_min, log_r_ra_max, r_ra_n,
            endpoint=True))*apu.kpc

        sigma_r_grid = np.zeros((ra_n,r_ra_n))
        sigma_t_grid = np.zeros((ra_n,r_ra_n))
        # Loop over the ra and r/ra grid and compute the dispersions
        if verbose:
            msg = f'Computing velocity dispersions'
            logging.info(msg)
            print(msg)
        for k in range(ra_n):

            dfom =  pkin.reconstruct_anisotropic_df(df_grid[k],
                interpot, denspot)
            # Do the fQ -> interpolator trick
            dfom.fQ = lambda Q: np.exp(dfom._logfQ_interp(Q))

            dens = np.array([
                dfom.vmomentdensity(r, 0, 0).value for r in r_ra
                ])/apu.kpc**3
            svr = np.sqrt(np.array([
                dfom.vmomentdensity(r, 2, 0).value for r in r_ra
                ])/dens)
            svt = np.sqrt(np.array([
                dfom.vmomentdensity(r, 0, 2).value for r in r_ra
                ])/dens)

            sigma_r_grid[k,:] = svr
            sigma_t_grid[k,:] = svt

        # Save the grid
        sigma_grid_filename = os.path.join(this_fitting_dir,'sigma_grid.pkl')
        with open(sigma_grid_filename,'wb') as handle:
            pickle.dump([sigma_r_grid,sigma_t_grid,ra,r_ra],handle)

        # Make some plots
        this_fig_dir = os.path.join(fig_dir, str(z0_sid), 'merger_'+str(j+1))
        os.makedirs(this_fig_dir,exist_ok=True)
        
        # Plot the dispersion grid
        if verbose:
            msg = f'Making dispersion grid plot'
            logging.info(msg)
            print(msg)
        fig, axs = make_dispersion_grid_plot(sigma_r_grid, sigma_t_grid,
            r_ra, ra)
        fig.savefig(os.path.join(this_fig_dir,'sigma_r_t_grid.png'))
        plt.close(fig)

        # Make a plot of the interpolated dispersions
        if verbose:
            msg = f'Making interpolated dispersion grid plot'
            logging.info(msg)
            print(msg)
        
        sigma_r_interp = scipy.interpolate.RegularGridInterpolator(
            (ra.value,r_ra.value), sigma_r_grid, method='linear',
            bounds_error=False, fill_value=None)
        sigma_t_interp = scipy.interpolate.RegularGridInterpolator(
            (ra.value,r_ra.value), sigma_t_grid, method='linear',
            bounds_error=False, fill_value=None)
        
        fig, axs = make_interpolated_dispersion_grid_plot(sigma_r_interp,
            sigma_t_interp, r_ra, ra)
        fig.savefig(os.path.join(this_fig_dir,'sigma_r_t_beta_interp.png'))
        plt.close(fig)

        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)


# warnings.resetwarnings()

### Compute the best-fitting combination Osipkov-Merritt DF using dispersion grids

In [None]:
def mloglike_beta_dispersion_grid(*args, **kwargs):
    return -loglike_beta_dispersion_grid(*args, **kwargs)

def loglike_beta_dispersion_grid(params, rs, beta, sigma_r_interp, 
    sigma_t_interp, sigma=None, mass=None, usr_log_prior=None, 
    usr_log_prior_params=[], parts=False):
    '''
    Compute the loglikelihood for a given model using the dispersion grid

    Parameters
    ----------
    params : list
        The parameters of the model, in order: ra1, ra2, k
    rs : numpy.ndarray
        The radii where beta is defined for the N-body data
    beta : numpy.ndarray
        The anisotropy profile for the N-body data
    sigma_r_interp : scipy.interpolate.RegularGridInterpolator
        The sigma_r grid, indexed as (ra, r)
    sigma_t_interp : scipy.interpolate.RegularGridInterpolator
        The sigma_t grid, indexed as (ra, r)
    sigma : numpy.ndarray, optional
        User supplied weights for the each radial bin. If not provided, the
        loglikelihood will be unweighted.
    mass : numpy.ndarray, optional
        The mass contained in each radial bin. To be used as sigma
    usr_log_prior : function, optional
        A function that returns the log of the prior for the model. If not 
        provided, the loglikelihood will be unweighted. Takes the parameters
        as the argument, and any additional arguments can be passed via the
        usr_log_prior_params keyword.
    usr_log_prior_params : list, optional
        A list of arguments to pass to the usr_log_prior function.
    parts : bool, optional
        If True, return the individual parts of the loglikelihood calculation.
        Default is False.
    
    Returns
    -------
    loglike : float
        The loglikelihood of the model given the data.
    '''
    # Evaluate the domain prior
    ra_min, ra_max = sigma_r_interp.grid[0].min(), sigma_r_interp.grid[0].max()
    r_min, r_max = sigma_r_interp.grid[1].min(), sigma_r_interp.grid[1].max()
    if np.min(rs) < r_min or np.max(rs) > r_max:
        warnings.warn('Data outside of interpolation grid.')
    if not domain_prior_beta_dispersion_grid(params, ra_min, ra_max):
        return -np.inf

    # Evaluate the prior on the parameters
    # logprior = logprior_beta_dispersion_grid(params)
    logprior = 0.

    # Evaluate the user supplied log prior
    if callable(usr_log_prior):
        usrlogprior = usr_log_prior(params, *usr_log_prior_params)
        if np.isinf(usrlogprior):
            return -np.inf
    else:
        usrlogprior = 0.

    # Compute the anisotropy for the model
    ra1, ra2, k = params
    svr1 = sigma_r_interp((ra1, rs))
    svt1 = sigma_t_interp((ra1, rs))
    svr2 = sigma_r_interp((ra2, rs))
    svt2 = sigma_t_interp((ra2, rs))
    svr = k*svr1**2 + (1-k)*svr2**2
    svt = k*svt1**2 + (1-k)*svt2**2
    model_beta = 1 - svt/svr/2

    # Sigma for the likelihood
    if sigma is not None:
        _sigma = sigma
    elif mass is not None:
        mass_frac = mass/np.sum(mass)
        _sigma = 1/mass_frac
    else:
        _sigma = np.ones_like(beta)
    
    # Compute the log objective
    logobj = -0.5*(beta-model_beta)**2/_sigma**2

    # Compute the log likelihood
    loglike = np.sum(logobj) + logprior + usrlogprior

    if parts:
        return loglike, logobj, logprior, usrlogprior
    else:
        return loglike

def domain_prior_beta_dispersion_grid(params, ra_min, ra_max):
    ra1, ra2, k = params
    if ra1 < ra_min or ra1 > ra_max:
        return False
    if ra2 < ra_min or ra2 > ra_max:
        return False
    if k < 0 or k > 1:
        return False
    return True

def generate_mcmc_init(init, nwalkers, domain_prior, scale=0.1):
    mcmc_init = np.array([
            init+scale*np.random.randn(len(init)) for i in range(nwalkers)
            ])
    for i in range(nwalkers):
        counter = 0
        while not domain_prior(mcmc_init[i]):
            mcmc_init[i] = init+scale*np.random.randn(len(init))
            counter += 1
            if counter > 100:
                raise RuntimeError('Failed to generate initial conditions.')
    return mcmc_init
    

In [None]:
### Some keywords and properties
force_fit = True
# test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# MCMC params
nwalkers = 50
nit = 1000
ncut = 500
nprocs = 10
n_bin = 500
n_bs = 100

# Anisotropy information
df_type = 'osipkov_merritt_2_combination'
anisotropy_fit_version = 'ra_N10_01_to_300_softening'
# anisotropy_ncut = 500

# Begin logging
log_filename = './log/2_fit_om_2_combination_vdisp_fit.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning Osipkov-Merritt 2 combination velocity dispersion'+\
             ' grid anisotropy fitting. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    # if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir, df_type, 
            anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        if os.path.exists(sampler_filename) and not force_fit:
            if verbose:
                msg = f'Already have fit model, continuing'
                logging.info(msg)
                print(msg)
            continue
        
        # Load the dispersion grids
        sigma_grid_filename = os.path.join(this_fitting_dir,'sigma_grid.pkl')
        with open(sigma_grid_filename,'rb') as handle:
            sigma_r_grid,sigma_t_grid,ra,r_ra = pickle.load(handle)

        # Convert the dispersion grids into interpolators
        sigma_r_interp = scipy.interpolate.RegularGridInterpolator(
            (ra.value,r_ra.value), sigma_r_grid, method='linear',
            bounds_error=False, fill_value=None)
        sigma_t_interp = scipy.interpolate.RegularGridInterpolator(
            (ra.value,r_ra.value), sigma_t_grid, method='linear',
            bounds_error=False, fill_value=None)

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value

        # Bin the data to calculate beta
        _n_bin = round(n_bin) if (len(orbs) > 10*n_bin) else round(len(orbs)/10)
        r_softening = putil.get_softening_length('stars', z=0, physical=True)
        rmin = np.max([r_softening, np.min(rs)])
        # rmin = 0.
        rmax = np.max(rs)
        adaptive_binning_kwargs = {
            'n':_n_bin,
            'rmin':rmin,
            'rmax':rmax,
            'bin_mode':'exact numbers',
            'bin_equal_n':True,
            'end_mode':'ignore',
            'bin_cents_mode':'median',
        }
        bin_edges, bin_cents, bin_n = pkin.get_radius_binning(orbs, 
            **adaptive_binning_kwargs)
        
        # Compute the ingredients for the anisotropy, use dispersions
        compute_betas_kwargs = {'use_dispersions':True,
                                'return_kinematics':True}
        beta, vr2, vp2, vz2 = pkin.compute_betas_bootstrap(orbs,bin_edges,
            n_bootstrap=n_bs, compute_betas_kwargs=compute_betas_kwargs)
        lbeta, mbeta, ubeta = np.percentile(beta, [16,50,84], axis=0)
        sbeta = ubeta-lbeta
    
        # Do the optimization
        if verbose:
            msg = f'Optimizing to find MCMC start point'
            logging.info(msg)
            print(msg)
        init = [1., 10., 0.5]
        opt_fn = lambda params: mloglike_beta_dispersion_grid(params, bin_cents, 
            beta, sigma_r_interp, sigma_t_interp, sigma=sbeta, 
            usr_log_prior=None, usr_log_prior_params=[], parts=False)
        opt = scipy.optimize.minimize(opt_fn, init, method='Powell')# , 
            # options={'maxiter',1000,})
        
        # Do MCMC
        if verbose:
            msg = f'Doing MCMC'
            logging.info(msg)
            print(msg)
        def llfunc(params):
            return loglike_beta_dispersion_grid(params, bin_cents, beta, 
                sigma_r_interp, sigma_t_interp, sigma=sbeta, 
                usr_log_prior=None, usr_log_prior_params=[], parts=False)
        mcmc_init = generate_mcmc_init(opt.x, nwalkers, 
            lambda x: domain_prior_beta_dispersion_grid(x, 
                ra.value.min(), ra.value.max()))
        with multiprocessing.Pool(processes=nprocs) as pool:
            sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), 
                llfunc, args=[], pool=pool)
            sampler.run_mcmc(mcmc_init, nit, progress=True)
        chain = sampler.get_chain(discard=ncut, flat=True, 
            thin=1)
        
        # Save the results
        opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
        with open(opt_filename,'wb') as handle:
            pickle.dump(opt,handle)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        with open(sampler_filename,'wb') as handle:
            pickle.dump(sampler,handle)
        chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
        with open(chain_filename,'wb') as handle:
            pickle.dump(chain,handle)

        # Make some plots
        this_fig_dir = os.path.join(fig_dir, str(z0_sid), 'merger_'+str(j+1))
        os.makedirs(this_fig_dir,exist_ok=True)
        
        # Corner plot
        if verbose:
            msg = f'Making corner plot'
            logging.info(msg)
            print(msg)
        
        figc = corner.corner(chain, labels=[r'$r_{a,1}$',r'$r_{a,2}$',r'$k$'],
            quantiles=[0.16,0.5,0.84], show_titles=True, truths=opt.x,
            truth_color='Red', title_kwargs={'fontsize':12})
        figc.tight_layout()
        figname = os.path.join(this_fig_dir,'corner_dispersion_grid_fit.png')
        figc.savefig(figname)
        plt.close(figc)
        
        # Make a mock beta plot
        if verbose:
            msg = f'Making beta plot'
            logging.info(msg)
            print(msg)
        fig = plt.figure(figsize=(4,4))
        ax = fig.add_subplot(111)
        ax.plot(bin_cents, mbeta, color='Black', alpha=1.0)
        ax.fill_between(bin_cents, lbeta, ubeta, color='Black', alpha=0.25)
        n_plot = 50
        for k in range(n_plot):
            ra1, ra2, kom = chain[np.random.randint(len(chain))]
            svr1 = sigma_r_interp((ra1, bin_cents))
            svt1 = sigma_t_interp((ra1, bin_cents))
            svr2 = sigma_r_interp((ra2, bin_cents))
            svt2 = sigma_t_interp((ra2, bin_cents))
            svr = kom*svr1**2 + (1-kom)*svr2**2
            svt = kom*svt1**2 + (1-kom)*svt2**2
            model_beta = 1 - svt/svr/2
            ax.plot(bin_cents, model_beta, color='Red', alpha=0.1)
        ax.set_xscale('log')
        ax.set_xlabel(r'$r$ [kpc]')
        ax.set_ylabel(r'$\beta$')
        fig.tight_layout()
        figname = os.path.join(this_fig_dir,'beta_dispersion_grid_fit.png')
        fig.savefig(figname)
        plt.close(fig)

        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)

### Compute DFs for the best-fitting pair of Osipkov-Merritt scale radii

In [None]:
### Some keywords and properties
force_df_creation = True
test_pickling = True
verbose = True
dens_fitting_dir = os.path.join(fitting_dir_base,'density_profile/')
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')

# Potential interpolator version
interpot_version = 'all_star_dm_enclosed_mass'

# Stellar halo density information
stellar_halo_density_version = 'poisson_twopower_softening'
stellar_halo_density_ncut = 500
stellar_halo_densfunc = pdens.TwoPowerSpherical()

# Stellar halo rotation information
stellar_halo_rotation_version = 'tanh_rotation'
stellar_halo_rotation_ncut = 500

# Anisotropy information
df_type = 'osipkov_merritt_2_combination'
anisotropy_fit_version = 'ra_N10_01_to_300_softening'
# anisotropy_ncut = 500

# Ignore some standard warnings
warnings.filterwarnings(action='ignore', 
    message='No particle IDs found', category=UserWarning)
warnings.filterwarnings(action='ignore', 
    message='maxiter', category=scipy.integrate.AccuracyWarning)
warnings.filterwarnings(action='ignore', 
    message='invalid value encountered', category=RuntimeWarning)
warnings.filterwarnings(action='ignore',
    message='subdivisions', category=scipy.integrate.IntegrationWarning)
warnings.filterwarnings(action='ignore',
    message='divergent', category=scipy.integrate.IntegrationWarning)

# Begin logging
log_filename = './log/2_fit_om_2_combination_best_fit_dfs.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning Osipkov-Merritt 2 combination DF grid creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    # if i > 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    n_snap = len(primary.snapnum)
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Load the interpolator for the sphericalized potential
    interpolator_filename = os.path.join(dens_fitting_dir,
        'spherical_interpolated_potential/',interpot_version,
        str(z0_sid),'interp_potential.pkl')
    with open(interpolator_filename,'rb') as handle:
        interpot = pickle.load(handle)

    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Filename and pathing check
        this_fitting_dir = os.path.join(df_fitting_dir, df_type, 
            anisotropy_fit_version, str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        df_filename = os.path.join(this_fitting_dir,'df.pkl')
        if os.path.exists(df_filename) and not force_df_creation:
            if verbose:
                msg = f'Already have DF best fits, continuing'
                logging.info(msg)
                print(msg)
            continue
            
        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        r_softening = putil.get_softening_length('stars', z=0, physical=True)
        rmin = np.max([r_softening, np.min(rs)])
        # n_star = len(orbs)

        # Get the stellar halo density profile (denspot for the DF)
        stellar_halo_density_filename = os.path.join(dens_fitting_dir,
            'stellar_halo/',stellar_halo_density_version,str(z0_sid),
            'merger_'+str(j+1)+'/', 'sampler.pkl')
        denspot = pfit.construct_pot_from_fit(
            stellar_halo_density_filename, stellar_halo_densfunc, 
            stellar_halo_density_ncut, ro=ro, vo=vo)

        # Load the best-fits
        chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
        with open(chain_filename,'rb') as handle:
            chain = pickle.load(handle)
        ra1, ra2, kom = np.median(chain,axis=0)

        ras = [ra1, ra2]
        dfs = []
        
        # Loop over the pair of best-fitting ra and build the DFs
        for k in range(len(ras)):
            # Construct the distribution function and do some dummy sampling
            # to set the interpolators. Then save.
            try:
                if verbose:
                    msg = f'ra: {round(ras[k],3)} kpc, building DF'
                    logging.info(msg)
                    print(msg)
                dfom = df.osipkovmerrittdf(pot=interpot, denspot=denspot, 
                    ra=ras[k]*apu.kpc, ro=ro, vo=vo, rmax=rs.max()*apu.kpc*1.1)
                print('  Sampling DF')
                _ = dfom.sample(n=100, rmin=rmin*apu.kpc)
            except Exception as e:
                msg = f'Failed to build DF, skipping. Error: {e}'
                logging.info(msg)
                print(msg)

            # Filename built above
            if test_pickling:
                try:
                    pickle.loads(pickle.dumps(dfom))
                except RecursionError:
                    if verbose:
                        msg = 'Caught recursion error when (un)pickling, quiting.'
                        logging.info(msg)
                        print(msg)
            # with open(df_filename,'wb') as handle:
            #     pickle.dump(dfom,handle)
            dfs.append(dfom)

        # Save the grid
        with open(df_filename,'wb') as handle:
            pickle.dump([dfs, ras, kom, 'array of dfs, array of [ra1,ra2], k'],
                handle)

        if verbose:
            msg = f'Done with merger'
            logging.info(msg)
            print(msg)

warnings.resetwarnings()