In [None]:
# ------------------------------------------------------------------------
#
# TITLE - apo_mock_imf
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Create the framework to generate mock APOGEE realizations of the stellar 
halo that can be used to test fitting routines. Investigate IMF sampling
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, glob, subprocess, warnings, dill as pickle

## Matplotlib
# import matplotlib as mpl
from matplotlib import pyplot as plt

## scipy
import scipy.integrate
import scipy.interpolate

## Project-specific
sys.path.insert(0,'../../../src/')
# import sample_project.module as project_module

### Scale parameters
ro = 8.275
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup
%matplotlib inline
plt.style.use('../../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

## Functions

In [None]:
def chabrier_imf(m,k=0.0193,A=1.):
    '''chabrier_imf:
    Chabrier initial mass function
    '''
    k = 0.0193 # Equalizes m<1 and m>1 at m=1
    a = 2.3
    
    if not isinstance(m,np.ndarray):
        m = np.atleast_1d(m)
    ##fi
    
    where_m_gt_1 = m>1
    Nm = np.empty(len(m))
    Nm[~where_m_gt_1] = (0.158/(np.log(10)*m[~where_m_gt_1]))\
                        *np.exp(-(np.log10(m[~where_m_gt_1])-np.log10(0.08))**2\
                               /(2*0.69**2))
    Nm[where_m_gt_1] = k*m[where_m_gt_1]**(-a)
    Nm[m<0.01] = 0
    return A*Nm
#def

def kroupa_imf(m,k1=1.):
    '''kroupa_imf:
    Kroupa initial mass function
    '''
    a1,a2,a3 = 0.3,1.3,2.3
    k2 = 0.08*k1
    k3 = 0.5*k2
    
    if not isinstance(m,np.ndarray):
        m = np.atleast_1d(m)
    ##fi
    
    where_m_1 = np.logical_and(m>=0.01,m<0.08)
    where_m_2 = np.logical_and(m>=0.08,m<0.5)
    where_m_3 = m>=0.5
    Nm = np.empty(len(m))
    Nm[where_m_1] = k1*m[where_m_1]**(-a1)
    Nm[where_m_2] = k2*m[where_m_2]**(-a2)
    Nm[where_m_3] = k3*m[where_m_3]**(-a3)
    Nm[m<0.01] = 0
    return Nm
#def

### Estimate the average mass of the IMFs over the default mass range (0.1 to 100)

In [None]:
ms = np.logspace(-1,2,100)
np.average(ms,weights=chabrier_imf(ms))

### Integrate the IMFs

In [None]:
scipy.integrate.quad(chabrier_imf, a=1.1, b=3)[0]

In [None]:
scipy.integrate.quad(chabrier_imf, a=3, b=np.inf)[0]

In [None]:
# scipy.integrate.quad(chabrier_imf, a=0.01, b=np.inf)
# scipy.integrate.quad(chabrier_imf, a=0.01, b=100)
# scipy.integrate.quad(kroupa_imf, a=0.01, b=np.inf)
# scipy.integrate.quad(kroupa_imf, a=0.01, b=100)

### Plot the IMFs

In [None]:
ms = np.logspace(-2,2,100)

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(np.log10(ms), np.log10(chabrier_imf(ms) ), color='DodgerBlue',
        label='Chabrier', linestyle='solid', zorder=1)
ax.plot(np.log10(ms), np.log10(kroupa_imf(ms,0.5) ), color='Red', 
        label='Kroupa', linestyle='dashed', zorder=2)
ax.legend()

ax.set_xlabel(r'$\log_{10}(m)$')
ax.set_ylabel(r'$\xi$')

fig.show()

## Determine the cumulative probability functions, then the inverse cumulative probability functions

In [None]:
def cimf(f,b,a=0.01,intargs=()):
    '''cimf:
    Calculate the cumulative of the initial mass function
    '''
    return scipy.integrate.quad(f,a,b,args=intargs)[0]
#def

def make_icimf_interpolator(imf,m_min=0.01,m_max=100):
    '''make_icimf_interpolator:
    Make inverse interpolator for the cumulative initial mass function which 
    allows you to map normalized (0 to 1) cumulative IMF onto mass 
    (m_min to m_max). Note that the interpolator maps onto log10(m).
    
    Args:
        imf (callable) - 
        m_min (float) - minimum mass (must be > 0)
        m_max (float) - maximum mass (must be finite)
        
    Returns:
        icimf_interp (scipy.interpolate.InterpolatedUnivariateSpline) - icimf
            interpolated spline
    '''
    assert m_min > 0 and np.isfinite(m_max), 'mass range out of bounds'
    ms = np.logspace(np.log10(m_min),np.log10(m_max),1000)
    cml_imf = np.empty(len(ms))
    for i in range(len(cml_imf)):
        cml_imf[i] = cimf(imf,ms[i],a=m_min)
    ###i
    
    # make sure that the cumulative IMF is normalized to 1
    cml_imf /= cml_imf[-1]

    return scipy.interpolate.InterpolatedUnivariateSpline(cml_imf,np.log10(ms),
                                                          k=3)
    
    # xi_min = _m_to_xi(m_min,a=1.)
    # xi_max = _m_to_xi(m_max,a=1.)
    # xis= numpy.arange(xi_min,xi_max,1e-4)
    # ms= _xi_to_m(xis,a=1.)

    # mnorm= mass(self._denspot,self._rmax,use_physical=False)
    # if self._rmin_sampling > 0:
    #     ms-= mass(self._denspot,self._rmin_sampling)
    #     mnorm-= mass(self._denspot,self._rmin_sampling)
    # ms/= mnorm
    # # Add total mass point
    # if numpy.isinf(self._rmax):
    #     xis= numpy.append(xis,1)
    #     ms= numpy.append(ms,1)
    
    # return scipy.interpolate.InterpolatedUnivariateSpline(ms,xis,k=3)
#def

# def _m_to_xi(m,a=1.):
#     out= numpy.divide((m/a-1.),(m/a+1.),where=True^numpy.isinf(r))
#     if numpy.any(numpy.isinf(m)):
#         if hasattr(m,'__len__'):
#             out[numpy.isinf(m)]= 1.
#         else:
#             return 1.
#     return out
# #def

# def _xi_to_m(xi,a=1.):
#     return a*np.divide(1.+xi,1.-xi)
# #def

In [None]:
chabrier_icimf_interp = make_icimf_interpolator(chabrier_imf)

In [None]:
kroupa_icimf_interp = make_icimf_interpolator(kroupa_imf)

In [None]:
fig = plt.figure(figsize=(8,3))
axs = fig.subplots(nrows=1,ncols=2)

icimfs = np.arange(0,1,0.001)
ms = np.logspace(-2,2,100)

chabrier_cimf = np.empty(len(ms))
kroupa_cimf = np.empty(len(ms))
for i in range(len(ms)):
    chabrier_cimf[i] = cimf(chabrier_imf,ms[i],a=0.01)
    kroupa_cimf[i] = cimf(kroupa_imf,ms[i],a=0.01)
###i
chabrier_cimf /= chabrier_cimf[-1]
kroupa_cimf /= kroupa_cimf[-1]

axs[0].plot(chabrier_icimf_interp(icimfs), icimfs, color='DodgerBlue',
            linestyle='solid', zorder=1, linewidth=2., label='Chabrier')
axs[0].plot(np.log10(ms), chabrier_cimf, color='Black', linestyle='dashed',
            zorder=2)

axs[1].plot(kroupa_icimf_interp(icimfs), icimfs, color='Red',
            linestyle='solid', zorder=1, linewidth=2., label='Kroupa')
axs[1].plot(np.log10(ms), kroupa_cimf, color='Black', linestyle='dashed', 
            zorder=2)

axs[0].set_xlabel(r'$\log_{10}(m)$')
axs[0].set_ylabel(r'Normalized cumulative IMF')
axs[1].set_xlabel(r'$\log_{10}(m)$')
axs[1].set_ylabel(r'Normalized cumulative IMF')

fig.show()

In [None]:
# Draw samples from the inverse cumulative IMFs and see if they match the IMFs

ms = np.logspace(-2,2,100)
icimfs = np.random.random(int(1e6))

chabrier_samples = chabrier_icimf_interp(icimfs)
kroupa_samples = kroupa_icimf_interp(icimfs)

fig = plt.figure(figsize=(8,3))
axs = fig.subplots(nrows=1,ncols=2)

# Weights convert logarithmic IMF to linear IMF (parameterized in functions)
# See Chabrier 2003 review for details
axs[0].plot(np.log10(ms), chabrier_imf(ms,A=8e5), color='DodgerBlue',
            label='Chabrier', linestyle='solid', zorder=1)
axs[0].hist(chabrier_samples, bins=20, range=(-2,2), log=True, edgecolor='DodgerBlue',
            weights=1/(np.log(10)*np.power(10,chabrier_samples)), histtype='step')

axs[1].plot(np.log10(ms), kroupa_imf(ms,4e5), color='Red', 
            label='Kroupa', linestyle='dashed', zorder=2)
axs[1].hist(kroupa_samples, bins=20, range=(-2,2), log=True, edgecolor='Red',
            weights=1/(np.log(10)*np.power(10,kroupa_samples)), histtype='step')

axs[0].set_xlim(-2,2)
axs[1].set_xlim(-2,2)

axs[0].set_xlabel(r'$\log_{10}(m)$')
axs[0].set_ylabel(r'$\log_{10}(\xi)$')
axs[1].set_xlabel(r'$\log_{10}(m)$')
axs[1].set_ylabel(r'$\log_{10}(\xi)$')

fig.subplots_adjust(wspace=0.25)
fig.show()

### Create and try out a mass sampling function

In [None]:
def sample_mass(m_tot,imf_type='chabrier',m_min=0.01,m_max=100.):
    '''sample_mass:
    Draw mass samples from an IMF
    
    Args:
        n (int) - Number of samples to draw
        m_tot (float) - Total mass worth of stars to sample
        imf_type (string) - IMF type, either chabrier or kroupa
        m_min - (float) - minimum mass
        m_max - (float) - maximum mass
    
    Returns:
        ms (np.array) - sampled masses (shape n)
    '''
    ms_for_avg = np.arange(m_min,m_max,0.01)
    # First make the icimf interpolator
    if imf_type == 'chabrier':
        m_avg = np.average(ms_for_avg,weights=chabrier_imf(ms_for_avg))
        icimf_interp = make_icimf_interpolator(chabrier_imf,m_min=m_min,
                                               m_max=m_max)
    elif imf_type == 'kroupa':
        m_avg = np.average(ms_for_avg,weights=kroupa_imf(ms_for_avg))
        icimf_interp = make_icimf_interpolator(kroupa_imf,m_min=m_min,
                                               m_max=m_max)
    ##fi
    
    # Guess how many samples to draw based on the average mass
    n_samples_guess = int(m_tot/m_avg)
    
    # Now draw first round of samples
    print('Drawing first samples...')
    icimf_samples = np.random.random(n_samples_guess)
    ms = np.power(10,icimf_interp(icimf_samples))
        
    # Add more samples or take some away depending on where things landed
    while np.sum(ms) < m_tot:
        print('Resampling...')
        n_samples_guess = int((m_tot-np.sum(ms))/m_avg)
        if n_samples_guess < 1: break
        icimf_samples = np.random.random(n_samples_guess)
        ms = np.append(ms,np.power(10,icimf_interp(icimf_samples)))
    ##wh
    if np.sum(ms) > m_tot:
        print('Removing some samples...')
        ms = ms[:np.where(np.cumsum(ms) > m_tot)[0][0]]
    ##fi
    
    return ms
#def

In [None]:
mtot = 3e8
t1 = time.time()
ms = sample_mass(mtot,m_min=0.09,m_max=0.87)
t2 = time.time()
print('Took '+str(round(t2-t1,1))+' s')

In [None]:
np.sum(ms)/mtot

In [None]:
len(ms)

### Success!

### Also save some sample masses with varying sizes

In [None]:
data_dir = '/geir_data/scr/lane/projects/ges-mass/apo_mocks/'

# Save 1e6 masses
np.save(data_dir+'masses_1e6.npy', ms[:int(1e6)])
# Save 1e7 masses
np.save(data_dir+'masses_1e7.npy', ms[:int(1e7)])
# Save 1e8 masses
np.save(data_dir+'masses_1e8.npy', ms[:int(1e8)])
# Save all masses
np.save(data_dir+'masses.npy', ms)