In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 1_fit_anisotropy.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Do some fits to the velocity dispersion anisotropy.
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle, time, pdb, multiprocessing, logging

## Plotting
import matplotlib as mpl
from matplotlib import pyplot as plt
import corner

## Astropy
from astropy import units as apu

## Analysis
import scipy.stats
import scipy.interpolate
import emcee

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import kinematics as pkin
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 
    'notebooks/4_fit_distribution_functions/1_fit_anisotropy/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Define the likelihood and any priors

Likelihood will be Gaussian with mean from the model. The variances will be estimated from the data if provided. Otherwise can proxy variances with bin counts, provided by `mass`. If bin counts are Poisson, then the variance is the bin count, equal to the mass. Since this is $\beta$, we normalize by the total mass. So the variance for the likelihood is $m_{i}/\sum_{i} m_{i}$

In [None]:
def mloglike_beta(*args, **kwargs):
    return -loglike_beta(*args, **kwargs)

def loglike_beta(params, model, r, beta, sigma=None, mass=None, 
    usr_log_prior=None, usr_log_prior_params=[], parts=False):
    # Evaluate the domain prior
    if not domain_prior_beta(model, params):
        return -np.inf
    # Evaluate the prior on the beta model
    logprior = logprior_beta(model, params)
    # Evaluate any user supplied prior
    if callable(usr_log_prior):
        usrlogprior = usr_log_prior(params, *usr_log_prior_params)
        if np.isinf(usrlogprior):
            return -np.inf
    else:
        usrlogprior = 0
    # Evaluate the model
    beta_model = model(r, *params)
    # Sigma for the likelihood
    if sigma is not None:
        _sigma = sigma
    elif mass is not None:
        mass_frac = mass/np.sum(mass)
        _sigma = mass_frac
    else:
        _sigma = 0.1
    # Compute the log objective
    logobj = -0.5*np.square(beta - beta_model)/_sigma**2
    # Compute the log likelihood
    loglike = np.sum(logobj) + logprior + usrlogprior
    if parts:
        return loglike, np.sum(logobj), logprior, usrlogprior
    else:
        return loglike

def logprior_beta(model, params):
    if model.__name__ == 'beta_osipkov_merritt':
        ra, = params
        # log prior on ra
        ra_min = 0.0001
        ra_max = 10000.
        prior_ra = scipy.stats.loguniform.pdf(ra, ra_min, ra_max)
        return np.log(prior_ra)
    if model.__name__ == 'beta_cuddeford91':
        ra, alpha = params
        # log prior on ra
        ra_min = 0.0001
        ra_max = 10000.
        prior_ra = scipy.stats.loguniform.pdf(ra, ra_min, ra_max)
        return np.log(prior_ra)
    return 0.

def domain_prior_beta(model, params):
    if model.__name__ == 'beta_constant':
        beta, = params
        if beta >= 1.:
            return False
    elif model.__name__ == 'beta_osipkov_merritt':
        ra, = params
        if ra <= 0.:
            return False
    elif model.__name__ == 'beta_cuddeford91':
        ra, alpha = params
        if ra <= 0.:
            return False
        if alpha <= -1.:
            return False
    return True

### Load some data and try and do some fits at $z=0$

In [None]:
# Some keywords and properties
verbose = True

# Paths
df_fitting_dir = os.path.join(fitting_dir_base,'distribution_function')
os.makedirs(df_fitting_dir,exist_ok=True)

# Begin logging
log_filename = './log/1_fit_anisotropy.log'
if os.path.exists(log_filename):
    os.remove(log_filename)
logging.basicConfig(filename=log_filename, level=logging.INFO, filemode='w', 
    force=True)
logging.info('Beginning constant anisotropy DF creation. Time: '+\
             time.strftime('%a, %d %b %Y %H:%M:%S',time.localtime()))

# Models to fit and fitting params
models = [pkin.beta_constant, pkin.beta_osipkov_merritt, pkin.beta_cuddeford91]
inits = [[0.], [10.], [10.,0.]]
mcmc_labels = [[r'$\beta$',],
               [r'$r_a$',],
               [r'$r_a$',r'$\alpha$',]]
param_names = [['beta'],
               ['ra'],
               ['ra','alpha']]
plot_labels = [r'Constant-$\beta$',
               r'Osipkov-Merritt',
               r'Cuddeford 1991',]
plot_colors = ['DodgerBlue','Crimson','ForestGreen']
plot_linestyles = ['solid','dashed','dotted']
plot_zorders = [1,2,3]
df_type = ['constant_beta','osipkov_merritt','cuddeford91']
fit_version = ['anisotropy_params_softening','anisotropy_params_softening',
               'anisotropy_params_softening']
fig_version = 'anisotropy_params_softening' # Singular because all plotted at same time
# MCMC params
nwalkers = 100
nit = 1000
ncut = 500
nprocs = 12
# Binning params
n_bs = 100 # Number of bootstrap samples

for i in range(n_mw):
    # if i != 0: continue

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    co = pcutout.TNGCutout(
        primary.get_cutout_filename(mw_analog_dir,snapnum=primary.snapnum[0]))
    co.center_and_rectify()
    pid = co.get_property('stars','ParticleIDs')

    # Loop over the major mergers
    for j in range(n_major):
        # if j > 0: continue
        if verbose:
            msg = f'Analyzing major merger {j+1}/{n_major} for MW {i+1}/{n_mw}'
            logging.info(msg)
            print(msg)

        # Get the major merger
        major_merger = primary.tree_major_mergers[j]
        major_acc_sid = major_merger.subfind_id[0]
        major_mlpid = major_merger.secondary_mlpid
        upid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        indx = np.where(np.isin(pid,upid))[0]
        orbs = co.get_orbs('stars')[indx]
        rs = orbs.r().to(apu.kpc).value
        norbs = len(orbs)
        # _n_bin = 500 if (len(orbs) > 5000) else round(len(orbs)/10)
        n_bin = np.min([500, len(orbs)//10]) # n per bin

        # Compute the binning
        r_softening = putil.get_softening_length('stars', z=0, physical=True)
        rmin = np.max([r_softening, np.min(rs)])
        # rmin = 0.
        rmax = np.max(rs)
        adaptive_binning_kwargs = {
            'n':n_bin,
            'rmin':rmin,
            'rmax':rmax,
            'bin_mode':'exact numbers',
            'bin_equal_n':True,
            'end_mode':'ignore',
            'bin_cents_mode':'median',
        }
        bin_edges, bin_cents, bin_n = pkin.get_radius_binning(orbs, 
            **adaptive_binning_kwargs)
        
        # Compute the ingredients for the anisotropy, use dispersions
        compute_betas_kwargs = {'use_dispersions':True,
                                'return_kinematics':True}
        beta, vr2, vp2, vz2 = pkin.compute_betas_bootstrap(orbs,bin_edges,
            n_bootstrap=n_bs, compute_betas_kwargs=compute_betas_kwargs)
        lbeta, mbeta, ubeta = np.percentile(beta, [16,50,84], axis=0)
        sbeta = ubeta-lbeta

        # Plotting directory
        this_fig_dir = os.path.join(fig_dir, fig_version, str(z0_sid), 
            'merger_'+str(j+1))
        os.makedirs(this_fig_dir, exist_ok=True)

        # Make the figure showing the beta profile and the different fits
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.plot(bin_cents, mbeta, color='Black')
        ax.fill_between(bin_cents, lbeta, ubeta, color='Black', alpha=0.25)

        ## Fit the different models
        for k in range(3):
            if k == 2: continue
            if verbose:
                msg = f'Fitting anisotropy model '+df_type[k]
                logging.info(msg)
                print(msg)
            model = models[k]
            init = inits[k]
            
            # Optimizing
            opt_fn = lambda params: mloglike_beta(params, model, bin_cents, 
                mbeta, sigma=sbeta, mass=None, usr_log_prior=None, 
                usr_log_prior_params=[], parts=False)
            opt = scipy.optimize.minimize(opt_fn, init, method='Nelder-Mead',
                options={'maxiter':1000,})
            
            # MCMC
            def llfunc(params):
                return loglike_beta(params, model, bin_cents, mbeta, 
                    sigma=sbeta, mass=None, usr_log_prior=None, 
                    usr_log_prior_params=[], parts=False)
            mcmc_init = np.array([
                opt.x+0.05*np.random.randn(len(opt.x)) for i in range(nwalkers)
            ])
            with multiprocessing.Pool(processes=nprocs) as pool:
                sampler = emcee.EnsembleSampler(nwalkers, len(mcmc_init[0]), 
                    llfunc, pool=pool)
                sampler.run_mcmc(mcmc_init, nit, progress=True)
            chain = sampler.get_chain(flat=True, discard=ncut)
            
            # Corner plot for these results
            figc = corner.corner(chain, labels=mcmc_labels[k], 
                quantiles=[0.16,0.5,0.84], truths=opt.x, 
                truth_color='Red')
            figname = os.path.join(this_fig_dir,df_type[k]+'_corner.png')
            figc.savefig(figname)
            plt.close(figc)

            # Add the results on the common figure
            # indx = np.random.choice(np.arange(len(chain)), size=nplot, replace=True)
            beta_model = np.ones((chain.shape[0],len(bin_cents)))
            for m in range(len(chain)):
                beta_model[m,:] = model(bin_cents, *chain[m])
            ax.plot(bin_cents, np.median(beta_model, axis=0), 
                color=plot_colors[k], label=plot_labels[k], 
                zorder=plot_zorders[k], linestyle=plot_linestyles[k])
            ax.fill_between(bin_cents, np.percentile(beta_model,16,axis=0), 
                np.percentile(beta_model,84,axis=0), color=plot_colors[k], 
                alpha=0.25)

            # Save the results
            this_fitting_dir = os.path.join(df_fitting_dir, df_type[k],
                fit_version[k], str(z0_sid),'merger_'+str(j+1))
            os.makedirs(this_fitting_dir, exist_ok=True)
            opt_filename = os.path.join(this_fitting_dir,'opt.pkl')
            with open(opt_filename,'wb') as handle:
                pickle.dump(opt, handle)
            sampler_filename = os.path.join(this_fitting_dir,'sampler.pkl')
            with open(sampler_filename,'wb') as handle:
                pickle.dump(sampler, handle)
            chain_filename = os.path.join(this_fitting_dir,'chain.pkl')
            with open(chain_filename,'wb') as handle:
                pickle.dump([chain,param_names[k]], handle)

        ax.set_xlabel(r'$r$ [kpc]')
        ax.set_ylabel(r'$\beta$')
        ax.legend(loc='best', frameon=False, fontsize=8)
        fig.tight_layout()
        figname = os.path.join(this_fig_dir,'beta_fits.png')
        fig.savefig(figname)
        plt.close(fig)