In [None]:
# ------------------------------------------------------------------------
#
# TITLE - apo_mock_complete
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Use all the functions to generate complete sets of samples
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, glob, subprocess, warnings, dill as pickle, time, gc
from astropy import units as apu

## Matplotlib
# import matplotlib as mpl
from matplotlib import pyplot as plt

## galpy
from galpy import orbit
from galpy import potential
# from galpy import actionAngle as aA
from galpy import df
from galpy.util import _rotate_to_arbitrary_vector

## APOGEE, isochrones, dustmaps
from apogee import select as apsel
from isodist import Z2FEH,FEH2Z
import mwdust
from mwdust.util.extCurves import aebv

## scipy
import scipy.integrate
import scipy.interpolate

## Astropy and healpix
from astropy.coordinates import SkyCoord
import healpy

## APOGEE mocks
import apomock

### Scale parameters
ro = 8.275
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup
%matplotlib inline
plt.style.use('../../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Helper functions

In [None]:

def chabrier_imf(m,k=0.0193,A=1.):
    '''chabrier_imf:
    Chabrier initial mass function
    '''
    k = 0.0193 # Equalizes m<1 and m>1 at m=1
    a = 2.3
    
    if not isinstance(m,np.ndarray):
        m = np.atleast_1d(m)
    ##fi
    
    where_m_gt_1 = m>1
    Nm = np.empty(len(m))
    Nm[~where_m_gt_1] = (0.158/(np.log(10)*m[~where_m_gt_1]))\
                        *np.exp(-(np.log10(m[~where_m_gt_1])-np.log10(0.08))**2\
                               /(2*0.69**2))
    Nm[where_m_gt_1] = k*m[where_m_gt_1]**(-a)
    Nm[m<0.01] = 0
    return A*Nm
#def

def kroupa_imf(m,k1=1.):
    '''kroupa_imf:
    Kroupa initial mass function
    '''
    a1,a2,a3 = 0.3,1.3,2.3
    k2 = 0.08*k1
    k3 = 0.5*k2
    
    if not isinstance(m,np.ndarray):
        m = np.atleast_1d(m)
    ##fi
    
    where_m_1 = np.logical_and(m>=0.01,m<0.08)
    where_m_2 = np.logical_and(m>=0.08,m<0.5)
    where_m_3 = m>=0.5
    Nm = np.empty(len(m))
    Nm[where_m_1] = k1*m[where_m_1]**(-a1)
    Nm[where_m_2] = k2*m[where_m_2]**(-a2)
    Nm[where_m_3] = k3*m[where_m_3]**(-a3)
    Nm[m<0.01] = 0
    return Nm
#defs

def _remove_orbits_outside_footprint(orbs,aposf,field_indx=None,
                                     chunk_size=None):
    '''_remove_orbits_outside_footprint:
    
    Remove stellar samples from outside the APOGEE observational footprint.
    Each plate has a variable field of view, and an inner 'hole' of 
    5 arcminutes. Using field_indx allows for selecting only a subset of the 
    available fields to use.
    
    Args:
        orbs (array) - Orbits representing samples
        aposf (array) - APOGEE selection function
        field_indx (array) - Indices of fields to consider
        chunk_size (int) - If not None, split up orbs into chunks of this 
            size to 
    
    Returns:
        fp_indx () - Index of samples that lie within the observational
            footprint
        fp_locid () - Location IDs of field each sample lies within
    '''
    # Account for field_indx, fields we want to consider
    if field_indx is None:
        field_indx = np.arange(0,len(aposf._apogeeFields),dtype=int)
    ##fi
    
    # field center coordinates, location IDs, radii
    glon = aposf._apogeeField['GLON'][field_indx]
    glat = aposf._apogeeField['GLAT'][field_indx]
    locids = aposf._locations[field_indx]
    radii = np.zeros(len(field_indx))
    for i in range(len(locids)):
        radii[i] = aposf.radius(locids[i])
    ###i
    
    # Make SkyCoord objects
    aposf_sc = SkyCoord(frame='galactic', l=glon*apu.deg, b=glat*apu.deg)
    orbs_sc = SkyCoord(frame='galactic', l=orbs.ll(), b=orbs.bb())
    
    # First nearest-neighbor match
    indx,sep,_ = orbs_sc.match_to_catalog_sky(aposf_sc)
    indx_radii = radii[indx]
    indx_locid = locids[indx]
    fp_indx = np.where(np.logical_and(sep < indx_radii*apu.deg,
                                      sep > 5.5*apu.arcmin))[0]
    fp_locid = indx_locid[fp_indx]
    
    # Second nearest-neighbor match for samples inside plate central holes
    where_in_hole = np.where(sep < 5.5*apu.arcmin)[0]
    indx2,sep2,_ = orbs_sc[where_in_hole].match_to_catalog_sky(aposf_sc,
                                                               nthneighbor=2)
    indx2_radii = radii[indx2]
    indx2_locid = locids[indx2]
    fp_indx2 = np.where(np.logical_and(sep2 < indx2_radii*apu.deg,
                                       sep2 > 5.5*apu.arcmin))[0]
    if len(fp_indx2) > 0:
        fp_indx = np.append(fp_indx,where_in_hole[fp_indx2])
        fp_locid = np.append(fp_locid,indx2_locid[fp_indx2])
    ##fi
    
    return fp_indx,fp_locid
#def

def _match_isochrone_to_samples(iso,ms,m_err=0.05,iso_keys=_parsec_1_2_iso_keys):
    '''_match_isochrone_to_samples:
    
    Match the samples to entries in an isochrone according to initial mass
    
    iso_keys must accept the following keys:
    'Mini' -> initial mass key
    
    Args:
        iso (array) - isochrone array
        ms (array) - sample masses
        m_err (float) - Maximum difference in mass between sample and isochrone
            for successful match
        iso_keys (dict) - Dictionary of keys for accessing the isochrone 
            properties, accessible via a common set of strings (see above)
        
    Returns:
        good_match (array) - Indices of ms which found matches in the isochrone 
            array within m_err tolerance
        match_indx (array) - array of matches, length len(good_match), 
            indexing ms into iso
    '''
    # Access initial mass
    mass_initial_key = iso_keys['mass_initial']
    m0 = iso[mass_initial_key]
    
    # Ensure isochrone is sorted by initial mass
    m0_argsort = np.argsort(m0)
    m0_sorted = m0[m0_argsort]
    
    # Search the sorted array for the nearest neighbors (fast)
    m0_mids = m0_sorted[1:] - np.diff(m0_sorted.astype('f'))/2
    idx = np.searchsorted(m0_mids, ms)
    cand_indx = m0_argsort[idx]
    residual = ms - m0_sorted[cand_indx]
    
    good_match = np.where( (np.abs(residual) < m_err) &\
                           (ms < m0[-1]) &\
                           (ms > m0[0])
                          )[0]

    match_indx = np.argsort(m0_argsort)[cand_indx[good_match]]

    np.all(np.abs(ms[good_match]-m0[match_indx]) <= m_err)
    
    return good_match,match_indx
#def

def _sample_r(denspot,n=1,r_min=0.,r_max=np.inf,a=1.):
    '''_sample_r:
    
    Draw radial position samples. Note the function interpolates the normalized 
    iCMF onto the variable xi, defined as:
    
    .. math:: \\xi = \\frac{r/a-1}{r/a+1}
    
    so that xi is in the range [-1,1], which corresponds to an r range of 
    [0,infinity)
    
    Args:
        denspot (galpy.potential.Potential) - galpy potential representing
            the density profile. Must be spherical
        n (int) - Number of samples
        r_min (float) - 
        r_max (float) - 
        a (float) - 
        
    Returns:
        r_samples (np.ndarray) - Radial position samples
    '''
    # First make the icmf interpolator
    icmf_xi_interp = _make_icmf_xi_interpolator(denspot,r_min=r_min,
        r_max=r_max,a=a)
    
    # Now draw samples
    icmf_samples = np.random.uniform(size=int(n))
    xi_samples = icmf_xi_interp(icmf_samples)
    return _xi_to_r(xi_samples,a=a)
#def

def _sample_position_angles(n=1):
    '''_sample_position_angles:
    
    Draw galactocentric, spherical angle samples.
    
    Args:
        n (int) - Number of samples
    
    Returns:
        phi_samples (np.ndarray) - Spherical azimuth
        theta_samples (np.ndarray) - Spherical polar angle
    '''
    phi_samples= np.random.uniform(size=n)*2*np.pi
    theta_samples= np.arccos(1.-2*np.random.uniform(size=n))
    return phi_samples,theta_samples
#def

def _make_icmf_xi_interpolator(denspot,r_min=0.,r_max=np.inf,a=1.):
    '''_make_icmf_xi_interpolator:

    Create the interpolator object which maps the iCMF onto variable xi.
    Note - must use self.xi_to_r() on any output of interpolator
    Note - the function interpolates the normalized CMF onto the variable 
    xi defined as:

    .. math:: \\xi = \\frac{r-1}{r+1}

    so that xi is in the range [-1,1], which corresponds to an r range of 
    [0,infinity)

    Args:

    Returns
        icmf_xi_interpolator
    '''
    xi_min= _r_to_xi(r_min,a=a)
    xi_max= _r_to_xi(r_max,a=a)
    xis= np.arange(xi_min,xi_max,1e-4)
    rs= _xi_to_r(xis,a=a)
    
    try:
        ms = potential.mass(denspot,rs,use_physical=False)
    except (AttributeError,TypeError):
        ms = np.array([potential.mass(denspot,r,use_physical=False) for r in rs])
    ##te
    mnorm = potential.mass(denspot,r_max,use_physical=False)

    if r_min > 0:
        ms -= potential.mass(denspot,r_min,use_physical=False)
        mnorm -= potential.mass(denspot,r_min,use_physical=False)
    ms /= mnorm
        
    # Add total mass point
    if np.isinf(r_max):
        xis= np.append(xis,1)
        ms= np.append(ms,1)
    return scipy.interpolate.InterpolatedUnivariateSpline(ms,xis,k=3)
#def

def make_icimf_interpolator(imf,m_min=0.01,m_max=100):
    '''make_icimf_interpolator:
    Make inverse interpolator for the cumulative initial mass function which 
    allows you to map normalized (0 to 1) cumulative IMF onto mass 
    (m_min to m_max). Note that the interpolator maps onto log10(m).
    
    Args:
        imf (callable) - 
        m_min (float) - minimum mass (must be > 0)
        m_max (float) - maximum mass (must be finite)
        
    Returns:
        icimf_interp (scipy.interpolate.InterpolatedUnivariateSpline) - icimf
            interpolated spline
    '''
    assert m_min > 0 and np.isfinite(m_max), 'mass range out of bounds'
    ms = np.logspace(np.log10(m_min),np.log10(m_max),1000)
    cml_imf = np.empty(len(ms))
    for i in range(len(cml_imf)):
        cml_imf[i] = cimf(imf,ms[i],a=m_min)
    ###i
    
    # make sure that the cumulative IMF is normalized to 1
    cml_imf /= cml_imf[-1]

    return scipy.interpolate.InterpolatedUnivariateSpline(cml_imf,np.log10(ms),
                                                          k=3)

def cimf(f,b,a=0.01,intargs=()):
    '''cimf:
    Calculate the cumulative of the initial mass function
    '''
    return scipy.integrate.quad(f,a,b,args=intargs)[0]
#def

def _r_to_xi(r,a=1.):
    '''_r_to_xi:
    
    Convert r to the variable xi
    '''
    out= np.divide((r/a-1.),(r/a+1.),where=True^np.isinf(r))
    if np.any(np.isinf(r)):
        if hasattr(r,'__len__'):
            out[np.isinf(r)]= 1.
        else:
            return 1.
    return out
#def

def _xi_to_r(xi,a=1.):
    '''_xi_to_r:
    
    Convert the variable xi to r
    '''
    return a*np.divide(1.+xi,1.-xi)
#def

def _transform_zvecpa(x,y,z,zvec,pa):
    '''_transform_zvecpa:
    
    Transform coordinates using the axis-angle method. First align the
    z-axis of the coordinate system with a vector (zvec) and then rotate 
    about the new z-axis by an angle (pa).
    
    Args:
        x,y,z (array) - Coordinates
        zvec (list) - z-axis to align the new coordinate system
        pa (float) - Rotation about the transformed z-axis
        
    Returns:
        x_rot,y_rot,z_rot (array) - Rotated coordinates 
    '''
    pa_rot = np.array([[ np.cos(pa), np.sin(pa), 0.],
                      [-np.sin(pa), np.cos(pa), 0.],
                      [0.         , 0.        , 1.]])

    zvec /= np.sqrt(np.sum(zvec**2.))
    zvec_rot = _rotate_to_arbitrary_vector(np.array([[0.,0.,1.]]),
                                           zvec,inv=True)[0]
    R = np.dot(pa_rot,zvec_rot)
    
    xyz = np.squeeze(np.dstack([x,y,z]))
    if np.ndim(xyz) == 1:
        xyz_rot = np.dot(R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[0],xyz_rot[1],xyz_rot[2]
    else:
        xyz_rot = np.einsum('ij,aj->ai', R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[:,0],xyz_rot[:,1],xyz_rot[:,2]
    return x_rot,y_rot,z_rot
#def

def _transform_alpha_beta_gamma(x,y,z,alpha,beta,gamma):
    '''_transform_alpha_beta_gamma:
    
    Transform x,y,z coordinates by a yaw-pitch-roll transformation.
    
    Args:
        x,y,z (array) - Coordinates
        alpha (float) - Roll rotation about the x-axis 
        beta (float) - Pitch rotation about the transformed y-axis
        gamma (float) - Yaw rotation around twice-transformed z-axis
        
    Returns:
        x_rot,y_rot,z_rot (array) - Rotated coordinates 
    '''
    # Roll matrix
    Rx = np.zeros([3,3])
    Rx[0,0] = 1
    Rx[1]   = [0           , np.cos(alpha), -np.sin(alpha)]
    Rx[2]   = [0           , np.sin(alpha), np.cos(alpha)]
    # Pitch matrix
    Ry = np.zeros([3,3])
    Ry[0]   = [np.cos(beta), 0            , np.sin(beta)]
    Ry[1,1] = 1
    Ry[2]   = [-np.sin(beta), 0, np.cos(beta)]
    # Yaw matrix
    Rz = np.zeros([3,3])
    Rz[0]   = [np.cos(gamma), -np.sin(gamma), 0]
    Rz[1]   = [np.sin(gamma), np.cos(gamma), 0]
    Rz[2,2] = 1
    R = np.matmul(Rx,np.matmul(Ry,Rz))
    
    xyz = np.squeeze(np.dstack([x,y,z]))
    if np.ndim(xyz) == 1:
        xyz_rot = np.dot(R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[0],xyz_rot[1],xyz_rot[2]
    else:
        xyz_rot = np.einsum('ij,aj->ai', R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[:,0],xyz_rot[:,1],xyz_rot[:,2]
    return x_rot,y_rot,z_rot
#def
    
### Utilities

def load_parsec_isochrone(iso_dir,z,log_age,remove_wd_point=True):
    '''load_parsec_isochrone:
    
    Load a parsec isochrone from a set of old isochrones.
        
    For old=True the range of metallicities and ages is:
    0.0001 <= Z <= 0.0030 in spacing of 0.0001
    which equates roughly to -2.28 <= [FE/H] <= -0.8
    10 < log Age < 10.15 in spacing of 0.025
    
    
    Args:
        iso_dir (string) - Directory where the isochrones are located
        z (float) - metallicity (will use nearest)
        log_age (float) - log_age (will use nearest)
        remove_wd_point (bool) - Remove any WD-like points from the isochrone 
            before matching
    '''
    # Find which z to use
    grid_zs = np.arange(0.0001,0.0031,0.0001)
    grid_log_ages = np.arange(10,10.15,0.025)
    if z in grid_zs:
        z_load = z
    else:
        z_load = grid_zs[np.argmin(np.abs(z-grid_zs))]
        print('Using z='+str(z_load))
    ##ie
    
    # Get filename
    iso_name = 'parsec1.2-2mass-spitzer-wise-old'
    iso_filename = os.path.join(iso_dir,iso_name,iso_name+\
                                '-Z-{:<06.4}.dat.gz'.format(z_load))
    
    full_iso = np.genfromtxt(iso_filename, dtype=None, names=True, 
                             skip_header=11)
    
    # Find which log Age to use
    grid_log_ages = np.unique(full_iso['logAge'])
    if log_age in grid_log_ages:
        log_age_load = log_age
    else:
        log_age_load = grid_log_ages[np.argmin(np.abs(log_age-grid_log_ages))]
        print('Using log age='+str(log_age_load))
    ##ie
    
    # Extract the isochrone
    iso = full_iso[full_iso['logAge']==log_age_load]
    
    # Remove any points that look like WDs
    if remove_wd_point:
        wd_inds = np.zeros(len(iso),dtype=bool)
        is_wd = True
        ind = int(len(iso)-1)
        # Start at the end and work backwards until we find the TRGB
        while is_wd:
            is_wd = (np.diff(iso['Hmag'])[ind-1] > 0.) &\
                    (iso['logg'][ind] > iso['logg'][0])
            if is_wd:
                wd_inds[ind] = True
                ind -= 1
            ##fi
        ##wh
        iso = iso[~wd_inds]
    ##fi
    
    return iso
#def

def make_fake_allstar(iso,iso_match_indx,locid,fe_h):
    '''make_fake_allstar:
    
    Make a numpy structured array with all the fields required by the density 
    fitting algorithm to function properly.
    
    Args:
        iso (array) - Isochrone
        iso_match_indx (array) - Indices which match samples to isochrone points
        locid (array) - location IDs of APOGEE field where each sample lies
        fe_h (array) - [Fe/H] abundance of samples
    '''
    iso_match = iso[iso_match_indx]
    atype = np.dtype([('LOCATION_ID', 'i4'),
                     ('LOGG', 'f4'),
                     ('TEFF', 'f4'),
                     ('FE_H', 'f4')
                     ])
    allstar = np.empty(len(iso_match_indx), dtype=atype)
    allstar['LOCATION_ID'] = locid
    allstar['LOGG'] = iso_match['logg']
    allstar['TEFF'] = iso_match['logTe']
    allstar['FE_H'] = fe_h
    return allstar
#def

def join_orbs(orbs):
    '''join_orbs:
    
    Join a list of orbit.Orbit objects together
    
    '''
    for i,o in enumerate(orbs):
        if i == 0:
            ro = o._ro
            vo = o._vo
            vxvvs = o._call_internal()
        else:
            assert ro==o._ro and vo==o._vo, 'ro and/or vo do not match'
            vxvvs = np.append(vxvvs, o._call_internal(), axis=1)
    ###i
    return orbit.Orbit(vxvvs.T,ro=ro,vo=vo)

### Sampling functions

In [None]:
_DEGTORAD = np.pi/180.

def sample_mass(m_tot,imf_type='chabrier',m_min=0.01,m_max=100.):
    '''sample_mass:
    Draw mass samples from an IMF
    
    Args:
        n (int) - Number of samples to draw
        m_tot (float) - Total mass worth of stars to sample
        imf_type (string) - IMF type, either chabrier or kroupa
        m_min - (float) - minimum mass
        m_max - (float) - maximum mass
    
    Returns:
        ms (np.array) - sampled masses (shape n)
    '''
    ms_for_avg = np.arange(m_min,m_max,0.01)
    # First make the icimf interpolator
    if imf_type == 'chabrier':
        m_avg = np.average(ms_for_avg,weights=chabrier_imf(ms_for_avg))
        icimf_interp = make_icimf_interpolator(chabrier_imf,m_min=m_min,
                                               m_max=m_max)
    elif imf_type == 'kroupa':
        m_avg = np.average(ms_for_avg,weights=kroupa_imf(ms_for_avg))
        icimf_interp = make_icimf_interpolator(kroupa_imf,m_min=m_min,
                                               m_max=m_max)
    ##fi
    
    # Guess how many samples to draw based on the average mass
    n_samples_guess = int(m_tot/m_avg)
    
    # Now draw first round of samples
    print('Drawing first samples...')
    icimf_samples = np.random.random(n_samples_guess)
    ms = np.power(10,icimf_interp(icimf_samples))
        
    # Add more samples or take some away depending on where things landed
    while np.sum(ms) < m_tot:
        print('Resampling...')
        n_samples_guess = int((m_tot-np.sum(ms))/m_avg)
        if n_samples_guess < 1: break
        icimf_samples = np.random.random(n_samples_guess)
        ms = np.append(ms,np.power(10,icimf_interp(icimf_samples)))
    ##wh
    if np.sum(ms) > m_tot:
        print('Removing some samples...')
        ms = ms[:np.where(np.cumsum(ms) > m_tot)[0][0]]
    ##fi
    
    return ms
#def

def sample_positions(denspot,n=1,r_min=0.,r_max=np.inf,a=None,b=None,c=None,
                     zvec=None,pa=None,alpha=None,beta=None,gamma=None,
                     return_orbits=False,ro=8.,vo=220.,zo=0.):
    '''sample_positions:
    
    Draw position samples from the density profile. The density profile 
    must be spherically symmetric. Density profile can be modified to be 
    triaxial using b and c, and then rotated using either an axis-angle 
    (zvec and pa) or yaw-pitch-roll (alpha, beta, gamma) methods.
    
    Args:
        denspot (galpy.potential.Potential) - galpy potential representing
            the density profile. Must be spherical
        n (int) - Number of samples to draw
        r_min (float) - 
        r_max (float) - 
        a (float) - Density profile scale radius (optional)
        b (float) - triaxial y/x scale ratio
        c (float) - triaxial z/x scale ratio
        zvec (list) - z-axis to align the new coordinate system
        pa (float) - Rotation about the transformed z-axis
        alpha (float) - Roll rotation about the x-axis 
        beta (float) - Pitch rotation about the transformed y-axis
        gamma (float) - Yaw rotation around twice-transformed z-axis
        return_orbits (bool) - Return the orbits
    
    Returns:
        orbs (galpy.orbit.Orbit) - orbits (optional, otherwise None)
    '''
    
    # Draw radial and angular samples
    r_samples = _sample_r(denspot,n=n,r_min=r_min,r_max=r_max,a=a)
    phi_samples,theta_samples = _sample_position_angles(n=n)
    R_samples = r_samples*np.sin(theta_samples)
    z_samples = r_samples*np.cos(theta_samples)
    
    # Apply triaxial scalings and possibly a rotation
    if b is not None and c is not None:
        x_samples = R_samples*np.cos(phi_samples)
        y_samples = R_samples*np.sin(phi_samples)
        y_samples *= b
        z_samples *= c
        if zvec is not None and pa is not None:
            x_samples,y_samples,z_samples = _transform_zvecpa(x_samples,
                y_samples, z_samples, zvec, pa)
        elif alpha is not None and beta is not None and gamma is not None:
            x_samples,y_samples,z_samples = _transform_alpha_beta_gamma(
                x_samples, y_samples, z_samples, alpha, beta, gamma)
        ##ei
        R_samples = np.sqrt(x_samples**2.+y_samples**2.)
        phi_samples = np.arctan2(y_samples,x_samples)
    ##fi
    
    # Make into orbits
    orbs = orbit.Orbit(vxvv=np.array([R_samples,np.zeros(n),np.zeros(n),
        z_samples,np.zeros(n),phi_samples]).T,ro=ro,vo=vo,zo=zo)
    # self.orbs = orbs
    if return_orbits:
        return orbs
    ##fi
#def

def apply_APOGEE_selection_function(orbs,ms,aposf,iso,dmap,print_stats=False):
    '''apply_APOGEE_selection_function:
    
    Apply the APOGEE selection function to sample data
    
    Args:
        orbs (galpy.orbit.Orbit) - Orbits representing the samples
        ms (np.array) - Masses of the samples
        aposf (apogee.select.*) - APOGEE selection function
        iso (np.array) - Numpy array
        dmap (mwdust.DustMap3D) - Dust map
        
    Returns:
        
    '''
    
    # Remove fields without spectroscopic targets
    nspec = np.nansum(aposf._nspec_short,axis=1) +\
            np.nansum(aposf._nspec_medium,axis=1) +\
            np.nansum(aposf._nspec_long,axis=1)
    good_nspec_fields = np.where(nspec>=1.)[0]
    
    # Get info about APOGEE pointings
    aposf_Hmax = np.dstack([aposf._short_hmax,
                            aposf._medium_hmax,
                            aposf._long_hmax])[0]
    
    # Match samples to isochrone entries based on initial mass
    t1 = time.time()
    init_n = len(ms)
    m_err = 1e-4+np.diff(iso['Mini']).max()/2.
    good_iso_match,iso_match_indx = _match_isochrone_to_samples(iso,ms,m_err=m_err)
    # assert np.all(np.abs(ms[good_iso_match]-iso['Mini'][iso_match_indx]<=m_err))
    orbs = orbs[good_iso_match]
    ms = ms[good_iso_match]
    Hmag = iso['Hmag'][iso_match_indx]
    t2 = time.time()
    if print_stats:
        print(str(len(good_iso_match))+'/'+str(init_n)+\
              ' samples have good matches in the isochrone')
        print('Kept '+str(round(100*len(good_iso_match)/init_n,2))+\
              ' % of samples')
        print('Matching samples to isochrone entries took '+\
              str(round(t2-t1,1))+' s')
    ##fi
    init_n = len(ms)
    
    # Remove samples with apparent Hmag below faintest APOGEE Hmax
    t1 = time.time()
    dm = 5.*np.log10(orbs.dist().to(apu.pc).value)-5.
    where_good_Hmag1 = np.where(np.nanmax(aposf_Hmax) >\
        Hmag + dm)[0]
    orbs = orbs[where_good_Hmag1]
    ms = ms[where_good_Hmag1]
    dm = dm[where_good_Hmag1]
    Hmag = Hmag[where_good_Hmag1]
    iso_match_indx = iso_match_indx[where_good_Hmag1]
    t2=time.time()
    if print_stats:
        print(str(len(where_good_Hmag1))+'/'+str(init_n)+\
              ' samples are bright enough to be observed')
        print('Kept '+str(round(100*len(where_good_Hmag1)/init_n,2))+\
              ' % of samples')
        print('Removing samples with H-band magnitudes below faintest APOGEE'+\
              ' Hmax limit took '+str(round(t2-t1,1))+' s')
    ##fi
    init_n = len(ms)
        
    # Remove samples that lie outside the APOGEE observational footprint
    t1 = time.time()
    fp_indx,locid = _remove_orbits_outside_footprint(orbs,aposf,good_nspec_fields)
    orbs = orbs[fp_indx]
    ms = ms[fp_indx]
    dm = dm[fp_indx]
    Hmag = Hmag[fp_indx]
    iso_match_indx = iso_match_indx[fp_indx]
    t2 = time.time()
    if print_stats:        
        print('Removing samples outside observational footprint took '+\
              str(round(t2-t1,1))+' s')
        print(str(len(fp_indx))+'/'+str(init_n)+' samples found within'+\
              ' observational footprint')
        print('Kept '+str(round(100*len(fp_indx)/init_n,2))+\
              ' % of samples')
    ##fi
    init_n = len(ms)
    
    # Remove samples with apparent Hmag below faintest Hmax on field-by-field 
    # basis
    t1 = time.time()
    field_Hmax = np.nanmax(aposf_Hmax, axis=1)
    locid_inds = np.where(locid.reshape(locid.size, 1) == aposf._locations)[1]
    Hmax = field_Hmax[locid_inds]
    where_good_Hmag2 = np.where(Hmax > Hmag + dm)[0]
    # pdb.set_trace()
    # Access values with good H-magnitudes
    orbs = orbs[where_good_Hmag2]
    locid = locid[where_good_Hmag2]
    ms = ms[where_good_Hmag2]
    dm = dm[where_good_Hmag2]
    Hmag = Hmag[where_good_Hmag2]
    iso_match_indx = iso_match_indx[where_good_Hmag2]
    Jmag = iso['Jmag'][iso_match_indx]
    Ksmag = iso['Ksmag'][iso_match_indx]
    t2=time.time()
    if print_stats:
        print('Removing samples with H-band magnitudes outside '+\
              'observational limits took '+str(round(t2-t1,1))+' s')
        print(str(len(where_good_Hmag2))+'/'+str(init_n)+\
                  ' samples are bright enough to be observed')
        print('Kept '+str(round(100*len(where_good_Hmag2)/init_n,2))+\
              ' % of samples')
    ##fi
    init_n = len(ms)
        
    # Get lbIndx for the dust map
    t1 = time.time()
    gl = orbs.ll(use_physical=True).value
    gb = orbs.bb(use_physical=True).value
    dist = np.atleast_2d(orbs.dist(use_physical=True).value).T
    # Information about the dust map
    dmap_nsides = np.array(dmap._nsides)
    nside_pix = np.zeros((len(orbs),len(dmap_nsides)))
    nside_arr = np.repeat(dmap_nsides[:,np.newaxis],len(orbs),axis=1).T
    for i in range(len(dmap_nsides)):
        nside_pix[:,i] = healpy.pixelfunc.ang2pix(dmap_nsides[i],
                                                  (90.-gb)*_DEGTORAD,
                                                   gl*_DEGTORAD, nest=True)
    # Calculate healpix u
    dmap_hpu = (dmap._pix_info['healpix_index'] + 4*dmap._pix_info['nside']**2.).astype(int)
    hpu = (nside_pix + 4*nside_arr**2).astype(int)
    # Use searchsorted to get the indices
    dmap_hpu_argsort = np.argsort(dmap_hpu)
    dmap_hpu_sorted = dmap_hpu[dmap_hpu_argsort]
    hpu_indx_sorted = np.searchsorted(dmap_hpu_sorted,hpu)
    hpu_indx = np.take(dmap_hpu_argsort, hpu_indx_sorted, mode="clip")
    hpu_mask = dmap_hpu[hpu_indx] != hpu
    hpu_ma = np.ma.array(hpu_indx, mask=hpu_mask)
    # Check if somehow a sample has no matches or more than one match. The 
    # former should not happen ever, the later has been observed.
    hpu_ma_sum = np.sum(~hpu_ma.mask,axis=1)
    lbIndx = np.zeros(len(orbs))
  
    # lbIndx = hpu_ma.data[~hpu_ma.mask]
    if np.any(hpu_ma_sum==0):
        # Need to code in a solution
        raise RuntimeError('A sample did not find a lbIndx in the dust map')
    hpu_ma_multi = hpu_ma_sum>1
    if np.any(hpu_ma_multi):
        print('Warning: At least one sample has more than one match at'\
              +'different Nside, choosing the highest resolution matches')
        lbIndx[~hpu_ma_multi] = hpu_ma.data[~hpu_ma_multi][~hpu_ma.mask[~hpu_ma_multi]]
        where_hpu_ma_multi = np.where(hpu_ma_multi)[0]
        for i in range(len(where_hpu_ma_multi)):
            multi_ind = where_hpu_ma_multi[i]
            # Choose the highest resolution index
            lbIndx[multi_ind] = hpu_ma.data[multi_ind][~hpu_ma.mask[multi_ind]][0]
    else:
        lbIndx = hpu_ma.data[~hpu_ma.mask]
    
    t2 = time.time()
    if print_stats:
        print('Getting lbIndx took '+str(round(t2-t1,1))+' s')
    ##fi
        
    # Compute AH
    t1 = time.time()
    unique_lbIndx = np.unique(lbIndx).astype(int)
    AH = np.zeros(len(orbs))
    for i in range(len(unique_lbIndx)):
        # First find which samples have this lbIndx
        where_unique = np.where(lbIndx == unique_lbIndx[i])[0]
        # Get the dust map interpolation data for this lbIndx
        dmap_interp_data = scipy.interpolate.InterpolatedUnivariateSpline(
            dmap._distmods, dmap._best_fit[unique_lbIndx[i]], k=dmap._interpk)
        # Calcualate AH
        eBV_to_AH = mwdust.util.extCurves.aebv(dmap._filter,sf10=dmap._sf10)
        AH[where_unique] = dmap_interp_data(dm[where_unique])*eBV_to_AH
    ###i
    t2 = time.time()
#     if print_stats:
#         print('Getting AH took '+str(t2-t1)+' s')
#     ##fi
    
    # Apply the selection function
    t1 = time.time()
    sf_keep_indx = np.zeros(len(orbs),dtype=bool)
    for i in range(len(orbs)):
        random_n = np.random.random(size=1)[0]
        _H = Hmag[i] + dm[i] + AH[i]
        _JK0 = Jmag[i] - Ksmag[i]
        compare_n = aposf(locid[i],_H,_JK0)
        if compare_n > random_n: sf_keep_indx[i] = True
    t2 = time.time()
    if print_stats:
        print('Applying selection function took '+str(round(t2-t1,1))+' s')
        print(str(np.sum(sf_keep_indx))+'/'+str(init_n)+\
                  ' samples survive the selection function')
        print('Kept '+str(round(100*np.sum(sf_keep_indx)/init_n,2))+\
              ' % of samples')
    ##fi
    
    orbs = orbs[sf_keep_indx]
    locid = locid[sf_keep_indx]
    ms = ms[sf_keep_indx]
    iso_match_indx = iso_match_indx[sf_keep_indx]
    
    return orbs, locid, ms, iso_match_indx
#def

### Load selection function, dust map, isochrone, mock orbits and masses

In [None]:
aposf_data_dir = '/geir_data/scr/lane/projects/ges-mass/data/gaia_apogee/'+\
                 'apogee_dr16_l33_gaia_dr2/'
with open(aposf_data_dir+'apogee_SF.dat','rb') as f:
    aposf = pickle.load(f)
##wi

In [None]:
dmap = mwdust.Combined19(filter='2MASS H') # dustmap from mwdust

In [None]:
_parsec_1_2_iso_keys = {'mass_initial':'Mini',
                        'z_initial':'Zini',
                        'log_age':'logAge',
                        'jmag':'Jmag',
                        'hmag':'Hmag',
                        'ksmag':'Ksmag',
                        'logg':'logg',
                        'logteff':'logTe'
                        }
z = 0.0010
log_age = 10.0
iso_grid = np.load('/geir_data/scr/lane/projects/ges-mass/data/gaia_apogee/apogee_dr16_l33_gaia_dr2/iso_grid.npy')
iso = iso_grid[(iso_grid['Zini']==z) & (iso_grid['logAge']==log_age)]
iso = iso[iso['logL']>-9.]

In [None]:
np.save('./data/iso_z_'+str(z)+'_log_age_'+str(log_age)+'.npy',iso)

### Try the APOGEEMock class

In [None]:
alpha = 2.5
denspot_args = {'alpha':alpha}
denspot = potential.PowerSphericalPotential(**denspot_args, ro=ro, vo=vo)

fallstar = []
orbs = []
m_tot = 2e6
n_chunk = 2

for i in range(n_chunk):
    mock = apomock.APOGEEMock(denspot, ro=ro, vo=vo)
    mock.load_isochrone(iso=iso, iso_keys=_parsec_1_2_iso_keys)

    print('Sampling masses')
    t1 = time.time()
    
    m_min = 0.08
    mock.sample_masses(m_tot/n_chunk, m_min=m_min)
    t2 = time.time()
    print('Took '+str(t2-t1)+' s')

    print('Sampling positions')
    t1 = time.time()
    r_min = 2./ro
    r_max = 70./ro
    mock.sample_positions(r_min=r_min, r_max=r_max)
    t2 = time.time()
    print('Took '+str(t2-t1)+' s')

    print('Applying selection function')
    t1 = time.time()
    mock.apply_selection_function(aposf, dmap)
    t2 = time.time()
    print('Took '+str(t2-t1)+' s')

    fallstar.append( mock.make_allstar() )
    orbs.append( mock.orbs )
    if i+1 < n_chunk:
        del mock
    gc.collect()

print('Found '+str(len(fallstar))+' samples')

In [None]:
summary = mock._write_mock_summary()

In [None]:
with open('test_summary.txt','w') as f:
    for line in summary:
        f.write(line)
        f.write('\n')


In [None]:
mock_number = '5'
with open('./data/mock_'+mock_number+'/orbs.pkl','wb') as f:
    pickle.dump(orbs,f)
##wi
np.save('./data/mock_'+mock_number+'/allstar.npy',fallstar)

In [None]:
# Cut bulge fields. Within 20 degrees of the galactic center
omask_bulge = ~(((orbs.ll().value > 340.) |\
                (orbs.ll().value < 20.)) &\
               (np.fabs(orbs.bb().value) < 20.)
              )
# Cut fields containing enhancements of globular cluster stars
gc_locid = [2011,4353,5093,5229,5294,5295,5296,5297,5298,5299,5300,5325,5328,
            5329,5438,5528,5529,5744,5801]
omask_gc = ~np.isin(fallstar['LOCATION_ID'],gc_locid)
omask_logg = (fallstar['LOGG'] > 1.) & (fallstar['LOGG'] < 3.)

omask = omask_bulge & omask_gc & omask_logg
print(str(np.sum(omask))+' samples survived observational masking')

np.save('./data/mock_'+mock_number+'/omask.npy',omask)

## Do it the way with more steps with the old routines

In [None]:
mtot = 1e8
t1 = time.time()
ms = sample_mass(mtot,m_min=0.08,m_max=iso['Mini'].max())
t2 = time.time()
print('Took '+str(round(t2-t1,1))+' s')

### Sample Positions

In [None]:
t1 = time.time()

# rc = 30./ro
alpha = 2.0
denspot_args = {'alpha':alpha}
denspot = potential.PowerSphericalPotential(**denspot_args)

potential.turn_physical_off(denspot)
r_min = 2./ro
r_max = 70./ro
a = 1.
b=1.
c=1.
rot_zvec = np.array([0.,0.,1.])
rot_pa = 0.
rot_alpha = 0.
rot_beta = 0.
rot_gamma = 0.
orbs = sample_positions(denspot,n=len(ms),r_min=r_min,r_max=r_max,a=a,
                        b=b,c=c,zvec=None,pa=None,return_orbits=True,
                        ro=ro,vo=vo,zo=zo)
t2 = time.time()
print('Took '+str(round(t2-t1,1))+' s')

In [None]:
orbs_sf,locid,ms_sf,iso_match_indx = apply_APOGEE_selection_function(
                                        orbs,ms,aposf,iso,dmap,
                                        print_stats=True)
print('Found '+str(len(orbs_sf))+' samples')

In [None]:
fallstar = make_fake_allstar(iso,iso_match_indx,locid,
                              np.ones(len(orbs_sf))*Z2FEH(z))

In [None]:
mock_number = '5'
with open('./data/mock_'+mock_number+'/orbs.pkl','wb') as f:
    pickle.dump(orbs_sf,f)
##wi
np.save('./data/mock_'+mock_number+'/allstar.npy',fallstar)

### Apply the obsevational masks for fitting
Exclude stars which:
1. Lie within a 20 degree square around the bulge
2. Lie within a field containing a globular cluster
3. Have log(g) > 3 or log(g) < 1

In [None]:
# Cut bulge fields. Within 20 degrees of the galactic center
omask_bulge = ~(((orbs_sf.ll().value > 340.) |\
                (orbs_sf.ll().value < 20.)) &\
               (np.fabs(orbs_sf.bb().value) < 20.)
              )
# Cut fields containing enhancements of globular cluster stars
gc_locid = [2011,4353,5093,5229,5294,5295,5296,5297,5298,5299,5300,5325,5328,
            5329,5438,5528,5529,5744,5801]
omask_gc = ~np.isin(fallstar['LOCATION_ID'],gc_locid)
omask_logg = (fallstar['LOGG'] > 1.) & (fallstar['LOGG'] < 3.)

omask = omask_bulge & omask_gc & omask_logg
print(str(np.sum(omask))+' samples survived observational masking')

np.save('./data/mock_'+mock_number+'/omask.npy',omask)

### Batch compute for large numbers of samples

In [None]:
# Total mass and batch size
mtot = 2e8
batch_m = 5e7
batch_n = int(mtot / batch_m)

# Potential
# rc = 30./ro
alpha = 2.0
denspot_args = {'alpha':alpha}
denspot = potential.PowerSphericalPotential(**denspot_args)
potential.turn_physical_off(denspot)

# Position sampling
r_min = 2./ro
r_max = 50./ro
a = 1.
b=0.8
c=0.5
rot_zvec = np.array([0.,0.,1.])
rot_pa = np.pi/4
rot_alpha = 0.
rot_beta = 0.
rot_gamma = 0.

orbs_b = []
ms_b = []
locid_b = []
iso_match_indx_b = []

for i in range(batch_n):
    batch_mtot = mtot / batch_n
    print('\nDoing batch '+str(i+1)+', mass: '+str(batch_mtot/1e8)+'e8 Msun')
    
    t1 = time.time()
    ms = sample_mass(batch_mtot,m_min=0.09,m_max=0.87)
    t2 = time.time()
    print('Masses took '+str(round(t2-t1,1))+' s')
    
    t1 = time.time()
    orbs = sample_positions(denspot,n=len(ms),r_min=r_min,r_max=r_max,a=a,
                        b=b,c=c,zvec=rot_zvec,pa=rot_pa,return_orbits=True,
                        ro=ro,vo=vo,zo=zo)
    t2 = time.time()
    print('Orbits took '+str(round(t2-t1,1))+' s')
    
    t1 = time.time()
    orbs_sf,locid,ms_sf,iso_match_indx = apply_APOGEE_selection_function(
                                        orbs,ms,aposf,iso,dmap,print_stats=False)
    t2 = time.time()
    print('SF application took '+str(round(t2-t1,1))+' s')
    
    orbs_b.append(orbs_sf)
    ms_b.append(ms_sf)
    locid_b.append(locid)
    iso_match_indx_b.append(iso_match_indx)
    
    del orbs,ms
    gc.collect()

In [None]:
ms_sf = np.concatenate(ms_b)
locid = np.concatenate(locid_b)
iso_match_indx = np.concatenate(iso_match_indx_b)
orbs_sf = join_orbs(orbs_b)

In [None]:
fallstar = make_fake_allstar(iso,iso_match_indx,locid,
                              np.ones(len(orbs_sf))*Z2FEH(z))

In [None]:
mock_number = '4'
with open('./data/mock_'+mock_number+'/orbs.pkl','wb') as f:
    pickle.dump(orbs_sf,f)
##wi
np.save('./data/mock_'+mock_number+'/allstar.npy',fallstar)

In [None]:
# Cut bulge fields. Within 20 degrees of the galactic center
omask_bulge = ~(((orbs_sf.ll().value > 340.) |\
                (orbs_sf.ll().value < 20.)) &\
               (np.fabs(orbs_sf.bb().value) < 20.)
              )

# Cut fields containing enhancements of globular cluster stars
gc_locid = [2011,4353,5093,5229,5294,5295,5296,5297,5298,5299,5300,5325,5328,
            5329,5438,5528,5529,5744,5801]
omask_gc = ~np.isin(fallstar['LOCATION_ID'],gc_locid)

omask_logg = (fallstar['LOGG'] > 1.) & (fallstar['LOGG'] < 3.)

omask = omask_bulge & omask_gc & omask_logg

np.save('./data/mock_'+mock_number+'/omask.npy',omask)