In [None]:
# ------------------------------------------------------------------------
#
# TITLE - plot_density_contours.ipynb
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Make a figure of the density contours for the paper
'''

__author__ = "James Lane"

In [None]:
### Imports

# Basic
import os, sys, pdb, copy
import numpy as np

# Matplotlib and plotting 
import matplotlib
import matplotlib.pyplot as plt

# Project specific
sys.path.insert(0,'../../src/')
from ges_mass import mass as pmass
from ges_mass import densprofiles as pdens
from ges_mass import util as putil
from ges_mass import plot as pplot

### Notebook setup

%matplotlib inline
plt.style.use('../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, Pathing, Loading, Data Preparation

In [None]:
# %load ../../src/nb_modules/keywords_pathing_loading_data_prep.py
## Keywords
cdict = putil.load_config_to_dict()
keywords = ['BASE_DIR','APOGEE_DR','APOGEE_RESULTS_VERS','GAIA_DR','NDMOD',
            'DMOD_MIN','DMOD_MAX','LOGG_MIN','LOGG_MAX','FEH_MIN','FEH_MAX',
            'FEH_MIN_GSE','FEH_MAX_GSE','DF_VERSION','KSF_VERSION','NPROCS',
            'RO','VO','ZO']
base_dir,apogee_dr,apogee_results_vers,gaia_dr,ndmod,dmod_min,dmod_max,\
    logg_min,logg_max,feh_min,feh_max,feh_min_gse,feh_max_gse,df_version,\
    ksf_version,nprocs,ro,vo,zo = putil.parse_config_dict(cdict,keywords)
logg_range = [logg_min,logg_max]
feh_range = [feh_min,feh_max]
feh_range_gse = [feh_min_gse,feh_max_gse]
feh_range_all = [feh_min,feh_max_gse]
# feh_range_fit = copy.deepcopy( # Need to choose here


## Pathing
fit_paths = putil.prepare_paths(base_dir,apogee_dr,apogee_results_vers,gaia_dr,
                                df_version,ksf_version)
data_dir,version_dir,ga_dir,gap_dir,df_dir,ksf_dir,fit_dir = fit_paths

## Filenames
fit_filenames = putil.prepare_filenames(ga_dir,gap_dir,feh_range_gse)
apogee_SF_filename,apogee_effSF_filename,apogee_effSF_mask_filename,\
    iso_grid_filename,clean_kinematics_filename = fit_filenames

## File loading and data preparation
fit_stuff,other_stuff = putil.prepare_fitting(fit_filenames,
    [ndmod,dmod_min,dmod_max], ro,zo,return_other=True)
apogee_effSF_mask,dmap,iso_grid,jkmins,dmods,ds,effsel_grid,apof,\
    allstar_nomask,orbs_nomask = fit_stuff
Rgrid,phigrid,zgrid = effsel_grid

# ## Load the distribution functions
# df_filename = df_dir+'dfs.pkl'
# betas = [0.3,0.8]
# dfs = putil.load_distribution_functions(df_filename, betas)

In [None]:
fig_dir = './fig/'

### Global Parameters

In [None]:
# %load ../../src/nb_modules/global_fitting_params.py
## general kwargs
verbose = True

## HaloFit kwargs (ordering follows HaloFit.__init__)
# allstar and orbs loaded in prep cell
init = None
init_type = 'ML'
# fit_type provided at runtime
mask_disk = True
mask_halo = True
# densfunc, selec provided at runtime
# effsel, effsel_grid, effsel_mask, dmods loaded in prep cell
nwalkers = 100
nit = int(2e3)
ncut = int(1e3)
# usr_log_prior provided at runtime
n_mass = 5000 # int(nwalkers*(nit-ncut))
int_r_range = [2.,70.]
iso = None # Will read from iso_grid_filename
# iso_filename, jkmins loaded in prep cell
# feh_range provided at runtime
# logg_range loaded in config cell
# fit_dir, gap_dir, ksf_dir loaded in prep cell
# version provided at runtime
# ro, vo, zo loaded in config cell

hf_kwargs = {## HaloFit parameters
             'allstar':allstar_nomask,
             'orbs':orbs_nomask,
             'init':init,
             'init_type':init_type,
             # 'fit_type':fit_type, # provided at runtime
             'mask_disk':mask_disk,
             'mask_halo':mask_halo,
             ## _HaloFit parameters
             # 'densfunc':densfunc, # provided at runtime
             # 'selec':selec, # provided at runtime
             'effsel':apof,
             'effsel_mask':apogee_effSF_mask,
             'effsel_grid':effsel_grid,
             'dmods':dmods,
             'nwalkers':nwalkers,
             'nit':nit,
             'ncut':ncut,
             # 'usr_log_prior':usr_log_prior, # provided at runtime
             'n_mass':n_mass,
             'int_r_range':int_r_range,
             'iso':iso,
             'iso_filename':iso_grid_filename,
             'jkmins':jkmins,
             # 'feh_range':feh_range, # provided at runtime
             'logg_range':logg_range,
             'fit_dir':fit_dir,
             'gap_dir':gap_dir,
             'ksf_dir':ksf_dir,
             # 'version':version, # provided at runtime
             'verbose':verbose,
             'ro':ro,
             'vo':vo,
             'zo':zo}

## pmass.fit() function kwargs
# nprocs set in config file
force_fit = True
mle_init = True
just_mle = False
return_walkers = True
optimizer_method = 'Powell'
mass_int_type = 'spherical_grid'
batch_masses = True
make_ml_aic_bic = True
calculate_masses = True
post_optimization = True
mcmc_diagnostic = True

fit_kwargs = {# 'nprocs':nprocs, # Normally given at runtime 
              'force_fit':force_fit,
              'mle_init':mle_init,
              'just_mle':just_mle,
              'return_walkers':return_walkers,
              'optimizer_method':optimizer_method,
              'mass_int_type':mass_int_type,
              'batch_masses':batch_masses,
              'make_ml_aic_bic':make_ml_aic_bic,
              'calculate_masses':calculate_masses,
              'post_optimization':post_optimization,
              'mcmc_diagnostic':mcmc_diagnostic,
              }

### Convenience function

In [None]:
def plot_fit_contours(params,densfunc,contour=True,show_sun=True,
    show_major_axis=False,left_handed=True,rmin_cut=5,):
    '''plot_fit_contours:
    
    Plot the contours in X-Y, X-Z, and Y-Z planes for a given set of 
    parameters.
    
    Args:
        params (array-like): The parameters to plot.
        densfunc (callable): The density function to use from pdens.
        contour (bool): Whether to plot the contours or density
        show_sun (bool): Whether to plot a point for the Sun.
        show_major_axis (bool): Whether to plot the major axis.
        left_handed (bool): Whether to reverse the Y axes to emphasize that 
            the coordinate system is left-handed.
        rmin_cut (float): Minimum radius to cut for the density profile to 
            avoid the central flattening of the density profile.
    
    Returns:
        fig (matplotlib.figure.Figure): The figure object.
        axs (matplotlib.axes.Axes): The axes object.
        '''
    # Plotting keywords
    scale = 50.
    n = 500
    sun_kwargs = {'color':'DarkOrange',
                  'marker':r'$\odot$',
                  's':100,
                  'zorder':5,
                  'linewidths':2,}
    major_axis_kwargs = {'color':'Black',
                        'linestyle':'dashed'}
    rmin_cut = 5. # in kpc
    contour_kwargs = {}
    imshow_kwargs = {}

    # Fontsize and formatting keywords
    columnwidth = 244./72.27 # Only one column. In inches, from pt
    textwidth = 508./72.27 # Whole page. In inches, from pt
    ylabel_fs = 12
    xlabel_fs = 12
    # posterior_lw = 1
    # legend_fs = 8
    ticklabel_fs = 10

    fig = plt.figure(figsize=(textwidth,3))
    axs = fig.subplots(nrows=1,ncols=3)

    # Index strings
    ind_arr = [[0,1],[0,2],[1,2]]
    ind_str = [['X','Y'],['X','Z'],['Y','Z']]

    # Make the X,Y,Z grid and R,phi,z grid
    xs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
    ys = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
    zs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)

    # Loop over the different views
    for i in range(3):
        # Grid
        if i == 0: # XY plane
            c1grid,c2grid = np.meshgrid(xs,ys)
            ogrid = np.zeros_like(c1grid)
            Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(c1grid,c2grid,ogrid)
        if i == 1: # XZ plane
            c1grid,c2grid = np.meshgrid(xs,zs)
            ogrid = np.zeros_like(c1grid)
            Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(c1grid,ogrid,c2grid)
        if i == 2: # YZ plane
            c1grid,c2grid = np.meshgrid(ys,zs)
            ogrid = np.zeros_like(c1grid)
            Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(ogrid,c1grid,c2grid)

        # Calculate densities
        dgrid = densfunc(Rgrid,phigrid,zgrid,params)
        
        # Cut out the inner region
        if rmin_cut is not None:
            cut_mask = (np.sqrt(c1grid**2.+c2grid**2.) < rmin_cut)
            dgrid[cut_mask] = np.nan

        # if left_handed:
        #     if ind_str[i][0] == 'Y':
        #         extent = [scale,-scale,-scale,scale]
        #     if ind_str[i][1] == 'Y':
        #         extent = [-scale,scale,scale,-scale]

        if contour:
            axs[i].contour(c1grid,c2grid,np.log10(dgrid),
                levels=[-4.5,-3.5,-2.5,-1.5,-0.5],linestyles='solid',
                colors='Black')
        else:
            axs[i].imshow(np.log10(dgrid),origin='lower',
                aspect='equal',**imshow_kwargs)
        
        # Labels
        axs[i].set_xlabel(ind_str[i][0]+' [kpc]',fontsize=xlabel_fs)
        axs[i].set_ylabel(ind_str[i][1]+' [kpc]',fontsize=ylabel_fs)
        axs[i].axhline(0,color='Black',linestyle='solid', linewidth=0.5)
        axs[i].axvline(0,color='Black',linestyle='solid', linewidth=0.5)
        axs[i].tick_params(axis='both',labelsize=ticklabel_fs)
        axs[i].set_aspect('equal')
        axs[i].set_xlim(-scale,scale)
        axs[i].set_ylim(-scale,scale)
        if left_handed:
            if i == 0:
                axs[i].set_ylim(scale,-scale)
            if i == 2:
                axs[i].set_xlim(scale,-scale)
        axs[i].set_xticks([-40,-20,0,20,40])
        axs[i].set_yticks([-40,-20,0,20,40])

        if show_sun:
            x_sun = 8.275
            y_sun = 0.
            z_sun = 0.0208
            if i == 0:
                axs[i].scatter(x_sun, y_sun, **sun_kwargs)
            if i == 1:
                axs[i].scatter(x_sun, z_sun, **sun_kwargs)
            if i == 2:
                axs[i].scatter(y_sun, z_sun, **sun_kwargs)
        if show_major_axis:
            xt,yt,zt = pplot.make_principal_axis(densfunc,
                pdens.denormalize_parameters(params,densfunc),
                    x=[-5*scale,5*scale],y=[0,0],z=[0,0])
            print(xt,yt,zt)
            if i == 0:
                axs[i].plot(xt,yt,**major_axis_kwargs)
            if i == 1:
                axs[i].plot(xt,zt,**major_axis_kwargs)
            if i == 2:
                axs[i].plot(yt,zt,**major_axis_kwargs)

    return fig,axs

## First show a figure with just our preferred results for AD

### Optimized maximum likelihood first

### More parameters

In [None]:
# Parameters for the best fit
selec = 'AD'
fit_type = 'gse'
densfunc = pdens.triaxial_single_cutoff_zvecpa
version = '100w_2e3n'
color = 'Black'

theta_in_degr = False
phi_in_degr = False
rad_to_degr = 180./np.pi

# Account for changing Fe/H range
if fit_type == 'gse':
    feh_range_fit = copy.deepcopy(feh_range_gse)
else:
    feh_range_fit = copy.deepcopy(feh_range_all)
hf = pmass.HaloFit(densfunc=densfunc, fit_type=fit_type, 
                    version=version, selec=selec, 
                    feh_range=feh_range_fit, **hf_kwargs)

# Load the results. Use best-fitting parameters from the optimization done 
# post-MCMC
hf.get_results()
hf.get_loglike_ml_aic_bic()

print('Best-fitting parameters:')
print('post: ',hf.get_ml_params(ml_type='post'))
print('median: ',hf.get_ml_params(ml_type='mcmc_median'))
print('ml: ',hf.get_ml_params(ml_type='mcmc_ml'))

params = hf.get_ml_params(ml_type='post')
fig,axs = plot_fit_contours(params,densfunc,contour=True,show_sun=True,
    show_major_axis=True,left_handed=True,rmin_cut=5)
fig.tight_layout()
# fig.savefig(fig_dir+'contour.pdf')

In [None]:
pavec = hf.get_rotated_coords_in_gc_frame(ml_type='post',
    vec=np.array([1,0,0]))[0]
pavec = pavec/np.linalg.norm(pavec)
alt,az = putil.vec_to_alt_az(pavec,degrees=True)
print('Altitude: ',alt)
print('Azimuth: ',az)

### Now median params

In [None]:
# Parameters for the best fit
selec = 'AD'
fit_type = 'gse'
densfunc = pdens.triaxial_single_cutoff_zvecpa
version = '100w_2e3n'
color = 'Black'

theta_in_degr = False
phi_in_degr = False
rad_to_degr = 180./np.pi

# Account for changing Fe/H range
if fit_type == 'gse':
    feh_range_fit = copy.deepcopy(feh_range_gse)
else:
    feh_range_fit = copy.deepcopy(feh_range_all)
hf = pmass.HaloFit(densfunc=densfunc, fit_type=fit_type, 
                    version=version, selec=selec, 
                    feh_range=feh_range_fit, **hf_kwargs)

# Load the results. Use best-fitting parameters from the optimization done 
# post-MCMC
hf.get_results()
hf.get_loglike_ml_aic_bic()

params = hf.get_ml_params(ml_type='mcmc_median')
fig,axs = plot_fit_contours(params,densfunc,contour=True,show_sun=True,
    show_major_axis=True,left_handed=True,rmin_cut=5)
fig.tight_layout()
fig.savefig(fig_dir+'contour.pdf')

In [None]:
pavec = hf.get_rotated_coords_in_gc_frame(ml_type='mcmc_median',
    vec=np.array([1,0,0]))[0]
pavec = pavec/np.linalg.norm(pavec)
alt,az = putil.vec_to_alt_az(pavec,degrees=True)
print('Altitude: ',alt)
print('Azimuth: ',az)

### Show the contours from density integrated along the line of sight

In [None]:
scale = 40.
n = 100

# Make the X,Y,Z grid and R,phi,z grid
xs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
ys = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
zs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)

xgrid,ygrid,zgrid = np.meshgrid(xs,ys,zs)
Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(xgrid,ygrid,zgrid)
dgrid = densfunc(Rgrid,phigrid,zgrid,params)

In [None]:
# Larger scale parameters
contour=True
show_sun=True
show_major_axis=True
left_handed=False
rmin_cut=5

# Plotting keywords
scale = 50.
n = 100
sun_kwargs = {'color':'DarkOrange',
            'marker':r'$\odot$',
            's':100,
            'linewidths':2,}
major_axis_kwargs = {'color':'Black',
                    'linestyle':'dashed'}
rmin_cut = 5. # in kpc
contour_kwargs = {}
imshow_kwargs = {}

# Fontsize and formatting keywords
columnwidth = 244./72.27 # Only one column. In inches, from pt
textwidth = 508./72.27 # Whole page. In inches, from pt
ylabel_fs = 12
xlabel_fs = 12
# posterior_lw = 1
# legend_fs = 8
ticklabel_fs = 10

fig = plt.figure(figsize=(textwidth,3))
axs = fig.subplots(nrows=1,ncols=3)

# Index strings
ind_arr = [[0,1],[0,2],[1,2]]
ind_str = [['X','Y'],['X','Z'],['Y','Z']]

# Make the X,Y,Z grid and R,phi,z grid
xs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
ys = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)
zs = np.linspace(-scale,scale,endpoint=False,num=n)+(scale/n)

xgrid,ygrid,zgrid = np.meshgrid(xs,ys,zs)
Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(xgrid,ygrid,zgrid)
dgrid = densfunc(Rgrid,phigrid,zgrid,params)

# Loop over the different views
for i in range(3):
    # Do the appropriate sum along the line of sight
    if i == 0:
        c1grid,c2grid = np.meshgrid(xs,ys)
        # sdens = np.zeros_like(c1grid)
        # for j in range(n):
        #     for k in range(n):
        #         xgrid = xs[j]*np.ones_like(zs)
        #         ygrid = ys[k]*np.ones_like(zs)
        #         Rgrid,phigrid,zgrid = pplot.xyz_to_Rphiz(xgrid,ygrid,zs)
        #         sdens[j,k] = np.sum(densfunc(Rgrid,phigrid,zgrid,params))
        sdens = np.sum(dgrid,axis=2)
    if i == 1:
        c1grid,c2grid = np.meshgrid(xs,zs)
        sdens = np.sum(dgrid,axis=0).T
    if i == 2:
        c1grid,c2grid = np.meshgrid(ys,zs)
        sdens = np.sum(dgrid,axis=1).T
    
    # # Cut out the inner region
    # if rmin_cut is not None:
    #     cut_mask = (np.sqrt(c1grid**2.+c2grid**2.) < rmin_cut)
    #     dgrid[cut_mask] = np.nan

    if contour:
        axs[i].contour(c1grid,c2grid,np.log10(sdens),
            extent=extent,levels=[0.,0.5,1.,1.5,2.],
            linestyles='solid', colors='Black')
    else:
        axs[i].imshow(np.log10(sdens),origin='lower',
            extent=extent,aspect='equal',**imshow_kwargs)
    
    # Labels
    axs[i].set_xlabel(ind_str[i][0]+' [kpc]',fontsize=xlabel_fs)
    axs[i].set_ylabel(ind_str[i][1]+' [kpc]',fontsize=ylabel_fs)
    axs[i].axhline(0,color='Black',linestyle='solid', linewidth=0.5)
    axs[i].axvline(0,color='Black',linestyle='solid', linewidth=0.5)
    axs[i].tick_params(axis='both',labelsize=ticklabel_fs)
    axs[i].set_aspect('equal')
    axs[i].set_xlim(-scale,scale)
    axs[i].set_ylim(-scale,scale)
    if left_handed:
        if i == 0:
            axs[i].set_ylim(scale,-scale)
        if i == 2:
            axs[i].set_xlim(scale,-scale)
    axs[i].set_xticks([-40,-20,0,20,40])
    axs[i].set_yticks([-40,-20,0,20,40])

    if show_sun:
        x_sun = 8.275
        y_sun = 0.
        z_sun = 0.0208
        if i == 0:
            axs[i].scatter(x_sun, y_sun, **sun_kwargs)
        if i == 1:
            axs[i].scatter(x_sun, z_sun, **sun_kwargs)
        if i == 2:
            axs[i].scatter(y_sun, z_sun, **sun_kwargs)
    if show_major_axis:
        xt,yt,zt = pplot.make_principal_axis(densfunc,
            pdens.denormalize_parameters(params,densfunc),
                x=[-scale,scale],y=[0,0],z=[0,0])
        if i == 0:
            axs[i].plot(xt,yt,**major_axis_kwargs)
        if i == 1:
            axs[i].plot(xt,zt,**major_axis_kwargs)
        if i == 2:
            axs[i].plot(yt,zt,**major_axis_kwargs)

fig.tight_layout()
# fig.savefig(fig_dir+'contour.pdf')