In [None]:
# ------------------------------------------------------------------------
#
# TITLE - examine_gaia_apogee.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Look at Gaia-APOGEE data. Explore selections, metallicities, and output
selections to do mass modelling on.
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np, pdb, sys, os, numbers, dill as pickle, copy
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import path
import astropy.units as apu

## galpy
from galpy import orbit
from galpy import potential
from galpy import actionAngle as aA

sys.path.append('../../src/')
from ges_mass import plot as pplot
from ges_mass import util as putil

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle')
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Colors setup

project_colors = pplot.colors()
rainbow_cmap = project_colors.colourmap('rainbow')

### Keywords, Pathing, Loading, Data Preparation

In [None]:
## Keywords
cdict = putil.load_config_to_dict()
keywords = ['BASE_DIR','APOGEE_DR','APOGEE_RESULTS_VERS','GAIA_DR','NDMOD',
            'DMOD_MIN','DMOD_MAX','LOGG_MIN','LOGG_MAX','FEH_MIN','FEH_MAX',
            'FEH_MIN_GSE','FEH_MAX_GSE','NPROCS','DF_VERSION','KSF_VERSION',
            'RO','VO','ZO']
base_dir,apogee_dr,apogee_results_vers,gaia_dr,ndmod,dmod_min,dmod_max,\
    logg_min,logg_max,feh_min,feh_max,feh_min_gse,feh_max_gse,nprocs,\
    df_version,ksf_version,ro,vo,zo = putil.parse_config_dict(cdict,keywords)
logg_range = [logg_min,logg_max]
feh_range = [feh_min,feh_max]
feh_range_gse = [feh_min_gse,feh_max_gse]

## Pathing
fit_paths = putil.prepare_paths(base_dir,apogee_dr,apogee_results_vers,gaia_dr,
                                df_version,ksf_version)
data_dir,version_dir,ga_dir,gap_dir,_,_,_ = fit_paths
os.makedirs(gap_dir,exist_ok=True)

## Filenames
sample_kinematics = True
if sample_kinematics:
    omask_kinematics_filename = gap_dir+'clean_kinematics_sampled.npy'
else:
    omask_kinematics_filename = gap_dir+'clean_kinematics_no_sample.npy'

## Potential
mwpot = potential.MWPotential2014
potential.turn_physical_on(mwpot,ro=ro,vo=vo)
phi0 = potential.evaluatePotentials(mwpot,1e10,0).value

# Functions
def line_equation(x,m,b):
    return m*x+b

### Parameters for lines to mask halo / separate GSE from disk/other halo

In [None]:
# Halo separation from Lane+2022, also used in Mackereth+2019
alpha_fe_halo_line_x = [-1.2,-0.9]
alpha_fe_halo_line_y = [0.5,-0.2]
alpha_fe_halo_line_m, alpha_fe_halo_line_b = pplot.get_params_from_line(
    alpha_fe_halo_line_x, alpha_fe_halo_line_y)

# Double GSE/thick disk separation for this work?
alpha_fe_halo_line_broken_m1, alpha_fe_halo_line_broken_b1 = -0.21,0.08
alpha_fe_halo_line_broken_m2, alpha_fe_halo_line_broken_b2 = -0.95,-0.58
alpha_fe_halo_line_broken_xint =\
 (alpha_fe_halo_line_broken_b1-alpha_fe_halo_line_broken_b2)/\
 (alpha_fe_halo_line_broken_m2-alpha_fe_halo_line_broken_m1)

# [Mg/Mn]-[Al/Fe] Halo separation from Horta+21
mg_mn_al_fe_halo_line_x = [-9999,-0.06,0.18]
mg_mn_al_fe_halo_line_y = [0.25,0.25,1.3]
mg_mn_al_fe_halo_line_m, mg_mn_al_fe_halo_line_b = pplot.get_params_from_line(
    mg_mn_al_fe_halo_line_x[1:], mg_mn_al_fe_halo_line_y[1:])

# [Al/Fe] Halo separation from Belokurov+22
al_fe_halo_line_x_all = [-1., -1., -0.6, -0.6]
al_fe_halo_line_y_all = [9999., -0.1, -0.3, -9999.]
al_fe_halo_line_x_acc = [-9999., -1., -0.6, -0.6]
al_fe_halo_line_y_acc = [-0.1, -0.1, -0.3, -9999.]
al_fe_halo_line_x_ext = [-9999., -0.6, -0.6]
al_fe_halo_line_y_ext = [-0.1, -0.1, -0.3]
# al_fe_halo_line_x = [-9999,-0.6,-0.6]
# al_fe_halo_line_y = [-0.075,-0.075,-9999]


### Get data

In [None]:
# Cleaned kinematics
if os.path.exists(omask_kinematics_filename):
    print('Loading cleaned kinematics from '+omask_kinematics_filename)
    with open(omask_kinematics_filename,'rb') as f:
        omask_kinematics = pickle.load(f)
else:
    sys.exit('Clean kinematics first')

gaia_omask,allstar_omask,os_omask,eELzs_omask,accs_omask,orbextr_omask = \
    omask_kinematics

# Unpack the orbital extrema
zmax_omask,rperi_omask,rapo_omask = orbextr_omask

## Whole Sample

### Look at chemistry of the whole sample

In [None]:
# def line_intersects(ax,ay,bx,by):
#     '''line_intersects
    
#     Args:
        
#     '''
#     det = (by[1]-by[0])*(ax[1]-ax[0]) - (bx[1]-bx[0])*(ay[1]-ay[0])
#     if det:
#         ua = ((bx[1]-bx[0])*(ay[0]-by[0]) - (by[1]-by[0])*(ax[0]-bx[0])) / det
#         ub = ((ax[1]-ax[0])*(ay[0]-by[0]) - (ay[1]-ay[0])*(ax[0]-bx[0])) / det
#     else:
#         return False
#     if not (0 <= ua <= 1 and 0 <= ub <= 1):
#         return False
#     else:
#         return True

# def lines_boundary_mask(xs,ys,lx,ly,where):
#     '''lines_boundary_mask:
    
#     Args:
#         xs (np.array) - x-coordinates of points to be masked
#         ys (np.array) - y-coordinates of points to be masked
#         lx (list) - x coordinates of boundary vertices
#         ly (list) - y coordinates of boundary vertices
#         where (text) - Orientation of points w.r.t. boundary to determine mask
        
#     Returns:
#         mask (np.array) - boolean mask of point position w.r.t. boundary given
#             where
#     '''
#     npt = len(xs)
#     nseg = len(lx)-1 # Number of line segments 1 less than number of vertices
#     assert len(xs) == len(ys)
#     mask = np.zeros(npt,dtype=bool)
    
#     # Based on "where", lines are drawn to inf and intersect w/ each boundary 
#     # segment is checked.
#     if where == 'left': # Mask=True if points are to the left
#         xe = 9999.
#         ye = 0.
#     if where == 'right': # mask=True if points are to the right
#         pass
#     if where == 'below': # mask=True if points are below
#         pass
#     if where == 'above': # mask=True if points are above
#         pass
    
#     for i in range(npt):
#         _mask = False
#         for j in range(nseg):
#             _mask |= line_intersects([xs[i],xe], [ys[i],ye], 
#                                      [lx[j],lx[j+1]], [ly[j],ly[j+1]])
#         mask[i] = _mask
#     return mask

In [None]:
fig = plt.figure(figsize=(15,5))
axs = fig.subplots(nrows=1, ncols=3).flatten()
n_bins = 50

abunds_x = ['FE_H','FE_H','AL_FE']
abunds_y = ['AL_FE','MG_FE','MG_MN']
abunds_xlabels = ['[Fe/H]', '[Fe/H]', '[Al/Fe]']
abunds_ylabels = ['[Al/Fe]', '[Mg/Fe]', '[Mg/Mn]']
xlims = [ [-2.5,0.75], [-2.5,0.75], [-1.1,1.1] ]
ylims = [ [-0.6,0.6], [-0.4,0.6], [-0.8,1.3] ]
nbins = 50

ecc_sort = np.argsort(eELzs_omask[0])

for i in range(len(axs)):
    # Data
    xplot,xplot_err = putil.get_metallicity(allstar_omask,abunds_x[i])
    yplot,yplot_err = putil.get_metallicity(allstar_omask,abunds_y[i])
    
    # Plot
    pts = axs[i].scatter(xplot[ecc_sort], yplot[ecc_sort], 
                         c=eELzs_omask[0][ecc_sort], cmap=rainbow_cmap, 
                         edgecolors='None', vmin=0, vmax=1, marker='o', 
                         s=5, zorder=4, rasterized=True)
    
    # Add contours
    xbinsize = (xlims[i][1]-xlims[i][0])/nbins
    ybinsize = (ylims[i][1]-ylims[i][0])/nbins
    hist,xedge,yedge = np.histogram2d(xplot,yplot,bins=nbins,
                                      range=[[xlims[i][0]-xbinsize,
                                              xlims[i][1]+xbinsize],
                                             [ylims[i][0]-ybinsize,
                                              ylims[i][1]+ybinsize]])
    xcents = xedge[:-1] + (xlims[i][1]-xlims[i][0])/nbins
    ycents = yedge[:-1] + (ylims[i][1]-ylims[i][0])/nbins
    xcnt,ycnt = np.meshgrid(xcents,ycents)
    axs[i].contour(xcnt, ycnt, hist.T, colors='Black', levels=[50,100,500], 
                   zorder=5, alpha=0.5)
     
    cax = fig.add_axes([0.92,0.2,0.025,0.7])
    cbar = fig.colorbar(pts, cax=cax)
    cbar.ax.tick_params(labelsize=12) 
    cbar.set_label(r'eccentricity', fontsize=16)
    
    axs[i].set_xlabel(abunds_xlabels[i], fontsize=20)
    axs[i].set_ylabel(abunds_ylabels[i], fontsize=20)
    axs[i].set_xlim(xlims[i])
    axs[i].set_ylim(ylims[i])
    
    # Add lines
    if abunds_x[i] == 'FE_H' and abunds_y[i] == 'MG_FE':
        axs[i].plot(np.arange(-3,2,1.), 
                    pplot.line_equation(np.arange(-3,2,1.),
                                        alpha_fe_halo_line_m,
                                        alpha_fe_halo_line_b), 
                    linestyle='solid', color='Black', linewidth=2., 
                    zorder=6)
        axs[i].axvline(-1, linestyle='solid', color='Black', linewidth=2., 
                       zorder=6)
    if abunds_x[i] == 'AL_FE' and abunds_y[i] == 'MG_MN':
        axs[i].axhline(mg_mn_al_fe_halo_line_y[0], linestyle='solid', 
            color='Black', linewidth=2., zorder=6)
        axs[i].plot([mg_mn_al_fe_halo_line_x[1],mg_mn_al_fe_halo_line_x[2]],
            [mg_mn_al_fe_halo_line_y[1],mg_mn_al_fe_halo_line_y[2]],
            linestyle='solid', color='Black', linewidth=2., zorder=6)
    
    if abunds_x[i] == 'FE_H' and abunds_y[i] == 'AL_FE':
        for j in range(len(al_fe_halo_line_x_all)-1):
            axs[i].plot([al_fe_halo_line_x_all[j],al_fe_halo_line_x_all[j+1]],
                        [al_fe_halo_line_y_all[j],al_fe_halo_line_y_all[j+1]],
                        linestyle='dashed', color='Black', linewidth=2., 
                        zorder=6)
        for j in range(2):
            axs[i].plot([al_fe_halo_line_x_acc[j],al_fe_halo_line_x_acc[j+1]],
                        [al_fe_halo_line_y_acc[j],al_fe_halo_line_y_acc[j+1]],
                        linestyle='dashed', color='Black', linewidth=2., 
                        zorder=6)
        for j in range(len(al_fe_halo_line_x_ext)-1):
            axs[i].plot([al_fe_halo_line_x_ext[j],al_fe_halo_line_x_ext[j+1]],
                        [al_fe_halo_line_y_ext[j],al_fe_halo_line_y_ext[j+1]],
                        linestyle='dotted', color='Black', linewidth=2., 
                        zorder=6)
    
fig.tight_layout()
fig.subplots_adjust(hspace=0.35, wspace=0.25, left=0.09, right=0.9)
fig.show()

### Mask the halo sample
Some of this should be redundant

In [None]:
# Mask the data based on Al/Fe and Fe/H
print('Making [Al/Fe]-[Fe/H] masks...')
# _,_ = putil.get_metallicity(allstar_omask,'')
allstar_omask_feh = allstar_omask['FE_H']
allstar_omask_alfe = allstar_omask['AL_FE']
# Make the mask for the whole halo, and just the accreted part (lower portion
# of the whole halo)
halo_abund_mask = putil.lines_boundary_mask(allstar_omask_feh, allstar_omask_alfe, 
    al_fe_halo_line_x_all, al_fe_halo_line_y_all, where='left')
acc_abund_mask = putil.lines_boundary_mask(allstar_omask_feh, allstar_omask_alfe,
    al_fe_halo_line_x_acc, al_fe_halo_line_y_acc, where='left')

# Logg mask (do this here rather than in the previous notebook so it's easier 
# to change)
logg_mask = (allstar_omask['LOGG'] > logg_min) &\
            (allstar_omask['LOGG'] < logg_max)

# This part should be redundant
quality_mask = (allstar_omask['AL_FE'] > -9999.) &\
               (allstar_omask['FE_H'] > -9999.)

# Combine all the masks into a mask for the whole halo sample
hmask = (halo_abund_mask) & (logg_mask) & (quality_mask)
accmask = (acc_abund_mask) & (logg_mask) & (quality_mask)

# Save the IDs and the mask
apogee_id_halo = allstar_omask['APOGEE_ID'].astype(str)[hmask]
apogee_id_halo_acc = allstar_omask['APOGEE_ID'].astype(str)[accmask]
np.save(gap_dir+'halo_apogee_ids.npy',apogee_id_halo)
np.save(gap_dir+'halo_apogee_ids_acc.npy',apogee_id_halo_acc)
np.save(gap_dir+'halo_apogee_mask.npy',hmask)
np.save(gap_dir+'halo_apogee_mask_acc.npy',accmask)
print('done masking')

In [None]:
fig = plt.figure(figsize=(15,5))
axs = fig.subplots(nrows=1, ncols=3).flatten()
n_bins = 50

abunds_x = ['FE_H','FE_H','AL_FE']
abunds_y = ['AL_FE','MG_FE','MG_MN']
abunds_xlabels = ['[Fe/H]', '[Fe/H]', '[Al/Fe]']
abunds_ylabels = ['[Al/Fe]', '[Mg/Fe]', '[Mg/Mn]']
xlims = [ [-2.5,0.75], [-2.5,0.75], [-1.1,1.1] ]
ylims = [ [-0.6,0.6], [-0.4,0.6], [-0.8,1.3] ]
nbins = 50

# ecc_sort = np.argsort(eELzs_omask[0])

for i in range(len(axs)):
    # Data
    xplot,xplot_err = putil.get_metallicity(allstar_omask,abunds_x[i])
    yplot,yplot_err = putil.get_metallicity(allstar_omask,abunds_y[i])
    
    # Do the eccentricity sorting based on the masked data
    ecc_sort_mask = np.argsort(eELzs_omask[0][hmask])
    
    # Plot
    pts = axs[i].scatter(xplot[hmask][ecc_sort_mask], 
                         yplot[hmask][ecc_sort_mask], 
                         c=eELzs_omask[0][hmask][ecc_sort_mask], 
                         cmap=rainbow_cmap, edgecolors='None', vmin=0, vmax=1, 
                         marker='o', s=5, zorder=4, rasterized=True)
    axs[i].scatter(xplot[~hmask], yplot[~hmask], 
                   c='Grey', 
                   edgecolors='None', marker='o', s=2, 
                   zorder=3, rasterized=True)
    
    # Add contours
    xbinsize = (xlims[i][1]-xlims[i][0])/nbins
    ybinsize = (ylims[i][1]-ylims[i][0])/nbins
    hist,xedge,yedge = np.histogram2d(xplot,yplot,bins=nbins,
                                      range=[[xlims[i][0]-xbinsize,
                                              xlims[i][1]+xbinsize],
                                             [ylims[i][0]-ybinsize,
                                              ylims[i][1]+ybinsize]])
    xcents = xedge[:-1] + (xlims[i][1]-xlims[i][0])/nbins
    ycents = yedge[:-1] + (ylims[i][1]-ylims[i][0])/nbins
    xcnt,ycnt = np.meshgrid(xcents,ycents)
    axs[i].contour(xcnt, ycnt, hist.T, colors='Black', levels=[50,100,500], 
                   zorder=5, alpha=0.5)
     
    cax = fig.add_axes([0.92,0.2,0.025,0.7])
    cbar = fig.colorbar(pts, cax=cax)
    cbar.ax.tick_params(labelsize=12) 
    cbar.set_label(r'eccentricity', fontsize=16)
    
    axs[i].set_xlabel(abunds_xlabels[i], fontsize=20)
    axs[i].set_ylabel(abunds_ylabels[i], fontsize=20)
    axs[i].set_xlim(xlims[i])
    axs[i].set_ylim(ylims[i])
    
    # Add lines
    if abunds_x[i] == 'FE_H' and abunds_y[i] == 'MG_FE':
        axs[i].plot(np.arange(-3,2,1.), 
                    pplot.line_equation(np.arange(-3,2,1.),
                                        alpha_fe_halo_line_m,
                                        alpha_fe_halo_line_b), 
                    linestyle='solid', color='Black', linewidth=2., 
                    zorder=6)
        axs[i].axvline(-1, linestyle='solid', color='Black', linewidth=2., 
                       zorder=6)
    if abunds_x[i] == 'AL_FE' and abunds_y[i] == 'MG_MN':
        axs[i].axhline(mg_mn_al_fe_halo_line_y[0], linestyle='solid', 
            color='Black', linewidth=2., zorder=6)
        axs[i].plot([mg_mn_al_fe_halo_line_x[1],mg_mn_al_fe_halo_line_x[2]],
            [mg_mn_al_fe_halo_line_y[1],mg_mn_al_fe_halo_line_y[2]],
            linestyle='solid', color='Black', linewidth=2., zorder=6)
    
    if abunds_x[i] == 'FE_H' and abunds_y[i] == 'AL_FE':
        for j in range(len(al_fe_halo_line_x_all)-1):
            axs[i].plot([al_fe_halo_line_x_all[j],al_fe_halo_line_x_all[j+1]],
                        [al_fe_halo_line_y_all[j],al_fe_halo_line_y_all[j+1]],
                        linestyle='dashed', color='Black', linewidth=2., 
                        zorder=6)
        for j in range(2):
            axs[i].plot([al_fe_halo_line_x_acc[j],al_fe_halo_line_x_acc[j+1]],
                        [al_fe_halo_line_y_acc[j],al_fe_halo_line_y_acc[j+1]],
                        linestyle='dashed', color='Black', linewidth=2., 
                        zorder=6)
        for j in range(len(al_fe_halo_line_x_ext)-1):
            axs[i].plot([al_fe_halo_line_x_ext[j],al_fe_halo_line_x_ext[j+1]],
                        [al_fe_halo_line_y_ext[j],al_fe_halo_line_y_ext[j+1]],
                        linestyle='dotted', color='Black', linewidth=1., 
                        zorder=6)
    
fig.tight_layout()
fig.subplots_adjust(hspace=0.35, wspace=0.25, left=0.09, right=0.9)
fig.show()

### Look at kinematics

In [None]:
# Kinematic spaces information
kin_spaces = ['vRvT','Toomre','ELz','JRLz','eLz','AD']
n_kin_spaces = len(kin_spaces)

# Some keyword properties
AD_dedge = 1.2

# Limits dictionary. [ [xmin,xmax], [ymin,ymax] ]
lims_dict = {'vRvT':   [[-450,450], [-450,450]],
             'Toomre': [[-450,450], [0,450]],
             'ELz':    [[-4500,4500], [-2.,-0.]],
             'JRLz':   [[-4500,4500], [0,80]],
             'eLz':    [[-4500,4500], [-0.05,1.05]],
             'AD':     [[-AD_dedge,AD_dedge], [-AD_dedge,AD_dedge]]
            }

# Labels dictionary. [xlabel, ylabel]
labels_dict = {'vRvT':   [r'$v_{R}$',r'$v_{T}$'],
               'Toomre': [r'$v_{T}$',r'$(v_{R}^{2}+v_{z}^{2})^{1/2}$'],
               'ELz':    [r'$L_{z}$',r'$(E-\Phi_{0})/10^{5}$'],
               'JRLz':   [r'$L_{z}$',r'$\sqrt{J_{R}}$'],
               'eLz':    [r'$L_{z}$',r'eccentricity'],
               'AD':     [r'$L_{z}/J_{tot}$',r'$(J_{z}-J_{R})/J_\mathrm{tot}$']
               }

# Selection dictionaries [ [xcent,ycent], [xmajor,ymajor] ] or [ [x1,x2], [y1,y2] ]
# From Lane+ 2021
halo_selection_survey_dict = putil.lane2022_kinematic_selections(version='current')

halo_selection_kws = {'linestyle':'dashed', 'linewidth':2.5, 'facecolor':'None', 
                      'edgecolor':'Black', 'color':'Black', 'label':r'High-$\beta$'}

# Misc
label_fontsize = 12
nbins = 40

In [None]:
# Figure
fig = plt.figure()
fig.set_size_inches(10,5)
axs = fig.subplots(nrows=2, ncols=3).flatten()
for i in range( n_kin_spaces ):
    
    # Limits and labels
    xlim,ylim = lims_dict[kin_spaces[i]]
    xlabel,ylabel = labels_dict[kin_spaces[i]]
        
    # Purity
    xplot,yplot = pplot.get_plottable_data( [os_omask,], [eELzs_omask,], 
        [accs_omask,], np.array([1,]), kin_spaces[i], phi0=phi0, absolute=True)
    hist = np.histogram2d(xplot,yplot,bins=nbins,
                          range=[[xlim[0],xlim[1]],[ylim[0],ylim[1]]])[0]
    img = axs[i].imshow(np.log10(np.rot90(hist)), cmap=rainbow_cmap, vmin=-0.5, 
                        vmax=3.5, aspect='auto', 
                        extent=(xlim[0],xlim[1],ylim[0],ylim[1]))
    cax = fig.add_axes([0.92,0.15,0.025,0.7])
    cbar = fig.colorbar(img, cax=cax)
    cbar.ax.tick_params(labelsize=12) 
    cbar.set_label(r'$\log_{10} N$', fontsize=16)
    
    # Decoration
    pplot.axis_limits_and_labels(axs[i], xlim, ylim, xlabel, ylabel, 
        mixture_text=None, is_left_edge=True, is_top_edge=False, 
        label_fontsize=label_fontsize)
    if kin_spaces[i] == 'AD':
        pplot.add_diamond_boundary(axs[i], dedge=AD_dedge)
    ##fi
    pplot.add_selection_boundaries(axs[i], 
        halo_selection_survey_dict[kin_spaces[i]], 
        plot_kwargs=halo_selection_kws, plot_cent=False)
###i

fig.tight_layout()
fig.subplots_adjust(hspace=0.35, wspace=0.35, left=0.09, right=0.9)
fig.show()

## Use different kinematic spaces to separate GS/E

### Functions

In [None]:
def get_selection_masks_and_plot(selec_spaces, plot_kinematics=True,
                                 plot_chemistry_scatter=True,
                                 plot_chemistry_hist=True,
                                 plot_feh_hist=True,
                                 plot_peri_apo=True,
                                 plot_feh_energy=True):
    '''get_selection_masks_and_plot:
    
    Function to get the selection masks and make plots of the selection for 
    various kinematic selections. Only argument is select spaces and 
    booleans to decide to make plots, most variables are globals.
    
    Args:
        selec_spaces (arr) - List of selection spaces, composite spaces are
            allowed
        plot_kinematics (bool) - Plot kinematics?
        plot_chemistry_scatter (bool) - Plot chemistry as scatter?
        plot_chemistry_hist (bool) - Plot chemistry as histogram?
        plot_peri_apo (bool) - Plot pericenter and apocenter?
        plot_feh_energy (bool) - Plot energy and [Fe/H]
        
    Returns:
        
    '''
    # Info about selection
    n_selec_spaces = len(selec_spaces)
    selec_spaces_suffix = '-'.join(selec_spaces)
    
    # Perform the selection on the kinematic spaces
    for i in range(n_selec_spaces):

        xplot,yplot = pplot.get_plottable_data( [os_omask,], [eELzs_omask,], 
            [accs_omask,], np.array([1,]), selec_spaces[i], phi0=phi0, 
            absolute=True)
        this_selection = halo_selection_survey_dict[selec_spaces[i]]

        # If this is a combined selection, handle multiple spaces
        if i == 0:
            kmask = pplot.is_in_scaled_selection(xplot, yplot, 
                this_selection, factor=[1.,1.])
        else:
            kmask &= pplot.is_in_scaled_selection(xplot, yplot, this_selection, 
                factor=[1.,1.])
    
    # Combine the whole halo and accreted masks with the kinematic mask
    #h_kmask = hmask & kmask
    acc_kmask = accmask & kmask
    #h_nacc_kmask = hmask & ~accmask & kmask
    
    ## Old way of doing things here: not going to change.. but will compare
    ## with a better, masked way of doing things above

    # Use the selection to get the high-beta stars
    where_in_hb = np.where(kmask)[0]
    gaia_hb = gaia_omask[where_in_hb]
    allstar_hb = allstar_omask[where_in_hb]
    os_hb = os_omask[where_in_hb]
    eELzs_hb = eELzs_omask[:,where_in_hb]
    accs_hb = accs_omask[:,where_in_hb]
    orbextr_hb = orbextr_omask[:,where_in_hb]
    
    # Custom [Al/Fe]-[Fe/H] mask
    print('Making [Al/Fe]-[Fe/H] mask...')
    # _,_ = putil.get_metallicity(allstar_hb,'')
    feh_hb = allstar_hb['FE_H']
    alfe_hb = allstar_hb['AL_FE']
    dmask = putil.lines_boundary_mask(feh_hb, alfe_hb, al_fe_halo_line_x_acc, 
                                      al_fe_halo_line_y_acc, where='left')
    hkmask = putil.lines_boundary_mask(feh_hb, alfe_hb, al_fe_halo_line_x_all,
                                       al_fe_halo_line_y_all, where='left')
    logg_mask = (allstar_hb['LOGG'] > logg_min) &\
                (allstar_hb['LOGG'] < logg_max)
    dmask &= logg_mask
    hkmask &= logg_mask
    print('done masking')

    # Report number of stars in selections
    print('Number of high-beta stars in selection is: '+str(len(where_in_hb)))
    print('Number of high-beta stars in disk-masked selection is: '\
          +str(np.sum(dmask)))
    print('Number of high-beta stars in halo-masked selection is: '\
          +str(np.sum(hkmask)))
    
    # Save three sets of IDs. One with everything selected w/ the kinematic
    # mask. One also including a cut on Fe/H and high Al/Fe (dmask) and 
    # one including both high and low Al/Fe (hmask)
    hb_id_filename = gap_dir+'hb_apogee_ids_'+selec_spaces_suffix+'.npy'
    hb_dmask_id_filename = gap_dir+'hb_apogee_ids_'+selec_spaces_suffix+\
        '_dmask.npy'
    hb_hmask_id_filename = gap_dir+'hb_apogee_ids_'+selec_spaces_suffix+\
        '_hmask.npy'
    apogee_id_hb = allstar_hb['APOGEE_ID'].astype(str)
    print('Saving high-beta APOGEE IDs to '+hb_id_filename)
    np.save(hb_id_filename,apogee_id_hb)
    print('Saving disk-masked high-beta APOGEE IDs to '+hb_dmask_id_filename)
    np.save(hb_dmask_id_filename,apogee_id_hb[dmask])
    print('Saving halo-masked high-beta APOGEE IDs to '+hb_hmask_id_filename)
    np.save(hb_hmask_id_filename,apogee_id_hb[hkmask])
    print('\n\n')    
    
    # Compare with the other masks
    assert np.all(apogee_id_hb[dmask] ==\
                  allstar_omask['APOGEE_ID'][acc_kmask].astype(str))
    
    # Plot kinematics?
    if plot_kinematics:
        
        fig = plt.figure()
        fig.set_size_inches(10,5)
        axs = fig.subplots(nrows=2, ncols=3).flatten()
        nbins = 40
        
        # Loop over kinematic spaces
        for i in range( n_kin_spaces ):

            # Limits and labels
            xlim,ylim = lims_dict[kin_spaces[i]]
            xlabel,ylabel = labels_dict[kin_spaces[i]]

            # Purity
            xplot,yplot = pplot.get_plottable_data( [os_omask,], [eELzs_omask,], 
                [accs_omask,], np.array([1,]), kin_spaces[i], phi0=phi0,
                absolute=True)
            xplot_hb,yplot_hb = pplot.get_plottable_data( [os_hb,], [eELzs_hb,], 
                [accs_hb,], np.array([1,]), kin_spaces[i], phi0=phi0, 
                absolute=True)
            hist = np.histogram2d(xplot,yplot,bins=nbins,
                                  range=[[xlim[0],xlim[1]],[ylim[0],ylim[1]]])[0]
            img = axs[i].imshow(np.log10(np.rot90(hist)), cmap='Greys', vmin=-0.5, 
                                vmax=3.5, aspect='auto', 
                                extent=(xlim[0],xlim[1],ylim[0],ylim[1]))
            axs[i].scatter(xplot_hb, yplot_hb, s=10, facecolor='Red', edgecolor='None',
                           alpha=0.5, label='GE')

            # Decoration
            pplot.axis_limits_and_labels(axs[i], xlim, ylim, xlabel, ylabel, None,
                is_left_edge=True, is_top_edge=False, label_fontsize=18)
            if kin_spaces[i] == 'AD':
                pplot.add_diamond_boundary(axs[i], dedge=AD_dedge)

            if kin_spaces[i] in selec_spaces:
                halo_selection_kws_use = copy.deepcopy(halo_selection_kws)
                halo_selection_kws_use['edgecolor'] = 'Blue'
                pplot.add_selection_boundaries(axs[i], 
                    halo_selection_survey_dict[kin_spaces[i]],
                    plot_kwargs=halo_selection_kws_use, plot_cent=False)
            else:
                pplot.add_selection_boundaries(axs[i], 
                    halo_selection_survey_dict[kin_spaces[i]],
                    plot_kwargs=halo_selection_kws, plot_cent=False)

        fig.tight_layout()
        fig.subplots_adjust(hspace=0.35, wspace=0.35, left=0.09, right=0.9)
        # fig.savefig(fig_dir+'kinematics_'+selec_spaces_suffix+'.png',dpi=300)
        fig.show()
    
    
    # Plot chemistry with scatter points?
    if plot_chemistry_scatter:
        
        fig = plt.figure(figsize=(15,5))
        axs = fig.subplots(nrows=1, ncols=3).flatten()
        n_bins = 25

        # keywords
        marker_s = 25
        mask_disk = True

        abunds_x = ['FE_H','FE_H','AL_FE']
        abunds_y = ['AL_FE','MG_FE','MG_MN']
        abunds_xlabels = ['[Fe/H]', '[Fe/H]', '[Al/Fe]']
        abunds_ylabels = ['[Al/Fe]', '[Mg/Fe]', '[Mg/Mn]']
        xlims = [ [-2.5,0.75], [-2.5,0.75], [-1.1,1.1] ]
        ylims = [ [-0.6,0.6], [-0.4,0.6], [-0.8,1.3] ]

        ecc_sort = np.argsort(eELzs_omask[0])

        for i in range(len(axs)):
            # Data
            xplot,xplot_err = putil.get_metallicity(allstar_hb,abunds_x[i])
            yplot,yplot_err = putil.get_metallicity(allstar_hb,abunds_y[i])
            # Background
            xplot_bg,_ = putil.get_metallicity(allstar_omask,abunds_x[i])
            yplot_bg,_ = putil.get_metallicity(allstar_omask,abunds_y[i])
 
            if mask_disk:
                pts = axs[i].scatter(xplot[dmask], yplot[dmask], 
                                     c=orbextr_hb[0][dmask], cmap=rainbow_cmap, 
                                     edgecolors='None', vmin=0, vmax=25, 
                                     marker='o', s=marker_s, zorder=5)
                pts = axs[i].scatter(xplot[~dmask], yplot[~dmask], 
                                     c=orbextr_hb[0][~dmask], cmap=rainbow_cmap, 
                                     edgecolors='Black', vmin=0, vmax=25, 
                                     marker='x', s=marker_s/2, zorder=5)
            else:
                pts = axs[i].scatter(xplot, yplot, c=orbextr_hb[0], cmap=rainbow_cmap, 
                                     edgecolors='None', vmin=0, vmax=1, 
                                     marker='o', s=5, zorder=4, rasterized=True)

            hist,_,_ = np.histogram2d( xplot_bg, yplot_bg, bins=n_bins, 
                                      range=[xlims[i],ylims[i]] )
            img = axs[i].imshow(np.rot90(np.log10(hist)), cmap='Greys', vmin=-1, 
                                vmax=5, extent=(xlims[i][0],xlims[i][1],
                                                ylims[i][0],ylims[i][1]), 
                                aspect='auto')

            cax = fig.add_axes([0.92,0.2,0.025,0.7])
            cbar = fig.colorbar(pts, cax=cax)
            cbar.ax.tick_params(labelsize=12) 
            cbar.set_label(r'$z_{max}$', fontsize=16)

            axs[i].set_xlabel(abunds_xlabels[i], fontsize=20)
            axs[i].set_ylabel(abunds_ylabels[i], fontsize=20)
            axs[i].set_xlim(xlims[i])
            axs[i].set_ylim(ylims[i])

            # Add lines
            if abunds_x[i] == 'FE_H' and abunds_y[i] == 'MG_FE':
                axs[i].plot(np.arange(-3,2,1.), 
                            pplot.line_equation(np.arange(-3,2,1.),
                                                alpha_fe_halo_line_m,
                                                alpha_fe_halo_line_b), 
                            linestyle='solid', color='Black', linewidth=2., 
                            zorder=6)
                axs[i].axvline(-1, linestyle='solid', color='Black', linewidth=2., 
                               zorder=6)

            if abunds_x[i] == 'AL_FE' and abunds_y[i] == 'MG_MN':
                axs[i].axhline(mg_mn_al_fe_halo_line_y[0], linestyle='solid', 
                    color='Black', linewidth=2., zorder=6)
                axs[i].plot([mg_mn_al_fe_halo_line_x[1],mg_mn_al_fe_halo_line_x[2]],
                    [mg_mn_al_fe_halo_line_y[1],mg_mn_al_fe_halo_line_y[2]],
                    linestyle='solid', color='Black', linewidth=2., zorder=6)

            if abunds_x[i] == 'FE_H' and abunds_y[i] == 'AL_FE':
                for j in range(len(al_fe_halo_line_x_all)-1):
                    axs[i].plot([al_fe_halo_line_x_all[j],al_fe_halo_line_x_all[j+1]],
                                [al_fe_halo_line_y_all[j],al_fe_halo_line_y_all[j+1]],
                                linestyle='dashed', color='Black', linewidth=2., 
                                zorder=6)
                for j in range(2):
                    axs[i].plot([al_fe_halo_line_x_acc[j],al_fe_halo_line_x_acc[j+1]],
                                [al_fe_halo_line_y_acc[j],al_fe_halo_line_y_acc[j+1]],
                                linestyle='dashed', color='Black', linewidth=2., 
                                zorder=6)
                for j in range(len(al_fe_halo_line_x_ext)-1):
                    axs[i].plot([al_fe_halo_line_x_ext[j],al_fe_halo_line_x_ext[j+1]],
                                [al_fe_halo_line_y_ext[j],al_fe_halo_line_y_ext[j+1]],
                                linestyle='dotted', color='Black', linewidth=1., 
                                zorder=6)

        fig.tight_layout()
        fig.subplots_adjust(hspace=0.35, wspace=0.25, left=0.09, right=0.9)
        # fig.savefig(fig_dir+'chemistry_'+selec_spaces_suffix+'.png',dpi=300)
        fig.show()
    
    
    # Plot chemistry with histogram?
    if plot_chemistry_hist:
        
        fig = plt.figure(figsize=(15,5))
        axs = fig.subplots(nrows=1, ncols=3).flatten()
        n_bins = 25

        # keywords
        mask_disk = True

        abunds_x = ['FE_H','FE_H','AL_FE']
        abunds_y = ['AL_FE','MG_FE','MG_MN']
        abunds_xlabels = ['[Fe/H]', '[Fe/H]', '[Al/Fe]']
        abunds_ylabels = ['[Al/Fe]', '[Mg/Fe]', '[Mg/Mn]']
        xlims = [ [-2.5,0.75], [-2.5,0.75], [-1.1,1.1] ]
        ylims = [ [-0.6,0.6], [-0.4,0.6], [-0.8,1.3] ]

        ecc_sort = np.argsort(eELzs_omask[0])

        for i in range(len(axs)):
            # Data
            xplot,xplot_err = putil.get_metallicity(allstar_hb,abunds_x[i])
            yplot,yplot_err = putil.get_metallicity(allstar_hb,abunds_y[i])
            # Background
            xplot_bg,_ = putil.get_metallicity(allstar_omask,abunds_x[i])
            yplot_bg,_ = putil.get_metallicity(allstar_omask,abunds_y[i])

            if mask_disk:
                hist_hb,_,_ = np.histogram2d(xplot[dmask], yplot[dmask],
                                             bins=n_bins, 
                                             range=[xlims[i],ylims[i]])
            else:
                hist_hb,_,_ = np.histogram2d(xplot, yplot, bins=n_bins, 
                                             range=[xlims[i],ylims[i]])

            hist,_,_ = np.histogram2d( xplot_bg, yplot_bg, bins=n_bins, 
                                      range=[xlims[i],ylims[i]] )
            hist_f = 100*hist_hb/hist
            
            img = axs[i].imshow(np.rot90(hist_f), cmap='rainbow', 
                                vmin=0., vmax=50., 
                                extent=(xlims[i][0],xlims[i][1],
                                        ylims[i][0],ylims[i][1]), 
                                aspect='auto')

            cax = fig.add_axes([0.92,0.2,0.025,0.7])
            cbar = fig.colorbar(img, cax=cax)
            cbar.ax.tick_params(labelsize=12) 
            cbar.set_label(r'per cent', fontsize=16)

            axs[i].set_xlabel(abunds_xlabels[i], fontsize=20)
            axs[i].set_ylabel(abunds_ylabels[i], fontsize=20)
            axs[i].set_xlim(xlims[i])
            axs[i].set_ylim(ylims[i])

            # Add lines
            if abunds_x[i] == 'FE_H' and abunds_y[i] == 'MG_FE':
                axs[i].plot(np.arange(-3,2,1.), 
                            pplot.line_equation(np.arange(-3,2,1.),
                                                alpha_fe_halo_line_m,
                                                alpha_fe_halo_line_b), 
                            linestyle='solid', color='Black', linewidth=2., 
                            zorder=6)
                axs[i].axvline(-1, linestyle='solid', color='Black', linewidth=2., 
                               zorder=6)
                
            if abunds_x[i] == 'AL_FE' and abunds_y[i] == 'MG_MN':
                axs[i].axhline(mg_mn_al_fe_halo_line_y[0], linestyle='solid', 
                    color='Black', linewidth=2., zorder=6)
                axs[i].plot([mg_mn_al_fe_halo_line_x[1],mg_mn_al_fe_halo_line_x[2]],
                    [mg_mn_al_fe_halo_line_y[1],mg_mn_al_fe_halo_line_y[2]],
                    linestyle='solid', color='Black', linewidth=2., zorder=6)

            if abunds_x[i] == 'FE_H' and abunds_y[i] == 'AL_FE':
                for j in range(len(al_fe_halo_line_x_all)-1):
                    axs[i].plot([al_fe_halo_line_x_all[j],al_fe_halo_line_x_all[j+1]],
                                [al_fe_halo_line_y_all[j],al_fe_halo_line_y_all[j+1]],
                                linestyle='dashed', color='Black', linewidth=2., 
                                zorder=6)
                for j in range(2):
                    axs[i].plot([al_fe_halo_line_x_acc[j],al_fe_halo_line_x_acc[j+1]],
                                [al_fe_halo_line_y_acc[j],al_fe_halo_line_y_acc[j+1]],
                                linestyle='dashed', color='Black', linewidth=2., 
                                zorder=6)
                for j in range(len(al_fe_halo_line_x_ext)-1):
                    axs[i].plot([al_fe_halo_line_x_ext[j],al_fe_halo_line_x_ext[j+1]],
                                [al_fe_halo_line_y_ext[j],al_fe_halo_line_y_ext[j+1]],
                                linestyle='dotted', color='Black', linewidth=1., 
                                zorder=6)

        fig.tight_layout()
        fig.subplots_adjust(hspace=0.35, wspace=0.25, left=0.09, right=0.9)
        # fig.savefig(fig_dir+'chemistry_'+selec_spaces_suffix+'.png',dpi=300)
        fig.show()
    
    
    # Plot a histogram of [Fe/H]
    if plot_feh_hist:
        fig = plt.figure(figsize=(5,3))
        ax = fig.add_subplot(111)
        
        # Keywords
        feh_range = [-3,0]
        n_bins = 30
        mask_disk = True
        
        feh,_ = putil.get_metallicity(allstar_hb,'FE_H')
        
        if mask_disk:
            ax.hist(feh[dmask], range=feh_range, bins=n_bins, histtype='step', 
                    color='DodgerBlue', label='inside [Al/Fe]-[Fe/H] mask')
            ax.hist(feh[~dmask], range=feh_range, bins=n_bins, histtype='step',
                    color='Red', label='outside [Al/Fe]-[Fe/H] mask')
        else:
            ax.hist(feh, range=feh_range, bins=n_bins, histtype='step',
                    color='DodgerBlue')
            
        ax.set_xlabel('[Fe/H]')
        ax.set_ylabel('N')
    
    
    # Plot the pericenter / apocenter of the selected stars?
    if plot_peri_apo:
        
        fig = plt.figure(figsize=(5,3))
        gs = fig.add_gridspec(nrows=4,ncols=4)
        ax = fig.add_subplot(gs[1:,:-1])
        axrp = fig.add_subplot(gs[0,:-1])
        axra = fig.add_subplot(gs[1:,-1])

        # Keywords
        marker_s = 25
        mask_disk = True
        n_bins = 25
        
        # Data
        feh,_ = putil.get_metallicity(allstar_hb,'FE_H')
        rp = orbextr_hb[1]
        ra = orbextr_hb[2]
        rp_range = [-3.5,1.5]
        ra_range = [0.5,2.5]

        if mask_disk:
            pts = ax.scatter(np.log10(rp)[dmask], np.log10(ra)[dmask], 
                             c=feh[dmask], cmap=rainbow_cmap, 
                             edgecolors='None', vmin=-3, vmax=0., 
                             marker='o', s=marker_s, zorder=5)
            pts = ax.scatter(np.log10(rp)[~dmask], np.log10(ra)[~dmask], 
                             c=feh[~dmask], cmap=rainbow_cmap, 
                             edgecolors='Black', vmin=-3, vmax=0., 
                             marker='x', s=marker_s/2, zorder=5)
            axrp.hist(np.log10(rp)[dmask], range=rp_range, bins=n_bins,
                      histtype='step', color='DodgerBlue')
            axrp.hist(np.log10(rp)[~dmask], range=rp_range, bins=n_bins,
                      histtype='step', color='Red')
            
            axra.hist(np.log10(ra)[dmask], range=ra_range, bins=n_bins,
                      histtype='step', color='DodgerBlue', 
                      orientation='horizontal')
            axra.hist(np.log10(ra)[~dmask], range=ra_range, bins=n_bins,
                      histtype='step', color='Red',
                      orientation='horizontal')
            
            axrp.tick_params(axis='both', labelbottom=False, labelleft=False)
            
            axra.tick_params(axis='both', labelbottom=False, labelleft=False)
        else:
            pts = ax.scatter(np.log10(rp), np.log10(ra), c=feh, cmap=rainbow_cmap, 
                             edgecolors='None', vmin=-3, vmax=0, 
                             marker='o', s=5, zorder=4, rasterized=True)
        
        ax.set_xlabel('pericenter/kpc')
        ax.set_ylabel('apocenter/kpc')
        ax.set_xlim(rp_range)
        ax.set_ylim(ra_range)
        axrp.set_xlim(rp_range)
        axra.set_ylim(ra_range)
        #cbar = fig.colorbar(pts, ax=ax)
        #cbar.ax.tick_params(labelsize=12) 
        #cbar.set_label(r'[Fe/H]', fontsize=16)
    
    
    if plot_feh_energy:
        
        fig = plt.figure(figsize=(10,15))
        axs = fig.subplots(nrows=3,ncols=2)
        
        # Keywords
        marker_s = 25
        
        abunds = ['FE_H','MG_FE']
        abunds_labels = ['[Fe/H]', '[Mg/Fe]']
        abunds_lims = [ [-3.,0.], [-0.4,0.6] ]
        
        for i in range(2):
            abund_halo = putil.get_metallicity(allstar_omask,abunds[i])[0][hmask]
            E_halo = (eELzs_omask[1][hmask]-phi0)/1e5
            abund_hb = putil.get_metallicity(allstar_hb,abunds[i])[0]
            E_hb = (eELzs_hb[1]-phi0)/1e5
            
            colors = ['Black','DodgerBlue','Red']
            
            axs[0,i].scatter(E_halo, abund_halo, c='Black', marker='o', 
                             s=marker_s, alpha=0.1)
            
            axs[1,i].scatter(E_hb[dmask], abund_hb[dmask], c='DodgerBlue', 
                             marker='o', s=marker_s, alpha=1.)
            
            axs[2,i].scatter(E_hb[~dmask], abund_hb[~dmask], c='Red', 
                             marker='x', s=marker_s, alpha=1.)
            
            for j in range(3):
                axs[j,i].set_xlabel('Energy')
                axs[j,i].set_ylabel(abunds_labels[i])
                
    
    return None

In [None]:
plot_kwargs = {'plot_kinematics':True,
               'plot_chemistry_scatter':True,
               'plot_chemistry_hist':True,
               'plot_feh_hist':True,
               'plot_peri_apo':True,
               'plot_feh_energy':False}

### Action Diamond

In [None]:
selec_spaces = ['AD',]
get_selection_masks_and_plot(selec_spaces,**plot_kwargs)

### e-Lz

In [None]:
selec_spaces = ['eLz',]
get_selection_masks_and_plot(selec_spaces,**plot_kwargs)

### JRLz

In [None]:
selec_spaces = ['JRLz',]
get_selection_masks_and_plot(selec_spaces,**plot_kwargs)