In [15]:
# ------------------------------------------------------------------------
#
# TITLE - 2-fit_stellar_halo_density.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Fit density profiles to remnants in stellar halos at z=0
'''

__author__ = "James Lane"

In [16]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, time, dill as pickle
from tqdm.notebook import tqdm

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Fitting
import emcee
import corner
import multiprocessing
import scipy.optimize

## galpy
from galpy import orbit
from galpy import potential

## Project-specific
sys.path.insert(0,'../../src/')
from tng_dfs import fitting as pfit
from tng_dfs import densprofile as pdens
from tng_dfs import util as putil
from tng_dfs import cutout as pcutout

### Scale parameters

ro = 8.
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
fig_dir = './fig/stellar_halo'
epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/'+\
    'fit_density_profiles/stellar_halo/'
epsen_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/stellar_halo/'
os.makedirs(fig_dir,exist_ok=True)
os.makedirs(epsen_fig_dir,exist_ok=True)
os.makedirs(epsen_fitting_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

46 Milky way like galaxies found
Loading subhalo data from /epsen_data/scr/lane/projects/tng-dfs/data/mw_analogs/subs/mwsubs.pkl
File has 46 subhalos
Cutting on bulge and disk fractions
Cut to  30  subhalos


### Parameters, preparation

In [18]:
def plot_fit_density(r, masses, chain, rmin, rmax):
    nbin = 50
    log_redges = np.linspace(np.log10(rmin), np.log10(rmax), nbin+1)
    log_rcents = 0.5*(log_redges[1:] + log_redges[:-1])
    redges = 10**log_redges
    rcents = 10**log_rcents
    dens = np.zeros(nbin)
    for k in range(nbin):
        vol = 4*np.pi/3 * (redges[k+1]**3 - redges[k]**3)
        mask = (r >= redges[k]) & (r < redges[k+1])
        dens[k] = np.sum(masses[mask]) / vol
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(np.log10(rcents), np.log10(dens), color='Grey', linewidth=4., zorder=2, 
        label='Data')
    nrand = 100
    inds = np.random.choice(np.arange(len(chain),dtype=int), size=nrand, 
        replace=False)
    for k in range(nrand):
        fdens = densfunc(rcents, 0., 0., chain[inds[k],:])
        ax.plot(np.log10(rcents), np.log10(fdens), color='Red', linewidth=1., 
            zorder=3, linestyle='solid', alpha=5/nrand)

    return fig, ax

In [90]:
def usr_log_prior_twopower(densfunc, params, rmax):
    assert isinstance(densfunc, pdens.TwoPowerSpherical)
    # Unpack params
    alpha, beta, a, amp = densfunc._parse_params(params)
    # Check that a is not outside of the maximum radius
    if a > rmax:
        return -np.inf
    else:
        return 0.

densfunc = pdens.TwoPowerSpherical()
verbose = True
show_plots = False
nwalkers = 50
nit = 2000
ncut = 500
nprocs = 10

In [None]:
force_fit = False
version_prefix = 'poisson_twopower'
for i in range(n_mw):
    # if i > 2: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    # Get the tree major merger object
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    # _unique_particle_ids = []

    for j in range(n_major):
        # if j > 0: continue
        if verbose: print(f'Fitting merger {j+1}/{n_major}, of analog {i+1}/{n_mw}')

        # Fitting dir, check if already exists
        this_fitting_dir = os.path.join(epsen_fitting_dir,version_prefix,
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        if os.path.exists(sampler_filename) and not force_fit:
            print(f'Fitting already exists for {z0_sid}, merger {j+1}')
            continue 

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        pid = co.get_property('stars','ParticleIDs')
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        masses = co.get_masses('stars')[indx].value

        # Properties for fitting
        r, R, phi, z = orbs.r().value, orbs.R().value, orbs.phi().value, orbs.z().value
        rmin = np.min(r)
        rmax = np.max(r)
        effvol_params = [rmin, rmax]
        usr_log_prior_params = [rmax,]

        # Pare down the data for fitting
        if len(masses) > 1e5:
            npts = 1e5
            inds = np.random.choice(np.arange(len(masses),dtype=int), 
                size=int(npts), replace=False)
            fac = len(masses)/npts
            r = r[inds]
            R = R[inds]
            phi = phi[inds]
            z = z[inds]
            masses = masses[inds]
            masses *= fac # Increase the mass to compensate for the decrease in points

        # Remove data outside the fitting range, null for now
        mask = (r >= rmin) & (r <= rmax)
        r_fit, R_fit, phi_fit, z_fit = r[mask], R[mask], phi[mask], z[mask]
        masses_fit = masses[mask]

        # Get the initial conditions
        _alpha = 1.
        _beta = 4.
        _a = 10.
        _amp = np.sum(masses)/densfunc.mass(1e8, [_alpha, _beta, _a, 1.])
        init = np.array([_alpha, _beta, _a, _amp])

        # Do an optimization to get the initial conditions
        print('Optimizing...')
        opt_fn = lambda params: pfit.mloglike_dens(params, densfunc,
            R_fit, phi_fit, z_fit, mass=masses_fit, 
            usr_log_prior=usr_log_prior_twopower, 
            usr_log_prior_params=usr_log_prior_params,
            effvol_params=effvol_params)
        opt = scipy.optimize.minimize(opt_fn, init, method='Nelder-Mead',
            options={'maxiter':2000,})
        
        def llfunc(params):
            return pfit.loglike_dens(params, densfunc, R_fit, phi_fit, z_fit,
                mass=masses_fit, effvol_params=effvol_params, 
                usr_log_prior=usr_log_prior_twopower, 
                usr_log_prior_params=usr_log_prior_params)
        mcmc_init = np.array([
            opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
            ])
        with multiprocessing.Pool(processes=nprocs) as pool:
            sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc,
                pool=pool)
            sampler.run_mcmc(mcmc_init, nit, progress=True)
        _chain = sampler.get_chain(flat=True, discard=ncut)
        chain = copy.deepcopy(_chain)
        chain[:,3] = np.log10(chain[:,3])

        # Compute masses for the whole chain
        ms = np.zeros(len(chain))
        for k in range(len(chain)):
            ms[k] = densfunc.effective_volume(_chain[k,:], *effvol_params)
        chain = np.append(chain, np.log10(ms[:,None]), axis=1)
        mtot = np.sum(masses)

        fig = corner.corner(chain, 
            labels=[r'$\alpha$',r'$\beta$','scale [kpc]','log Amplitude',
                    'log Mass [Msun]'],
            truths = [None,None,None,None,np.log10(mtot)],
            truth_color='Red')
        figname = fig_dir+str(z0_sid)+'_merger_'+str(j+1)+'_two_power_corner.png'
        print(figname)
        fig.savefig(figname, dpi=300, bbox_inches='tight')
        plt.close(fig)

        # Save the results
        opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
        with open(opt_filename,'wb') as handle:
            pickle.dump(opt,handle)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        with open(sampler_filename,'wb') as handle:
            pickle.dump(sampler,handle)
        chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
        with open(chain_filename,'wb') as handle:
            pickle.dump([chain,['a','log10(amp)','mvir']],handle)

        # Plotting
        this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fig_dir, exist_ok=True)

        # Density profile
        fig, axs = plot_fit_density(r, masses, _chain, rmin, rmax)
        figname = os.path.join(this_fig_dir,str(z0_sid)+'_merger_'+str(j+1)+'_two_power_fit_dens.png')
        fig.savefig(figname,dpi=300)
        plt.close(fig)

        # Corner plot
        fig = corner.corner(chain, 
        labels=[r'$\alpha$', r'$\beta$', r'$a$ [kpc]',
            r'$\log_{10}(\mathrm{amp}/\mathrm{M}_\odot\,\mathrm{kpc}^{-3}))$',
            r'$\log_{10}(M$ [$\mathrm{M}_\odot)$]'],
        truths=[None, None, None, None, np.log10(mtot)]
        )
        figname = os.path.join(this_fig_dir,str(z0_sid)+'_merger_'+str(j+1)+'_two_power_fit_corner.png')
        print(figname)
        fig.savefig(figname, dpi=300, bbox_inches='tight')
        plt.close(fig)