In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 3_fit_stellar_halo_rotation.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Fit rotation DF wrappers to remnants in stellar halos at z=0
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, time, dill as pickle, logging

## Matplotlib
from matplotlib import pyplot as plt

## Astropy
from astropy import units as apu

## Fitting
import emcee
import corner
import multiprocessing
import scipy.optimize

## Project-specific
sys.path.insert(0,'../../src/')
from tng_dfs import util as putil
from tng_dfs import cutout as pcutout

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base,
    'notebooks/3_fit_density_profiles/3_fit_stellar_halo_rotation/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Functions

In [None]:
def tanh_rotation_function(Lz,frot=0.,chi=1.):
    if isinstance(Lz,apu.Quantity):
        Lz = Lz.to(apu.kpc*apu.km/apu.s).value
    if isinstance(chi,apu.Quantity):
        chi = chi.to(apu.kpc*apu.km/apu.s).value
    gLz = np.tanh(Lz/chi)
    k = frot/2.
    return 1-k+k*gLz

def tanh_rotation_effvol(params, Lzmin, Lzmax):
    frot, chi = params
    term1 = (Lzmax-Lzmin)*(1-frot)
    # term2 = frot*chi*(np.log(np.cosh(Lzmax/chi)) - np.log(np.cosh(Lzmin/chi)))
    term2 = 0.5*frot*(Lzmax-Lzmin)
    return term1 + term2

def mloglike_tanh_rotation(*args, **kwargs):
    return -loglike_tanh_rotation(*args, **kwargs)

def loglike_tanh_rotation(params, Lz, usr_log_prior=None, 
    usr_log_prior_params=[], effvol_params=[], parts=False):
    # Evaluate the domain prior
    if not domain_prior_tanh_rotation(params):
        return -np.inf
    # Evaluate the prior on the density profile
    logprior = logprior_tanh_rotation(params)
    # Evaluate any user supplied prior
    if callable(usr_log_prior):
        usrlogprior = usr_log_prior(params, *usr_log_prior_params)
        if np.isinf(usrlogprior):
            return -np.inf
    else:
        usrlogprior = 0
    # Evaluate the tanh kernel
    frot, chi = params
    logkernel = np.log(tanh_rotation_function(Lz, frot=frot, chi=chi))
    if np.any(np.isnan(logkernel)):
        return -np.inf
    # Evaluate the effective volume
    effvol = tanh_rotation_effvol(params, *effvol_params)
    # Evaluate the log likelihood
    loglike = np.sum(logkernel) - effvol + logprior + usrlogprior
    if parts:
        return loglike, np.sum(logkernel), effvol, logprior, usrlogprior
    else:
        return loglike

def logprior_tanh_rotation(params):
    # frot, chi = params
    return 0.

def domain_prior_tanh_rotation(params):
    frot, chi = params
    if frot < 0.: return False
    if frot > 1.: return False
    if chi < 0.: return False
    return True

# Also define likelihoods for fitting tanh rotation in terms of the asymmetry
# of angular momentum counts.

def tanh_rotation_function_asymmetry(Lz,frot=0.,chi=1.):
    if isinstance(Lz,apu.Quantity):
        Lz = Lz.to(apu.kpc*apu.km/apu.s).value
    if isinstance(chi,apu.Quantity):
        chi = chi.to(apu.kpc*apu.km/apu.s).value
    return frot*np.tanh(Lz/chi)/2. + 0.5

def mloglike_tanh_rotation_asym(*args, **kwargs):
    return -loglike_tanh_rotation_asym(*args, **kwargs)

def loglike_tanh_rotation_asym(params, Lz, usr_log_prior=None, 
    usr_log_prior_params=[], parts=False):
    # Evaluate the domain prior
    if not domain_prior_tanh_rotation_asym(params):
        return -np.inf
    # Evaluate the prior on the density profile
    logprior = logprior_tanh_rotation_asym(params)
    # Evaluate any user supplied prior
    if callable(usr_log_prior):
        usrlogprior = usr_log_prior(params, *usr_log_prior_params)
        if np.isinf(usrlogprior):
            return -np.inf
    else:
        usrlogprior = 0
    # Determine the Lz asymmetry
    n_bin = int(np.min([50,len(Lz)/20]))
    Lz_max = np.percentile(np.abs(Lz), 80)
    Lz_min = -Lz_max
    Lz_N, Lz_edges = np.histogram( Lz, bins=n_bin, range=(Lz_min, Lz_max) )
    Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])
    Lz_mirror_sum = (Lz_N + Lz_N[::-1])
    Lz_asym_mask = Lz_mirror_sum > 0
    Lz_asym = np.ones_like(Lz_centers)
    Lz_asym = Lz_N[Lz_asym_mask] / Lz_mirror_sum[Lz_asym_mask]
    # Evaluate the tanh rotation asymmetry and the objective function
    frot, chi = params
    Lz_asym_model = tanh_rotation_function_asymmetry(Lz_centers, frot=frot, 
        chi=chi)
    # Mass fraction
    mass_frac = Lz_N/np.sum(Lz_N)
    # Squared objective
    # obj = np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask])
    # Gaussian objective
    _sigma = mass_frac[Lz_asym_mask]
    logobj = -0.5*np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask])/_sigma**2
    # logobj = np.log(obj)
    # if np.any(np.isnan(obj)):
    #     return -np.inf
    # Evaluate the effective volume
    # Evaluate the log likelihood
    loglike = np.sum(logobj) + logprior + usrlogprior
    if parts:
        return loglike, np.sum(logobj), logprior, usrlogprior
    else:
        return loglike

def logprior_tanh_rotation_asym(params):
    frot, chi = params
    # log prior on chi
    chi_min = 0.001
    chi_max = 10000.
    prior_chi = scipy.stats.loguniform.pdf(chi, chi_min, chi_max)
    return np.log(prior_chi)
    # return 0.

def domain_prior_tanh_rotation_asym(params):
    frot, chi = params
    if frot < -1.: return False
    if frot > 1.: return False
    if chi < 0.001: return False
    if chi > 10000: return False
    return True

### Fit the rotation of remnants

In [None]:
def plot_tanh_rotation_fit(Lz, chain):
    '''
    '''
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # Bin up Lz
    n_bin = int(np.min([50,len(Lz)/20]))
    Lz_max = np.percentile(np.abs(Lz), 80)
    Lz_min = -Lz_max
    Lz_N, Lz_edges = np.histogram( Lz, bins=n_bin, range=(Lz_min, Lz_max) )
    Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])
    Lz_mirror_sum = (Lz_N + Lz_N[::-1])
    Lz_asym_mask = Lz_mirror_sum > 0
    Lz_asym = np.ones_like(Lz_centers)
    Lz_asym = Lz_N[Lz_asym_mask] / Lz_mirror_sum[Lz_asym_mask]
    # Evaluate the tanh rotation asymmetry and the objective function

    ax.plot(Lz_centers, Lz_asym, color='Grey', linewidth=2.)
    nit = 100
    indx = np.random.choice(np.arange(len(chain),dtype=int), size=nit, 
        replace=False)
    for i in range(nit):
        frot, chi = chain[indx[i],:]
        Lz_asym_model = tanh_rotation_function_asymmetry(Lz_centers, frot=frot, 
            chi=chi)
        ax.plot(Lz_centers, Lz_asym_model, color='Red', linewidth=1., 
            linestyle='solid', alpha=5/nit)
    ax.set_xlabel(r'$L_z$')
    ax.set_ylabel(r'$A(L_z)$')

    return fig, ax

In [None]:
# Properties
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function/')
stellar_halo_rotation_dftype = 'tanh_rotation'
stellar_halo_rotation_version = 'asymmetry_fit'
verbose = True
force_fit = False
nwalkers = 100
nit = 2000
ncut = 500
nprocs = 10

# Begin logging
log_filename = './log/3_fit_stellar_halo_rotation.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning rotation kernel fitting for stellar halos. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

for i in range(n_mw):
    if i != 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    # n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    co.center_and_rectify()

    # Get the tree major merger object
    major_mergers = primary.tree_major_mergers
    n_major = primary.n_major_mergers
    # _unique_particle_ids = []

    for j in range(n_major):
        if j != 0: continue
        
        if verbose: 
            msg = f'Fitting merger {j+1}/{n_major}, of analog {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Fitting dir, check if already exists
        this_fitting_dir = os.path.join(df_fitting_dir,
            stellar_halo_rotation_dftype,stellar_halo_rotation_version,
            str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fitting_dir,exist_ok=True)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        if os.path.exists(sampler_filename) and not force_fit:
            if verbose:
                msg = f'Fitting already exists for {z0_sid}, merger {j+1}'
                logging.info(msg)
                print(msg)
            continue

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        pid = co.get_property('stars','ParticleIDs')
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        
        # Properties for fitting
        Lz = orbs.Lz().value
        usr_log_prior_params = []

        # Pare down the data for fitting
        # if len(masses) > 1e5:
        #     npts = 1e5
        #     inds = np.random.choice(np.arange(len(masses),dtype=int), 
        #         size=int(npts), replace=False)
        #     fac = len(masses)/npts
        #     Lz = Lz[inds]

        # Remove data outside the fitting range, null for now
        mask = np.ones_like(Lz, dtype=bool)
        Lz_fit = Lz[mask]

        # Get the initial conditions
        init = np.array([0.25, 425.])

        # Do an optimization to get the initial conditions
        if verbose:
            msg = 'Optimizing initial conditions'
            logging.info(msg)
            print(msg)
        opt_fn = lambda params: mloglike_tanh_rotation_asym(params, 
            Lz_fit, usr_log_prior=None, 
            usr_log_prior_params=usr_log_prior_params)
        opt = scipy.optimize.minimize(opt_fn, init, method='Nelder-Mead',
            options={'maxiter':2000,})
        
        # Do the MCMC
        if verbose:
            msg = 'Running MCMC'
            logging.info(msg)
            print(msg)
        def llfunc(params):
            return loglike_tanh_rotation_asym(params, 
            Lz_fit, usr_log_prior=None, 
            usr_log_prior_params=usr_log_prior_params)
        mcmc_init = np.array([
            opt.x+0.1*np.random.randn(len(opt.x)) for i in range(nwalkers)
            ])
        # mcmc_init = np.array([
        #     init+0.1*np.random.randn(len(init)) for i in range(nwalkers)
        #     ])
        with multiprocessing.Pool(processes=nprocs) as pool:
            sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), llfunc,
                pool=pool)
            sampler.run_mcmc(mcmc_init, nit, progress=True)
        _chain = sampler.get_chain(flat=True, discard=ncut)
        chain = copy.deepcopy(_chain)

        # Remove all points outside of the domain
        mask = np.ones(len(chain), dtype=bool)
        for k in range(len(chain)):
            if not domain_prior_tanh_rotation_asym(chain[k,:]):
                mask[k] = False
        chain = chain[mask,:]

        # Plotting
        this_fig_dir = os.path.join(fig_dir,stellar_halo_rotation_dftype,
            stellar_halo_rotation_version,str(z0_sid),'merger_'+str(j+1))
        os.makedirs(this_fig_dir, exist_ok=True)

        # Density profile
        fig, axs = plot_tanh_rotation_fit(Lz, chain)
        figname = os.path.join(this_fig_dir,'asymmetry.png')
        fig.savefig(figname,dpi=300)
        plt.close(fig)

        # Corner
        fig = corner.corner(chain, 
            labels=[r'$f_{rot}$',r'$\chi$'], quantiles=[0.16,0.5,0.84])
        corner.overplot_lines(fig, np.median(chain,axis=0), color='Black')
        figname = os.path.join(this_fig_dir,'corner.png')
        print(figname)
        fig.savefig(figname, dpi=300, bbox_inches='tight')
        plt.close(fig)

        # Save the results
        opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
        with open(opt_filename,'wb') as handle:
            pickle.dump(opt,handle)
        sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
        with open(sampler_filename,'wb') as handle:
            pickle.dump(sampler,handle)
        chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
        with open(chain_filename,'wb') as handle:
            pickle.dump([chain,['frot','chi']],handle)

        if verbose:
            msg = 'Done with this merger'
            logging.info(msg)
            print(msg)

### Some extra diagnostic code

In [None]:
# primary = tree_primaries[1]
# z0_sid = primary.subfind_id[0]
# # n_snap = len(primary.snapnum)
# n_major = primary.n_major_mergers
# primary_filename = primary.get_cutout_filename(mw_analog_dir,
#     snapnum=primary.snapnum[0])
# co = pcutout.TNGCutout(primary_filename)
# co.center_and_rectify()

# # Get the tree major merger object
# major_mergers = primary.tree_major_mergers
# n_major = primary.n_major_mergers
# # _unique_particle_ids = []

# # Get the major merger
# major_merger = primary.tree_major_mergers[0]
# upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
# pid = co.get_property('stars','ParticleIDs')
# indx = np.where(np.isin(pid,upid))[0]
# orbs = co.get_orbs('stars')[indx]

# # Properties for fitting
# Lz = orbs.Lz().value

In [None]:
# fig = plt.figure()
# ax = fig.add_subplot(111)

# # Lz_max = np.percentile(np.abs(Lz), 90)
# # Lz_min = -Lz_max
# # Lz_N, Lz_edges = np.histogram( Lz, bins=50, range=(Lz_min, Lz_max) )
# # Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])

# Lz_max = np.percentile(np.abs(Lz), 80)
# Lz_min = -Lz_max
# Lz_N, Lz_edges = np.histogram( Lz, bins=50, range=(Lz_min, Lz_max) )
# Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])
# Lz_mirror_sum = (Lz_N + Lz_N[::-1])
# Lz_asym_mask = Lz_mirror_sum > 0
# Lz_asym = np.ones_like(Lz_centers)
# Lz_asym = Lz_N[Lz_asym_mask] / Lz_mirror_sum[Lz_asym_mask]
# # Evaluate the tanh rotation asymmetry and the objective function

# ax.scatter(Lz_centers, Lz_N / (Lz_N + Lz_N[::-1]), color='Grey')

# frot, chi = [0.95, 200]
# ax.scatter(Lz_centers, 
#     tanh_rotation_function_asymmetry(Lz_centers, frot=frot, chi=chi), 
#     color='Red')
# frot, chi = [0.25, 10.]
# # ax.scatter(Lz_centers, 
# #     tanh_rotation_function_asymmetry(Lz_centers, frot=frot, chi=chi), 
# #     color='Blue')
# # ax.plot(Lz_centers, 
# #     tanh_rotation_function_asymmetry(Lz_centers, frot=0.5, chi=1000), 
# #     color='Blue')

# fig.show()

In [None]:
# fig = plt.figure()
# ax = fig.add_subplot(111)

# # Lz_max = np.percentile(np.abs(Lz), 90)
# # Lz_min = -Lz_max
# # Lz_N, Lz_edges = np.histogram( Lz, bins=50, range=(Lz_min, Lz_max) )
# # Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])

# Lz_max = np.percentile(np.abs(Lz), 90)
# Lz_min = -Lz_max
# Lz_N, Lz_edges = np.histogram( Lz, bins=50, range=(Lz_min, Lz_max) )
# Lz_centers = 0.5*(Lz_edges[1:] + Lz_edges[:-1])
# Lz_mirror_sum = (Lz_N + Lz_N[::-1])
# Lz_asym_mask = Lz_mirror_sum > 0
# Lz_asym = np.ones_like(Lz_centers)
# Lz_asym = Lz_N[Lz_asym_mask] / Lz_mirror_sum[Lz_asym_mask]
# # Evaluate the tanh rotation asymmetry and the objective function

# params = [0.95, 200.]
# frot, chi = params
# Lz_asym_model = tanh_rotation_function_asymmetry(Lz_centers, frot=frot, 
#     chi=chi)
# # logobj = np.log(np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask]))*Lz_N[Lz_asym_mask]/np.sum(Lz_N[Lz_asym_mask])
# mass_frac = Lz_N/np.sum(Lz_N)
# _sigma = mass_frac[Lz_asym_mask]
# obj = np.exp(-0.5*np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask])/_sigma**2)
# logobj = np.log(obj)
# print(params, np.sum(logobj))
# ax.scatter(Lz_centers, logobj, color='Grey')
# ax.axhline(np.average(logobj), color='Grey', linestyle='dashed')

# params = [0.25, 500.]
# frot, chi = params
# Lz_asym_model = tanh_rotation_function_asymmetry(Lz_centers, frot=frot, 
#     chi=chi)
# # logobj = np.log(np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask]))*Lz_N[Lz_asym_mask]/np.sum(Lz_N[Lz_asym_mask])
# mass_frac = Lz_N/np.sum(Lz_N)
# _sigma = mass_frac[Lz_asym_mask]
# obj = np.exp(-0.5*np.square(Lz_asym[Lz_asym_mask] - Lz_asym_model[Lz_asym_mask])/_sigma**2)
# logobj = np.log(obj)
# print(params, np.sum(logobj))
# ax.scatter(Lz_centers, logobj, color='Blue')
# ax.axhline(np.average(logobj), color='Blue', linestyle='dashed')

# fig.show()

In [None]:
# frots = np.linspace(0., 1., 50)
# loglikes = np.zeros_like(frots)
# for i in range(len(frots)):
#     loglikes[i] = loglike_tanh_rotation_asym([frots[i], 500], Lz, 
#         usr_log_prior=None, usr_log_prior_params=[], parts=False)

# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.plot(frots, loglikes, color='Grey')

# fig.show()

In [None]:
# chis = np.linspace(0., 1000, 50)
# loglikes = np.zeros_like(chis)
# for i in range(len(chis)):
#     loglikes[i] = loglike_tanh_rotation_asym([0.25, chis[i]], Lz, 
#         usr_log_prior=None, usr_log_prior_params=[], parts=False)

# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.plot(chis, loglikes, color='Grey')

# fig.show()

In [None]:
# frots = np.linspace(0., 1., 100)
# chis = np.linspace(0, 1000, 100)

# loglikes = np.zeros((len(frots), len(chis)))
# min_loglike = 0.
# for i in range(len(frots)):
#     for j in range(len(chis)):
#         loglikes[i,j] = mloglike_tanh_rotation_asym([frots[i], chis[j]], 
#             Lz_fit, usr_log_prior=None, usr_log_prior_params=[], parts=False)
#         if loglikes[i,j] < min_loglike:
#             min_loglike = loglikes[i,j]
#             min_chi = chis[j]
#             min_frot = frots[i]
        

# fig = plt.figure()
# ax = fig.add_subplot(111)
# img = ax.pcolormesh(chis, frots, loglikes, vmin=-80, vmax=-50)
# ax.scatter(min_chi, min_frot, facecolor='none', edgecolor='Black', s=40)
# ax.set_xlabel(r'$\chi$')
# ax.set_ylabel(r'$f_{rot}$')
# fig.colorbar(img)
# fig.show()
