In [None]:
# ------------------------------------------------------------------------
#
# TITLE - 4-df_stats.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Make a figure of the statistics of DF properties
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os, dill as pickle, warnings

## Plotting
import matplotlib as mpl
from matplotlib import pyplot as plt
import corner

## Astropy
from astropy import units as apu

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import plot as pplot
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','FIG_DIR_BASE','FITTING_DIR_BASE',
            'RO','VO','ZO','LITTLE_H','MW_MASS_RANGE']
data_dir,mw_analog_dir,fig_dir_base,fitting_dir_base,ro,vo,zo,h,\
    mw_mass_range = putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
local_fig_dir = './fig/'
fig_dir = os.path.join(fig_dir_base, 'notebooks/paper/')
os.makedirs(local_fig_dir,exist_ok=True)
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Get the data

In [None]:
analysis_version = 'v1.1'
analysis_dir = os.path.join(mw_analog_dir,'analysis',analysis_version)

# Load the merger_information
merger_data = np.load(os.path.join(analysis_dir,'merger_data.npy'), 
    allow_pickle=True)

### Corner plot of density profile parameters vs DF parameters

In [None]:
## Keywords
msun_txt = r'$\mathrm{M}_{\odot}$'
# Dimensions
columnwidth, textwidth = pplot.get_latex_columnwidth_textwidth_inches()
lim_fac = 0.2
# Font sizing and labels
scatter_xaxislabel_fs = 7
scatter_yaxislabel_fs = 7
hist_yaxislabel_fs = 6
ticklabel_fs = 5
xaxislabel_rot = 30
yaxislabel_rot = 30
xaxislabel_pad = 3
yaxislabel_pad = 6
# Scatter plot
scatter_s = 2.0
scatter_facecolor = 'Black'
scatter_edgecolor = 'none'
scatter_alpha = 0.5
scatter_tick_length = 2.0
# Histogram
hist_nbins = 8
hist_facecolor = 'none'
hist_edgecolor = 'Black'
hist_alpha = 1.0
hist_tick_length = 2.0
# Ellipses
ellipse_alpha = 0.6
ellipse_facecolor = 'none'
ellipse_edgecolor = 'Red'
ellipse_linewidth = 1.0
ellipse_fac = 0.15

corner_data = np.vstack( (np.log10(merger_data['star_mass']),
                          np.log10(merger_data['star_mass_ratio']),
                          merger_data['merger_redshift'],
                          merger_data['alpha'],
                          np.log10(merger_data['beta']),
                          np.log10(merger_data['a']),
                          merger_data['krot'],
                          merger_data['chi'],
                          merger_data['anisotropy'],
                          np.log10(merger_data['ra']),
                          np.log10(merger_data['ra1']),
                          np.log10(merger_data['ra2']),
                          merger_data['kom'],
                          ) )
corner_labels = [r'$\log_{10} M_{\star}$',
                 r'$\log_{10} m_{\star}$',
                 r'$z_\mathrm{merger}$',
                 r'$\alpha_{1}$',
                 r'$\log_{10} \alpha_{2}$',
                 r'$\log_{10} a$',
                 r'$k_\mathrm{rot}$',
                 r'$\chi$',
                 r'$\beta$',
                 r'$\log_{10} r_a$',
                 r'$\log_{10} r_{a,1}$',
                 r'$\log_{10} r_{a,2}$',
                 r'$k_\mathrm{om}$',
                 ]
# tick_locations = [None,
#                   None,
#                   None,
#                   [-4,0,4],
#                   None,
#                   None,
#                   None,
#                   None,
#                   None,
#                   None,
#                   None,
#                   None,
#                   None,
#                  ]

M, N = corner_data.shape

fig = plt.figure( figsize=(textwidth, textwidth) )
axs = fig.subplots(nrows=M, ncols=M)

for i in range(M):
    for j in range(M):

        if i == j: # Histogram
            # Plot
            axs[i,j].hist( corner_data[i,:], bins=hist_nbins, histtype='step', 
                           edgecolor=hist_edgecolor, facecolor=hist_facecolor,
                           )
            # Limits
            x_dmin, x_dmax = np.min(corner_data[i,:]), np.max(corner_data[i,:])
            x_dsize = x_dmax - x_dmin
            x_min = x_dmin - lim_fac*x_dsize
            x_max = x_dmax + lim_fac*x_dsize
            axs[i,j].set_xlim( x_min, x_max )
            # Labels
            axs[i,j].set_ylabel('N', fontsize=hist_yaxislabel_fs)
            axs[i,j].yaxis.set_label_position('right')
            axs[i,j].tick_params(labelright=True)
            # if tick_locations[i] is not None:
            #     axs[i,j].set_xticks(tick_locations[i])
            if i == M-1:
                axs[i,j].set_xlabel(corner_labels[i], 
                    fontsize=scatter_xaxislabel_fs, 
                    labelpad=xaxislabel_pad)
                axs[i,j].xaxis.get_label().set_rotation(xaxislabel_rot)
                axs[i,j].tick_params(labelleft=False, labelbottom=True, 
                    labelsize=ticklabel_fs)
            else:
                axs[i,j].tick_params(labelbottom=False, labelleft=False,
                    labelright=True, labelsize=ticklabel_fs, 
                    length=hist_tick_length)
        
        elif j < i: # Scatter
            # Plot
            axs[i,j].scatter( corner_data[j,:], corner_data[i,:], 
                              marker='o', s=scatter_s, 
                              facecolor=scatter_facecolor, 
                              edgecolor=scatter_edgecolor, alpha=scatter_alpha 
                              )
            
            # Limits
            x_dmin, x_dmax = np.min(corner_data[j,:]), np.max(corner_data[j,:])
            y_dmin, y_dmax = np.min(corner_data[i,:]), np.max(corner_data[i,:])
            x_dsize = x_dmax - x_dmin
            y_dsize = y_dmax - y_dmin
            x_min = x_dmin - lim_fac*x_dsize/3
            x_max = x_dmax + lim_fac*x_dsize
            y_min = y_dmin - lim_fac*y_dsize/3
            y_max = y_dmax + lim_fac*y_dsize
            axs[i,j].set_xlim( x_min, x_max )
            axs[i,j].set_ylim( y_min, y_max )
            # Labels
            # if tick_locations[i] is not None:
            #     axs[i,j].set_xticks(tick_locations[i])
            if i == M-1:
                axs[i,j].set_xlabel(corner_labels[j], 
                    fontsize=scatter_xaxislabel_fs, 
                    labelpad=xaxislabel_pad)
                axs[i,j].xaxis.get_label().set_rotation(xaxislabel_rot)
                axs[i,j].tick_params(labelbottom=True, labelsize=ticklabel_fs,
                    length=scatter_tick_length)
            else:
                axs[i,j].tick_params(labelbottom=False)
            if j == 0 and i > 0:
                axs[i,j].set_ylabel(corner_labels[i], 
                    fontsize=scatter_yaxislabel_fs, 
                    labelpad=yaxislabel_pad)
                axs[i,j].yaxis.get_label().set_rotation(90-yaxislabel_rot)
                axs[i,j].tick_params(labelleft=True, labelsize=ticklabel_fs,
                    length=scatter_tick_length)
            else:
                axs[i,j].tick_params(labelleft=False, 
                    length=scatter_tick_length)
            # Covariance
            cov = np.cov(corner_data[j,:], corner_data[i,:])
            eval,evec = np.linalg.eig(cov)
            a,b = np.sqrt(eval)
            theta = np.arctan2(evec[1,0],evec[0,0])
            ellipse_dx = ellipse_fac*(x_max-x_min)
            ellipse_dy = ellipse_fac*(y_max-y_min)
            ellipse = mpl.patches.Ellipse(xy=(x_max-ellipse_dx, 
                                              y_max-ellipse_dy),
                width=a,height=b,angle=np.rad2deg(theta),
                facecolor=ellipse_facecolor, edgecolor=ellipse_edgecolor,
                alpha=ellipse_alpha, linewidth=ellipse_linewidth)
            axs[i,j].add_patch(ellipse)

        else: # Nothing, duplicate of lower triangle
            axs[i,j].axis('off')
            

fig.tight_layout()
fig.subplots_adjust(wspace=0.08, hspace=0.08, left=0.07, right=0.95, 
    bottom=0.07, top=0.99)
fig.savefig(os.path.join(local_fig_dir, 'density_df_params.pdf'), dpi=500, 
    bbox_inches='tight')
fig.savefig(os.path.join(local_fig_dir, 'density_df_params.png'), dpi=500,
    bbox_inches='tight')
fig.show()