In [None]:
# ------------------------------------------------------------------------
#
# TITLE - apo_mock_sf
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Create the framework to generate mock APOGEE realizations of the stellar 
halo that can be used to test fitting routines. Investigate selection function
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, glob, subprocess, warnings, dill as pickle, time
from astropy import units as apu

## Matplotlib
# import matplotlib as mpl
from matplotlib import pyplot as plt

## galpy
from galpy import orbit
from galpy import potential
# from galpy import actionAngle as aA
from galpy import df
from galpy.util import _rotate_to_arbitrary_vector

## APOGEE, isochrones, dustmaps
from apogee import select as apsel
from isodist import Z2FEH,FEH2Z
import mwdust
from mwdust.util.extCurves import aebv

## scipy
import scipy.integrate
import scipy.interpolate

## Astropy and healpix
from astropy.coordinates import SkyCoord
import healpy

## Project-specific
sys.path.insert(0,'../../../src/')
# import sample_project.module as project_module

### Scale parameters
ro = 8.275
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup
%matplotlib inline
plt.style.use('../../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

## Plan
For the selection function portion of the algorithm first winnow down the number 
of samples as much as possible:
1. Remove fields without any spectroscopic targets
2. Remove samples outside the observational footprint
3. Use the supplied isochrones to remove samples with distance modulus-adjust H-band magnitudes below the survey limit

Then actually apply the main part of the algorithm
1. Calculate the H-band extinction using the dust map
2. Apply the selection function to the redenning-adjusted samples

In [None]:
# # These are the usual ipython objects, including this one you are creating
# ipython_vars = ['In', 'Out', 'exit', 'quit', 'get_ipython', 'ipython_vars']

# # Get a sorted list of the objects and their sizes
# sorted([(x, sys.getsizeof(globals().get(x))) for x in dir() \
#     if not x.startswith('_') and x not in sys.modules \
#     and x not in ipython_vars], key=lambda x: x[1], reverse=True)

### Functions

In [None]:
_parsec_1_2_iso_keys = {'mass_initial':'Mini',
                        }

In [None]:
# def _remove_dead_stars(ms,iso,iso_keys=_parsec_1_2_iso_keys):
#     '''_remove_dead_stars:
    
#     Remove stars with masses greater than the highest initial mass present 
#     in the isochrone.
    
#     Args:
#         ms (array) - array of masses
#         iso (array) - isochrone array
#         iso_keys (array) - 
    
#     Returns:
#         as_indx (array) - index of 'alive' stars to consider
#     '''
#     return np.where(ms < np.nanmax(iso[iso_keys['mass_initial']]))[0]
# #def 
    

def _remove_orbits_outside_footprint(orbs,aposf,field_indx=None,
                                     chunk_size=None):
    '''_remove_orbits_outside_footprint:
    
    Remove stellar samples from outside the APOGEE observational footprint.
    Each plate has a variable field of view, and an inner 'hole' of 
    5 arcminutes. Using field_indx allows for selecting only a subset of the 
    available fields to use.
    
    Args:
        orbs (array) - Orbits representing samples
        aposf (array) - APOGEE selection function
        field_indx (array) - Indices of fields to consider
        chunk_size (int) - If not None, split up orbs into chunks of this 
            size to 
    
    Returns:
        fp_indx () - Index of samples that lie within the observational
            footprint
        fp_locid () - Location IDs of field each sample lies within
    '''
    # Account for field_indx, fields we want to consider
    if field_indx is None:
        field_indx = np.arange(0,len(aposf._apogeeFields),dtype=int)
    ##fi
    
    # field center coordinates, location IDs, radii
    glon = aposf._apogeeField['GLON'][field_indx]
    glat = aposf._apogeeField['GLAT'][field_indx]
    locids = aposf._locations[field_indx]
    radii = np.zeros(len(field_indx))
    for i in range(len(locids)):
        radii[i] = aposf.radius(locids[i])
    ###i
    
    # Make SkyCoord objects
    aposf_sc = SkyCoord(frame='galactic', l=glon*apu.deg, b=glat*apu.deg)
    orbs_sc = SkyCoord(frame='galactic', l=orbs.ll(), b=orbs.bb())
    
    # First nearest-neighbor match
    indx,sep,_ = orbs_sc.match_to_catalog_sky(aposf_sc)
    indx_radii = radii[indx]
    indx_locid = locids[indx]
    fp_indx = np.where(np.logical_and(sep < indx_radii*apu.deg,
                                      sep > 5.5*apu.arcmin))[0]
    fp_locid = indx_locid[fp_indx]
    
    # Second nearest-neighbor match for samples inside plate central holes
    where_in_hole = np.where(sep < 5.5*apu.arcmin)[0]
    indx2,sep2,_ = orbs_sc[where_in_hole].match_to_catalog_sky(aposf_sc,
                                                               nthneighbor=2)
    indx2_radii = radii[indx2]
    indx2_locid = locids[indx2]
    fp_indx2 = np.where(np.logical_and(sep2 < indx2_radii*apu.deg,
                                       sep2 > 5.5*apu.arcmin))[0]
    if len(fp_indx2) > 0:
        fp_indx = np.append(fp_indx,where_in_hole[fp_indx2])
        fp_locid = np.append(fp_locid,indx2_locid[fp_indx2])
    ##fi
    
    return fp_indx,fp_locid
#def

def _match_isochrone_to_samples(iso,ms,m_err=0.05,iso_keys=_parsec_1_2_iso_keys):
    '''_match_isochrone_to_samples:
    
    Match the samples to entries in an isochrone according to initial mass
    
    iso_keys must accept the following keys:
    'Mini' -> initial mass key
    
    Args:
        iso (array) - isochrone array
        ms (array) - sample masses
        m_err (float) - Maximum difference in mass between sample and isochrone
            for successful match
        iso_keys (dict) - Dictionary of keys for accessing the isochrone 
            properties, accessible via a common set of strings (see above)
        
    Returns:
        good_match (array) - Indices of ms which found matches in the isochrone 
            array within m_err tolerance
        match_indx (array) - array of matches, length len(good_match), 
            indexing ms into iso
    '''
    # Access initial mass
    mass_initial_key = iso_keys['mass_initial']
    m0 = iso[mass_initial_key]
    
    # Ensure isochrone is sorted by initial mass
    m0_argsort = np.argsort(m0)
    m0_sorted = m0[m0_argsort]
    
    # Search the sorted array for the nearest neighbors (fast)
    m0_mids = m0_sorted[1:] - np.diff(m0_sorted.astype('f'))/2
    idx = np.searchsorted(m0_mids, ms)
    cand_indx = m0_argsort[idx]
    residual = ms - m0_sorted[cand_indx]
    
    good_match = np.where( (np.abs(residual) < m_err) &\
                           (ms < m0[-1]) &\
                           (ms > m0[0])
                          )[0]

    match_indx = np.argsort(m0_argsort)[cand_indx[good_match]]

    np.all(np.abs(ms[good_match]-m0[match_indx]) <= m_err)
    
    return good_match,match_indx
#def

def _remove_faint_Hmag_stars(Hmag,aposf):
    '''_remove_faint_Hmag_stars:
    
    Remove stars with Hmag below the limit for each individual field
    '''
    # Find maximum Hmag values for all cohorts in each field and compare to the 
    # distance modulus-adjusted value for each sample
    field_Hmax = np.nanmax(np.dstack([aposf._short_hmax,
                                      aposf._medium_hmax,
                                      aposf._long_hmax])[0],axis=1)
    dm = 5.*np.log10(im_orbs.dist().to(apu.pc).value)-5.
    im_Hmag_app = im_Hmag + dm
    im_locid_inds = np.where(im_locid.reshape(im_locid.size, 1) == aposf._locations)[1]
    im_Hmax = field_Hmax[im_locid_inds]
    where_good_Hmag = np.where(im_Hmax > im_Hmag_app)[0]
    print(str(len(where_good_Hmag))+'/'+str(len(im_Hmag_app))+\
              ' samples are bright enough to be observed')

    # Access values with good H-magnitudes
    gh_orbs = im_orbs[where_good_Hmag]
    gh_locid = im_locid[where_good_Hmag]
    gh_ms = im_ms[where_good_Hmag]
    gh_Hmag_app = im_Hmag_app[where_good_Hmag]
    
    
### Utilities

def load_parsec_isochrone(iso_dir,z,log_age,remove_wd_point=True):
    '''load_parsec_isochrone:
    
    Load a parsec isochrone from a set of old isochrones.
        
    For old=True the range of metallicities and ages is:
    0.0001 <= Z <= 0.0030 in spacing of 0.0001
    which equates roughly to -2.28 <= [FE/H] <= -0.8
    10 < log Age < 10.15 in spacing of 0.025
    
    
    Args:
        iso_dir (string) - Directory where the isochrones are located
        z (float) - metallicity (will use nearest)
        log_age (float) - log_age (will use nearest)
        remove_wd_point (bool) - Remove any WD-like points from the isochrone 
            before matching
    '''
    # Find which z to use
    grid_zs = np.arange(0.0001,0.0031,0.0001)
    grid_log_ages = np.arange(10,10.15,0.025)
    if z in grid_zs:
        z_load = z
    else:
        z_load = grid_zs[np.argmin(np.abs(z-grid_zs))]
        print('Using z='+str(z_load))
    ##ie
    
    # Get filename
    iso_name = 'parsec1.2-2mass-spitzer-wise-old'
    iso_filename = os.path.join(iso_dir,iso_name,iso_name+\
                                '-Z-{:<06.4}.dat.gz'.format(z_load))
    
    full_iso = np.genfromtxt(iso_filename, dtype=None, names=True, 
                             skip_header=11)
    
    # Find which log Age to use
    grid_log_ages = np.unique(full_iso['logAge'])
    if log_age in grid_log_ages:
        log_age_load = log_age
    else:
        log_age_load = grid_log_ages[np.argmin(np.abs(log_age-grid_log_ages))]
        print('Using log age='+str(log_age_load))
    ##ie
    
    # Extract the isochrone
    iso = full_iso[full_iso['logAge']==log_age_load]
    
    # Remove any points that look like WDs
    if remove_wd_point:
        wd_inds = np.zeros(len(iso),dtype=bool)
        is_wd = True
        ind = int(len(iso)-1)
        # Start at the end and work backwards until we find the TRGB
        while is_wd:
            is_wd = (np.diff(iso['Hmag'])[ind-1] > 0.) &\
                    (iso['logg'][ind] > iso['logg'][0])
            if is_wd:
                wd_inds[ind] = True
                ind -= 1
            ##fi
        ##wh
        iso = iso[~wd_inds]
    ##fi
    
    return iso
#def

### Summary function

In [None]:
def apply_APOGEE_selection_function(orbs,ms,aposf,iso,dmap,print_stats=False):
    '''apply_APOGEE_selection_function:
    
    Apply the APOGEE selection function to sample data
    
    Args:
        orbs (galpy.orbit.Orbit) - Orbits representing the samples
        ms (np.array) - Masses of the samples
        aposf (apogee.select.*) - APOGEE selection function
        iso (np.array) - Numpy array
        dmap (mwdust.DustMap3D) - Dust map
        
    Returns:
        
    '''
    
    # Remove fields without spectroscopic targets
    nspec = np.nansum(aposf._nspec_short,axis=1) +\
            np.nansum(aposf._nspec_medium,axis=1) +\
            np.nansum(aposf._nspec_long,axis=1)
    good_nspec_fields = np.where(nspec>=1.)[0]
    
    # Get info about APOGEE pointings
    aposf_Hmax = np.dstack([aposf._short_hmax,
                            aposf._medium_hmax,
                            aposf._long_hmax])[0]
    
    # Match samples to isochrone entries based on initial mass
    t1 = time.time()
    init_n = len(ms)
    m_err = 1e-4+np.diff(iso['Mini']).max()/2.
    good_iso_match,iso_match_indx = _match_isochrone_to_samples(iso,ms,m_err=m_err)
    # assert np.all(np.abs(ms[good_iso_match]-iso['Mini'][iso_match_indx]<=m_err))
    orbs = orbs[good_iso_match]
    ms = ms[good_iso_match]
    Hmag = iso['Hmag'][iso_match_indx]
    t2 = time.time()
    if print_stats:
        print(str(len(good_iso_match))+'/'+str(init_n)+\
              ' samples have good matches in the isochrone')
        print('Kept '+str(round(100*len(good_iso_match)/init_n,2))+\
              ' % of samples')
        print('Matching samples to isochrone entries took '+\
              str(round(t2-t1,1))+' s')
    ##fi
    init_n = len(ms)
    
    # Remove samples with apparent Hmag below faintest APOGEE Hmax
    t1 = time.time()
    dm = 5.*np.log10(orbs.dist().to(apu.pc).value)-5.
    where_good_Hmag1 = np.where(np.nanmax(aposf_Hmax) >\
        Hmag + dm)[0]
    orbs = orbs[where_good_Hmag1]
    ms = ms[where_good_Hmag1]
    dm = dm[where_good_Hmag1]
    Hmag = Hmag[where_good_Hmag1]
    iso_match_indx = iso_match_indx[where_good_Hmag1]
    t2=time.time()
    if print_stats:
        print(str(len(where_good_Hmag1))+'/'+str(init_n)+\
              ' samples are bright enough to be observed')
        print('Kept '+str(round(100*len(where_good_Hmag1)/init_n,2))+\
              ' % of samples')
        print('Removing samples with H-band magnitudes below faintest APOGEE'+\
              ' Hmax limit took '+str(round(t2-t1,1))+' s')
    ##fi
    init_n = len(ms)
        
    # Remove samples that lie outside the APOGEE observational footprint
    t1 = time.time()
    fp_indx,locid = _remove_orbits_outside_footprint(orbs,aposf,good_nspec_fields)
    orbs = orbs[fp_indx]
    ms = ms[fp_indx]
    dm = dm[fp_indx]
    Hmag = Hmag[fp_indx]
    iso_match_indx = iso_match_indx[fp_indx]
    t2 = time.time()
    if print_stats:        
        print('Removing samples outside observational footprint took '+\
              str(round(t2-t1,1))+' s')
        print(str(len(fp_indx))+'/'+str(init_n)+' samples found within'+\
              ' observational footprint')
        print('Kept '+str(round(100*len(fp_indx)/init_n,2))+\
              ' % of samples')
    ##fi
    init_n = len(ms)
    
    # Remove samples with apparent Hmag below faintest Hmax on field-by-field 
    # basis
    t1 = time.time()
    field_Hmax = np.nanmax(aposf_Hmax, axis=1)
    locid_inds = np.where(locid.reshape(locid.size, 1) == aposf._locations)[1]
    Hmax = field_Hmax[locid_inds]
    where_good_Hmag2 = np.where(Hmax > Hmag + dm)[0]
    # pdb.set_trace()
    # Access values with good H-magnitudes
    orbs = orbs[where_good_Hmag2]
    locid = locid[where_good_Hmag2]
    ms = ms[where_good_Hmag2]
    dm = dm[where_good_Hmag2]
    Hmag = Hmag[where_good_Hmag2]
    iso_match_indx = iso_match_indx[where_good_Hmag2]
    Jmag = iso['Jmag'][iso_match_indx]
    Ksmag = iso['Ksmag'][iso_match_indx]
    t2=time.time()
    if print_stats:
        print('Removing samples with H-band magnitudes outside '+\
              'observational limits took '+str(round(t2-t1,1))+' s')
        print(str(len(where_good_Hmag2))+'/'+str(init_n)+\
                  ' samples are bright enough to be observed')
        print('Kept '+str(round(100*len(where_good_Hmag2)/init_n,2))+\
              ' % of samples')
    ##fi
    init_n = len(ms)
        
    # Get lbIndx for the dust map
    gl = orbs.ll(use_physical=True).value
    gb = orbs.bb(use_physical=True).value
    dist = np.atleast_2d(orbs.dist(use_physical=True).value).T
    # Information about the dust map
    dmap_nsides = np.array(dmap._nsides)
    nside_pix = np.zeros((len(orbs),len(dmap_nsides)))
    nside_arr = np.repeat(dmap_nsides[:,np.newaxis],len(orbs),axis=1).T
    # Calculate healpix u
    dmap_hpu = (dmap._pix_info['healpix_index'] + 4*dmap._pix_info['nside']**2.).astype(int)
    hpu = (gh2_nside_pix + 4*gh2_nside_arr**2).astype(int)
    # Use searchsorted to get the indices
    dmap_hpu_argsort = np.argsort(dmap_hpu)
    dmap_hpu_sorted = dmap_hpu[dmap_hpu_argsort]
    hpu_indx_sorted = np.searchsorted(dmap_hpu_sorted,hpu)
    hpu_indx = np.take(dmap_hpu_argsort, hpu_indx_sorted, mode="clip")
    hpu_mask = dmap_hpu[hpu_indx] != hpu
    hpu_ma = np.ma.array(hpu_indx, mask=hpu_mask)
    lbIndx = hpu_ma.data[~hpu_ma.mask]
    t2 = time.time()
    if print_stats:
        print('Getting lbIndx took '+str(round(t2-t1,1))+' s')
    ##fi
    
    # Compute AH
    t1 = time.time()
    unique_lbIndx = np.unique(lbIndx).astype(int)
    AH = np.zeros(len(orbs))
    for i in range(len(unique_lbIndx)):
        # First find which samples have this lbIndx
        where_unique = np.where(lbIndx == unique_lbIndx[i])[0]
        # Get the dust map interpolation data for this lbIndx
        dmap_interp_data = scipy.interpolate.InterpolatedUnivariateSpline(
            dmap._distmods, dmap._best_fit[unique_lbIndx[i]], k=dmap._interpk)
        # Calcualate AH
        eBV_to_AH = mwdust.util.extCurves.aebv(dmap._filter,sf10=dmap._sf10)
        AH[where_unique] = dmap_interp_data(dm[where_unique])*eBV_to_AH
    ###i
    t2 = time.time()
#     if print_stats:
#         print('Getting AH took '+str(t2-t1)+' s')
#     ##fi
    
    # Apply the selection function
    t1 = time.time()
    sf_keep_indx = np.zeros(len(orbs),dtype=bool)
    for i in range(len(orbs)):
        random_n = np.random.random(size=1)[0]
        _H = Hmag[i] + dm[i] + AH[i]
        _JK0 = Jmag[i] - Ksmag[i]
        compare_n = aposf(locid[i],_H,_JK0)
        if compare_n > random_n: sf_keep_indx[i] = True
    t2 = time.time()
    if print_stats:
        print('Applying selection function took '+str(round(t2-t1,1))+' s')
        print(str(np.sum(sf_keep_indx))+'/'+str(init_n)+\
                  ' samples survive the selection function')
        print('Kept '+str(round(100*np.sum(sf_keep_indx)/init_n,2))+\
              ' % of samples')
    ##fi
    
    orbs = orbs[sf_keep_indx]
    locid = locid[sf_keep_indx]
    ms = ms[sf_keep_indx]
    iso_match_indx = iso_match_indx[sf_keep_indx]
    
    return orbs, locid, ms, iso_match_indx
#def

### Load selection function, dust map, isochrone, mock orbits and masses

In [None]:
aposf_data_dir = '/geir_data/scr/lane/projects/ges-mass/data/gaia_apogee/'+\
                 'apogee_dr16_l33_gaia_dr2/'
with open(aposf_data_dir+'apogee_SF.dat','rb') as f:
    aposf = pickle.load(f)
##wi

In [None]:
dmap = mwdust.Combined19(filter='2MASS H') # dustmap from mwdust

In [None]:
z = 0.0010
log_age = 10.0
iso = load_parsec_isochrone(iso_dir=os.environ['ISODIST_DATA'],
                           z=z,log_age=log_age, remove_wd_point=True)

In [None]:
sample_data_dir = '/geir_data/scr/lane/projects/ges-mass/apo_mocks/'
n_samples_str = '1e7'
with open(sample_data_dir+'orbs_'+n_samples_str+'.pkl','rb') as f:
    orbs = pickle.load(f)
##fi
ms = np.load(sample_data_dir+'masses_'+n_samples_str+'.npy')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(iso['Jmag']-iso['Ksmag'], iso['Hmag'], c=np.arange(0,len(iso),1.))
ax.invert_yaxis()

fig.show()

### Remove any fields without spectroscopic targets

In [None]:
nspec = np.nansum(aposf._nspec_short,axis=1) +\
        np.nansum(aposf._nspec_medium,axis=1) +\
        np.nansum(aposf._nspec_long,axis=1)
good_nspec_fields = np.where(nspec>=1.)[0]
print(str(len(good_nspec_fields))+'/'+str(len(nspec))+' fields have'\
      ' spectroscopic targets')

## Here do isochrone matching first, and a coarse removal of samples, then field matching
Match into the isochrone first, then remove all samples with apparent Hmag
below the faintest observational limit in APOGEE (haven't matched to fields yet)
so can't use field-specific values

### Match to isochrone


In [None]:
t1 = time.time()

# Match samples to isochrone entries based on initial mass
m_err = 1e-4+np.diff(iso['Mini']).max()/2.
good_iso_match,iso_match_indx = _match_isochrone_to_samples(iso,ms,m_err=m_err)
print(str(len(good_iso_match))+'/'+str(len(ms))+\
      ' samples have good matches in the isochrone')
print('Kept '+str(round(100*len(good_iso_match)/len(ms),2))+' % of samples')
assert np.all(np.abs(ms[good_iso_match]-iso['Mini'][iso_match_indx]<=m_err))

# Access isochrone-matched values
im_orbs = orbs[good_iso_match]
im_ms = ms[good_iso_match]
im_Hmag = iso['Hmag'][iso_match_indx]

t2 = time.time()
print('Matching samples to isochrone entries took '+str(round(t2-t1,1))+' s')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(iso['Jmag'][iso_match_indx][::100]\
           -iso['Ksmag'][iso_match_indx][::100], 
           im_Hmag[::100], alpha=0.1, s=1.)
ax.set_ylabel('H')
ax.set_xlabel('J-Ks')
ax.invert_yaxis()
fig.show()

### Remove samples that lie below the faintest APOGEE Hmax
Should be about 0.1 per cent depending on the isochrone

In [None]:
t1 = time.time()

# Find maximum Hmag values for all cohorts in each field and compare to the 
# distance modulus-adjusted value for each sample
aposf_Hmax = np.nanmax(np.dstack([aposf._short_hmax,
                                  aposf._medium_hmax,
                                  aposf._long_hmax])[0])
im_dm = 5.*np.log10(im_orbs.dist().to(apu.pc).value)-5.
where_good_Hmag = np.where(aposf_Hmax > (im_Hmag+im_dm) )[0]
where_bad_Hmag = np.where(aposf_Hmax < (im_Hmag+im_dm) )[0]
print(str(len(where_good_Hmag))+'/'+str(len(im_Hmag))+\
          ' samples are bright enough to be observed')
print('Kept '+str(round(100*len(where_good_Hmag)/len(im_Hmag),2))+' % of samples')

# Access values with good H-magnitudes
gh_orbs = im_orbs[where_good_Hmag]
gh_ms = im_ms[where_good_Hmag]
gh_dm = im_dm[where_good_Hmag]
gh_Hmag = im_Hmag[where_good_Hmag]
gh_iso_match_indx = iso_match_indx[where_good_Hmag]

t2=time.time()
print('Removing samples with H-band magnitudes outside observational limits'+\
      ' took '+str(round(t2-t1,1))+' s')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(iso['Jmag'][iso_match_indx[where_bad_Hmag]][::100]\
           -iso['Ksmag'][iso_match_indx[where_bad_Hmag]][::100], 
           im_Hmag[where_bad_Hmag][::100]+im_dm[where_bad_Hmag][::100], 
           alpha=0.1, s=1.)
ax.axhline(aposf_Hmax, c='Black', linestyle='solid')
    
ax.set_ylabel('H')
ax.set_xlabel('J-Ks')
ax.invert_yaxis()
fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(iso['Jmag'][gh_iso_match_indx]-iso['Ksmag'][gh_iso_match_indx], 
           gh_Hmag+gh_dm, alpha=0.1, s=1.)
ax.axhline(aposf_Hmax, c='Black', linestyle='solid')
    
ax.set_ylabel('H')
ax.set_xlabel('J-Ks')
ax.invert_yaxis()
fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

plot_gl = gh_orbs.ll(use_physical=True).value
plot_gb = gh_orbs.bb(use_physical=True).value

plot_gl[plot_gl > 180] = plot_gl[plot_gl > 180] - 360

ax.scatter(plot_gl,plot_gb,color='Black',alpha=0.5,s=0.1)
ax.set_xlim(180,-180)
ax.set_ylim(-90,90)
ax.set_xlabel(r'$\ell$')
ax.set_ylabel(r'$b$')

fig.show()

### Remove orbits outside APOGEE fields
Should keep ~12% of orbits depending on APOGEE DR

In [None]:
t1 = time.time()
fp_indx,fp_locid = _remove_orbits_outside_footprint(gh_orbs,aposf,good_nspec_fields)
fp_orbs = gh_orbs[fp_indx]
fp_ms = gh_ms[fp_indx]
fp_Hmag = gh_Hmag[fp_indx]
fp_dm = gh_dm[fp_indx]
fp_iso_match_indx = gh_iso_match_indx[fp_indx]
t2 = time.time()
print('Removing samples outside observational footprint took '+\
      str(round(t2-t1,1))+' s')
print(str(len(fp_indx))+'/'+str(len(gh_orbs))+' samples found within observational footprint')
print('Kept '+str(round(100*len(fp_indx)/len(gh_orbs),2))+' % of samples')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

plot_gl = fp_orbs.ll(use_physical=True).value
plot_gb = fp_orbs.bb(use_physical=True).value

plot_gl[plot_gl > 180] = plot_gl[plot_gl > 180] - 360

ax.scatter(plot_gl,plot_gb,color='Black',alpha=0.5,s=0.1)
ax.set_xlim(180,-180)
ax.set_ylim(-90,90)
ax.set_xlabel(r'$\ell$')
ax.set_ylabel(r'$b$')

fig.show()

### Do a more specific field-by-field Hmax removal procedure

In [None]:
t1 = time.time()

# Find maximum Hmag values for all cohorts in each field and compare to the 
# distance modulus-adjusted value for each sample
field_Hmax = np.nanmax(np.dstack([aposf._short_hmax,
                                  aposf._medium_hmax,
                                  aposf._long_hmax])[0],axis=1)
fp_locid_inds = np.where(fp_locid.reshape(fp_locid.size, 1) == aposf._locations)[1]
fp_Hmax = field_Hmax[fp_locid_inds]
where_good_Hmag = np.where(fp_Hmax > fp_Hmag + fp_dm)[0]
# Access values with good H-magnitudes
gh2_orbs = fp_orbs[where_good_Hmag]
gh2_locid = fp_locid[where_good_Hmag]
gh2_ms = fp_ms[where_good_Hmag]
gh2_dm = fp_dm[where_good_Hmag]
gh2_Hmag = fp_Hmag[where_good_Hmag]
gh2_iso_match_indx = fp_iso_match_indx[where_good_Hmag]
gh2_Jmag = iso['Jmag'][gh2_iso_match_indx]
gh2_Ksmag = iso['Ksmag'][gh2_iso_match_indx]

t2=time.time()
print('Removing samples with H-band magnitudes outside observational limits'+\
      ' took '+str(round(t2-t1,1))+' s')
print(str(len(where_good_Hmag))+'/'+str(len(fp_Hmag))+\
          ' samples are bright enough to be observed')
print('Kept '+str(round(100*len(where_good_Hmag)/len(fp_Hmag),2))+\
      ' % of samples')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

field_Hmax[np.isnan(field_Hmax)] = 9999
ax.hist(field_Hmax[field_Hmax < 9998], histtype='step')

fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(gh2_Jmag-gh2_Ksmag,gh2_Hmag+gh2_dm,alpha=1.,s=1.,
           color='Black')
ax.set_xlabel(r'$(J-K_{s})_{0}$')
ax.set_ylabel(r'$H_{app}$')
ax.set_xlim(0.3,1.0)
# ax.set_ylim(5.,15.)
ax.invert_yaxis()

fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

plot_gl = gh2_orbs.ll(use_physical=True).value
plot_gb = gh2_orbs.bb(use_physical=True).value

plot_gl[plot_gl > 180] = plot_gl[plot_gl > 180] - 360

ax.scatter(plot_gl,plot_gb,color='Black',alpha=0.5,s=1.)
ax.set_xlim(180,-180)
ax.set_ylim(-90,90)

fig.show()

### Use the dustmap to get the redenning correction
Looping through calls to dust map is slow. Perhaps minor speedup by first 
determining lbIndx and then batching calls to the dust map for samples with the 
same lbIndx but different distances.

In [None]:
# Get l,b,d properties for orbits
gh2_gl = gh2_orbs.ll(use_physical=True).value
gh2_gb = gh2_orbs.bb(use_physical=True).value
gh2_dist = np.atleast_2d(gh2_orbs.dist(use_physical=True).value).T

# Prepare arrays that hold healpix info about samples
dmap_nsides = np.array(dmap._nsides)
gh2_nside_pix = np.zeros((len(gh2_orbs),len(dmap_nsides)))
gh2_nside_arr = np.repeat(dmap_nsides[:,np.newaxis],len(gh2_orbs),axis=1).T

In [None]:
t1 = time.time()
gh2_AH_slow = np.zeros(len(gh2_orbs))
for i in range(len(gh2_AH_slow)):
    print('Done '+str(i+1)+'/'+str(len(gh2_AH_slow)),end='\r')
    gh2_AH_slow[i] = dmap(gh2_gl[i],gh2_gb[i],gh2_dist[i])
###i
t2 = time.time()
print('Getting AH with dmap queries took '+str(t2-t1)+' s')

t1 = time.time()
gh2_lbIndx_slow = np.zeros(len(gh2_orbs))
for i in range(len(gh2_orbs)):
    print('Done '+str(i+1)+'/'+str(len(gh2_lbIndx_slow)),end='\r')
    gh2_lbIndx_slow[i] = dmap._lbIndx(gh2_gl[i],gh2_gb[i])
###i
t2 = time.time()
print('Getting lbIndx took '+str(t2-t1)+' s')

So when querying the dust map the most expensive operation is determining the lbIndx.
Figure out a way to do this faster, preferably vectorizable

### Faster method of getting $A_{H}$

Querying the healpix pixel can be fast if we use `searchsorted` to get matches 
in the unique healpix identifier:

$u = p + 4N_{side}^{2}$

where $p$ is the healpix index in the range

$p \in [0,12N_{side}^{2}-1]$

Calculate this for the healpix tiles in the dust map as well as the 
samples. Then querying is easy.

Then bulk evaluate the dust map for all samples that share an lbIndx but have different distances.

In [None]:
t1 = time.time()

_DEGTORAD = np.pi/180.
for i in range(len(dmap_nsides)):
    gh2_nside_pix[:,i] = healpy.pixelfunc.ang2pix(dmap_nsides[i],
                                                 (90.-gh2_gb)*_DEGTORAD,
                                                  gh2_gl*_DEGTORAD,
                                                  nest=True)

# Calculate u for both the dust map and the samples
dmap_hpu = (dmap._pix_info['healpix_index'] + 4*dmap._pix_info['nside']**2.).astype(int)
gh2_hpu = (gh2_nside_pix + 4*gh2_nside_arr**2).astype(int)

# Use searchsorted to get the indices
dmap_hpu_argsort = np.argsort(dmap_hpu)
dmap_hpu_sorted = dmap_hpu[dmap_hpu_argsort]
hpu_indx_sorted = np.searchsorted(dmap_hpu_sorted,gh2_hpu)
hpu_indx = np.take(dmap_hpu_argsort, hpu_indx_sorted, mode="clip")
hpu_mask = dmap_hpu[hpu_indx] != gh2_hpu
hpu_ma = np.ma.array(hpu_indx, mask=hpu_mask)
gh2_lbIndx = hpu_ma.data[~hpu_ma.mask]

t2 = time.time()
print('Getting lbIndx using searchsorted took '+str(round(t2-t1,1))+' s')

# Sanity checks, there should only be one matching Nside per sample so if 
# we collapse the array there should be exactly enough entries. lbIndx acquired
# using this method should also match lbIndx acquired using the slow method.
assert np.all( np.sum(~hpu_ma.mask,axis=1) == np.ones(len(gh2_orbs)) )
assert np.all( gh2_lbIndx == gh2_lbIndx_slow )

In [None]:
t1 = time.time()

# First get all the unique lbIndx values, then batch compute AH
unique_lbIndx = np.unique(gh2_lbIndx).astype(int)
gh2_AH = np.zeros(len(gh2_orbs))
for i in range(len(unique_lbIndx)):
    # First find which samples have this lbIndx
    where_unique = np.where(gh2_lbIndx == unique_lbIndx[i])[0]
    # Get the dust map interpolation data for this lbIndx
    dmap_interp_data = scipy.interpolate.InterpolatedUnivariateSpline(
        dmap._distmods, dmap._best_fit[unique_lbIndx[i]], k=dmap._interpk)
    # Calcualate AH
    eBV_to_AH = mwdust.util.extCurves.aebv(dmap._filter,sf10=dmap._sf10)
    gh2_AH[where_unique] = dmap_interp_data(gh2_dm[where_unique])*eBV_to_AH
###i

t2 = time.time()
print('Getting AH took '+str(t2-t1)+' s')

In [None]:
assert np.all(gh2_AH - gh2_AH_slow < 1e-10)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.hist(gh2_AH,histtype='step',edgecolor='Black')

ax.set_xlabel(r'$A_{H}$')
ax.set_ylabel('N')

fig.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.scatter(gh2_Jmag-gh2_Ksmag,gh2_Hmag+gh2_dm+gh2_AH,alpha=1.,s=1.,
           color='Black')
ax.set_xlabel(r'$(J-K_{s})_{0}$')
ax.set_ylabel(r'$H_{app}$')
ax.set_xlim(0.3,1.0)
# ax.set_ylim(5.,20.)
ax.invert_yaxis()

fig.show()

### Apply the selection function

In [None]:
t1 = time.time()
sf_keep_indx = np.zeros(len(gh2_orbs),dtype=bool)
for i in range(len(gh2_orbs)):
    random_n = np.random.random(size=1)[0]
    _H = gh2_Hmag[i] + gh2_dm[i] + gh2_AH[i]
    _JK0 = gh2_Jmag[i] - gh2_Ksmag[i]
    compare_n = aposf(gh2_locid[i],_H,_JK0)
    # print(str(compare_n)+' > '+str(random_n))
    if compare_n > random_n: sf_keep_indx[i] = True
t2 = time.time()
print('Applying selection function took '+str(round(t2-t1,1))+' s')

In [None]:
print('Mass of the profile is: '+str(np.sum(ms)/1e8)+'e8 Msun')
print('Number of surviving stars: '+str(np.sum(sf_keep_indx)))