In [None]:
# ------------------------------------------------------------------------
#
# TITLE - fit_dm_halos.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Fit the dark matter halos of simulated primaries at z=0
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, time, dill as pickle
from tqdm.notebook import tqdm

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Fitting
import emcee
import corner
import multiprocessing
import scipy.optimize

## galpy
from galpy import orbit
from galpy import potential

## Project-specific
sys.path.insert(0,'../../src/')
from tng_dfs import fitting as pfit
from tng_dfs import densprofile as pdens
from tng_dfs import util as putil
from tng_dfs import cutout as pcutout

### Scale parameters

ro = 8.
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
fig_dir = './fig/dm_halos'
epsen_fig_dir = '/epsen_data/scr/lane/projects/tng-dfs/figs/notebooks/'+\
    'fit_density_profiles/dm_halos/'
epsen_fitting_dir = '/epsen_data/scr/lane/projects/tng-dfs/fitting/'+\
    'density_profile/dm_halo/'
os.makedirs(fig_dir,exist_ok=True)
os.makedirs(epsen_fig_dir,exist_ok=True)
os.makedirs(epsen_fitting_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Parameters, preparation

In [None]:
def plot_fit_density(r, masses, rvir, rmin_fit, rmax_fit, init, opt, chain,):   
    nbin = 50
    log_redges = np.linspace(np.log10(1), np.log10(500), nbin+1)
    log_rcents = 0.5*(log_redges[1:] + log_redges[:-1])
    redges = 10**log_redges
    rcents = 10**log_rcents
    dens = np.zeros(nbin)
    for k in range(nbin):
        vol = 4*np.pi/3 * (redges[k+1]**3 - redges[k]**3)
        mask = (r >= redges[k]) & (r < redges[k+1])
        dens[k] = np.sum(masses[mask]) / vol
    menc = np.zeros(nbin)
    for k in range(nbin):
        mask = (r < redges[k+1])
        menc[k] = np.sum(masses[mask])
    menc /= np.sum(masses[r < rvir])
    
    # Plot the density profile results
    fig = plt.figure(figsize=(8,6))
    ax1 = fig.add_subplot(211)
    ax2 = fig.add_subplot(212)
    ax1.plot(np.log10(rcents), np.log10(dens), color='Grey', linewidth=4., 
        zorder=2, label='Data')
    # ax1.plot(np.log10(rcents), np.log10(densfunc(rcents, 0, 0, init)), 
    #     color='DodgerBlue', linewidth=2., zorder=3, linestyle='dashed', 
    #     label='Init NFW')
    ax1.plot(np.log10(rcents), np.log10(densfunc(rcents, 0, 0, opt.x)), 
        color='Black', linewidth=2., zorder=3, linestyle='dashed', 
        label='Optimization')
    nt = 100
    indx = np.random.choice(len(chain), size=nt, replace=False)
    for k in range(nt):
        ax1.plot(np.log10(rcents), np.log10(densfunc(rcents, 0, 0, _chain[indx[k]])), 
            color='Red', alpha=0.1, linewidth=1., zorder=3, 
            linestyle='solid')
    ax1.axvline(np.log10(rmin_fit), color='Black', linestyle='dashed')
    ax1.axvline(np.log10(rmax_fit), color='Black', linestyle='dashed')
    ax1.set_xlabel(r'$\log_{10}(r/\mathrm{kpc})$')
    ax1.set_ylabel(r'$\log_{10}(\rho/\mathrm{M}_\odot\,\mathrm{kpc}^{-3})$')
    ax1.legend()

    ax2.plot(np.log10(rcents), menc, color='Grey', linewidth=4.,
        zorder=2, label='Data')
    # ax2.plot(np.log10(rcents), densfunc.mass(rcents, init)/densfunc.mass(rvir, init),
    #     color='DodgerBlue', linewidth=2., zorder=3, linestyle='dashed',
    #     label='Init NFW')
    ax2.plot(np.log10(rcents), densfunc.mass(rcents, opt.x)/densfunc.mass(rvir, opt.x),
        color='Black', linewidth=2., zorder=3, linestyle='dashed',
        label='Optimization')
    for k in range(nt):
        ax2.plot(np.log10(rcents), 
            densfunc.mass(rcents, _chain[indx[k]])/densfunc.mass(rvir, _chain[indx[k]]), 
            color='Red', alpha=0.1, linewidth=1., zorder=3, 
            linestyle='solid')
    ax2.axvline(np.log10(rmin_fit), color='Black', linestyle='dashed')
    ax2.axvline(np.log10(rmax_fit), color='Black', linestyle='dashed')
    ax2.set_xlabel(r'$\log_{10}(r/\mathrm{kpc})$')
    ax2.set_ylabel(r'$M(r)/M(r_{vir})$')

    fig.tight_layout()

    return fig, [ax1, ax2]

In [None]:
def usr_log_prior_nfw(densfunc, params, rmax):
    assert isinstance(densfunc, pdens.NFWSpherical)
    a, amp = densfunc._parse_params(params)
    if a > rmax:
        return -np.inf
    else:
        return 0.

densfunc = pdens.NFWSpherical()
verbose = True
show_plots = False
nwalkers = 50
nit = 2000
ncut = 500
nprocs = 10

### Fit the NFW profile to data with the Poisson likelihood
$10^{5}$ points per halo

In [None]:
force_fit = False

for i in range(n_mw):
    # if i > 2: continue
    if verbose: print(f'Fitting DM halo of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]

    # Check if the fitting has already been done
    this_fitting_dir = os.path.join(epsen_fitting_dir,'poisson_nfw',str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    if os.path.exists(os.path.join(this_fitting_dir,'sampler.pkl')) and \
        not force_fit:
        print('Fitting already done, continuing...')
        continue

    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    orbs = co.get_orbs('dm')
    masses = co.get_masses('dm').value
    
    # Properties for fitting
    r, R, phi, z = orbs.r().value, orbs.R().value, orbs.phi().value, orbs.z().value
    rmin = np.min(r)
    rmax = np.max(r)
    rvir = pdens.get_virial_radius(r, masses).value
    rmin_fit = 0.05*rvir
    rmax_fit = 0.5*rvir
    usr_log_prior_params = [rmax_fit,]
    effvol_params = [rmin_fit, rmax_fit]

    # Pare down the data for fitting
    npts = 1e5
    inds = np.random.choice(np.arange(len(masses),dtype=int), size=int(npts),
        replace=False)
    fac = len(masses)/npts
    r = r[inds]
    R = R[inds]
    phi = phi[inds]
    z = z[inds]
    masses = masses[inds]
    masses *= fac # Increase the mass to compensate for the decrease in points

    # Remove data outside the fitting range
    mask = (r >= rmin_fit) & (r < rmax_fit)
    r_fit = r[mask]
    R_fit = R[mask]
    phi_fit = phi[mask]
    z_fit = z[mask]
    masses_fit = masses[mask]

    # Get the initial conditions
    init = pdens.get_NFW_init_params(densfunc, r, masses)
    # init += [1.,] # Add the variance hyperparameter
    
    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_dens(params, densfunc, R_fit, 
        phi_fit, z_fit, mass=masses_fit, usr_log_prior=usr_log_prior_nfw, 
        usr_log_prior_params=usr_log_prior_params, 
        effvol_params=effvol_params)
    opt = scipy.optimize.minimize(opt_fn, init, method='Nelder-Mead')
    
    # Fit the density profile with MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_dens(params, densfunc, R_fit, phi_fit, z_fit, 
              mass = masses_fit, usr_log_prior = usr_log_prior_nfw, 
              usr_log_prior_params = usr_log_prior_params,
              effvol_params = effvol_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc,
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,1] = np.log10(chain[:,1])

    # Compute virial masses for the whole chain
    mvirs = np.zeros(len(chain))
    for k in range(len(chain)):
        mvirs[k] = densfunc.mass(rvir, _chain[k,:])
    chain = np.append(chain, np.log10(mvirs[:,None]), axis=1)
    mvir = np.sum(masses[r < rvir])

    # Save the results
    this_fitting_dir = os.path.join(epsen_fitting_dir,'poisson_nfw',str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        pickle.dump([chain,['a','log10(amp)','mvir']],handle)
    
    # Plotting
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axs = plot_fit_density(r, masses, rvir, rmin_fit, rmax_fit, init,
        opt, chain)
    fig.show()
    figname = os.path.join(this_fig_dir,'poisson_fit_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain, 
        labels=['a [kpc]',
            '$\log_{10}(\mathrm{amp}/\mathrm{M}_\odot\,\mathrm{kpc}^{-3}))$',
            '$\log_{10}(M_{vir}$ [$\mathrm{M}_\odot)$]'],
        truths=[None, None, np.log10(mvir)]
        )
    corner.overplot_lines(fig, np.median(chain, axis=0), color='Black')
    fig.show()
    figname = os.path.join(this_fig_dir,'poisson_fit_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)


### Fit the binned data with an NFW

In [None]:
for i in range(n_mw):
    # if i > 2: continue
    if verbose: print(f'Fitting DM halo of analog {i+1}/{n_mw}')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()
    orbs = co.get_orbs('dm')
    masses = co.get_masses('dm').value
    
    # Properties for fitting
    r, R, phi, z = orbs.R().value, orbs.r().value, orbs.phi().value, orbs.z().value
    rmin = np.min(r)
    rmax = np.max(r)
    rvir = pdens.get_virial_radius(r, masses).value
    rmin_fit = 0.05*rvir
    rmax_fit = 0.5*rvir
    usr_log_prior_params = [rmax,]

    # Compute the binned density
    nbin = 32
    log_redges = np.linspace(np.log10(rmin_fit), np.log10(rmax_fit), nbin+1)
    log_rcents = 0.5*(log_redges[1:] + log_redges[:-1])
    redges = 10**log_redges
    rcents = 10**log_rcents
    dens = np.zeros(nbin)
    for k in range(nbin):
        vol = 4*np.pi/3 * (redges[k+1]**3 - redges[k]**3)
        mask = (r >= redges[k]) & (r < redges[k+1])
        dens[k] = np.sum(masses[mask]) / vol

    # Get the initial conditions
    init = pdens.get_NFW_init_params(densfunc, r, masses)
    # init += [1.,] # Add the variance hyperparameter
    
    # Optimize to get a starting point
    print('Optimizing...')
    opt_fn = lambda params: pfit.mloglike_binned_dens(params, densfunc,
        rcents, dens, usr_log_prior_nfw, usr_log_prior_params)
    opt = scipy.optimize.minimize(opt_fn, init, method='Nelder-Mead')

    # Fit the density profile with MCMC
    # Make the likelihood function
    def llfunc(params):
        return pfit.loglike_binned_dens(params, densfunc, rcents, dens, 
            usr_log_prior_nfw, usr_log_prior_params)
    mcmc_init = np.array([
        opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
        ])
    with multiprocessing.Pool(processes=nprocs) as pool:
        sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc,
            pool=pool)
        sampler.run_mcmc(mcmc_init, nit, progress=True)
    _chain = sampler.get_chain(flat=True, discard=ncut)
    chain = copy.deepcopy(_chain)
    chain[:,1] = np.log10(chain[:,1])

    # Compute virial masses for the whole chain
    mvirs = np.zeros(len(chain))
    for k in range(len(chain)):
        mvirs[k] = densfunc.mass(rvir, _chain[k,:])
    chain = np.append(chain, np.log10(mvirs[:,None]), axis=1)
    mvir = np.sum(masses[r < rvir])

    # Save the results
    this_fitting_dir = os.path.join(epsen_fitting_dir,'binned_nfw',str(z0_sid))
    os.makedirs(this_fitting_dir,exist_ok=True)
    opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
    with open(opt_filename,'wb') as handle:
        pickle.dump(opt,handle)
    sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
    with open(sampler_filename,'wb') as handle:
        pickle.dump(sampler,handle)
    chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
    with open(chain_filename,'wb') as handle:
        pickle.dump([chain,['a','log10(amp)','mvir']],handle)

    # Plotting
    this_fig_dir = os.path.join(epsen_fig_dir,str(z0_sid))
    os.makedirs(this_fig_dir,exist_ok=True)

    # Density profile
    fig, axs = plot_fit_density(r, masses, rvir, rmin_fit, rmax_fit, init, 
        opt, chain)
    figname = os.path.join(epsen_fig_dir,'dm_halo_'+str(z0_sid)+'_binned_fit_density.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)

    # Corner plot
    fig = corner.corner(chain, 
        labels=['a [kpc]',
            '$\log_{10}(\mathrm{amp}/\mathrm{M}_\odot\,\mathrm{kpc}^{-3}))$',
            '$\log_{10}(M_{vir}$ [$\mathrm{M}_\odot)$]'],
        truths=[None, None, np.log10(mvir)]
        )
    fig.show()
    figname = os.path.join(this_fig_dir,'binned_fit_corner.png')
    fig.savefig(figname,dpi=300)
    plt.close(fig)