In [None]:
# ------------------------------------------------------------------------
#
# TITLE - merger_stats.ipynb
# AUTHOR - James Lane
# PROJECT - tng-dfs
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Make a figure of the statistics
'''

__author__ = "James Lane"

In [None]:
# %load ../../src/nb_modules/nb_imports.txt
### Imports

## Basic
import numpy as np
import sys, os
import dill as pickle
import pdb, copy, warnings

## Matplotlib
import matplotlib as mpl
from matplotlib import pyplot as plt
import corner

## Astropy
from astropy import units as apu

## Project-specific
src_path = 'src/'
while True:
    if os.path.exists(src_path): break
    if os.path.realpath(src_path).split('/')[-1] in ['tng-dfs','/']:
            raise FileNotFoundError('Failed to find src/ directory.')
    src_path = os.path.join('..',src_path)
sys.path.insert(0,src_path)
from tng_dfs import cutout as pcutout
from tng_dfs import kinematics as pkin
from tng_dfs import plot as pplot
from tng_dfs import tree as ptree
from tng_dfs import util as putil

### Notebook setup

%matplotlib inline
plt.style.use(os.path.join(src_path,'mpl/project.mplstyle')) # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

### Keywords, loading, pathing

In [None]:
# %load ../../src/nb_modules/nb_setup.txt
# Keywords
cdict = putil.load_config_to_dict()
keywords = ['DATA_DIR','MW_ANALOG_DIR','RO','VO','ZO','LITTLE_H',
            'MW_MASS_RANGE']
data_dir,mw_analog_dir,ro,vo,zo,h,mw_mass_range = \
    putil.parse_config_dict(cdict,keywords)

# MW Analog 
mwsubs,mwsubs_vars = putil.prepare_mwsubs(mw_analog_dir,h=h,
    mw_mass_range=mw_mass_range,return_vars=True,force_mwsubs=False,
    bulge_disk_fraction_cuts=True)

# Figure path
fig_dir = './fig/'
os.makedirs(fig_dir,exist_ok=True)
show_plots = False

# Load tree data
tree_primary_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_primaries.pkl')
with open(tree_primary_filename,'rb') as handle: 
    tree_primaries = pickle.load(handle)
tree_major_mergers_filename = os.path.join(mw_analog_dir,
    'major_mergers/tree_major_mergers.pkl')
with open(tree_major_mergers_filename,'rb') as handle:
    tree_major_mergers = pickle.load(handle)
n_mw = len(tree_primaries)

### Loop over all primaries and get the properties of each merger

In [None]:
dm_mass_ratio = []
dm_mass_ratio_snapnum = []
merger_snapnum = []
star_mass_ratio = []
star_mass_ratio_snapnum = []
# star_mass = []
# dm_mass = []

verbose = True

for i in range(n_mw):
    # if i > 1: continue
    if verbose: print(f'Getting MW {i+1}/{n_mw}', end='\r')

    # Get the primary
    primary = tree_primaries[i]
    z0_sid = primary.subfind_id[0]
    n_snap = len(primary.snapnum)
    n_major = primary.n_major_mergers
    primary_filename = primary.get_cutout_filename(mw_analog_dir,
        snapnum=primary.snapnum[0])
    co = pcutout.TNGCutout(primary_filename)
    # dmpid = co.get_property('dm','ParticleIDs')
    # dmass = co.get_masses('dm').to_value(apu.Msun)
    # spid = co.get_property('stars','ParticleIDs')
    # smass = co.get_masses('stars').to_value(apu.Msun)

    _dm_mass_ratio = []
    _dm_mass_ratio_snapnum = []
    _merger_snapnum = []
    _star_mass_ratio = []
    _star_mass_ratio_snapnum = []
    # _star_mass = []
    # _dm_mass = []

    for j in range(n_major):
        # if verbose: print(f'Merger {j+1}/{n_major}')

        # Get the major merger particle IDs and mask
        major_merger = primary.tree_major_mergers[j]
        # dmupid = major_merger.get_unique_particle_ids('dm',data_dir=data_dir)
        # supid = major_merger.get_unique_particle_ids('stars',data_dir=data_dir)
        # dmindx = np.isin(dmpid, dmupid)
        # sindx = np.isin(spid, supid)

        _dm_mass_ratio.append(major_merger.dm_mass_ratio)
        _dm_mass_ratio_snapnum.append(major_merger.dm_mass_ratio_snapnum)
        _merger_snapnum.append(major_merger.merger_snapnum)
        _star_mass_ratio.append(major_merger.star_mass_ratio)
        _star_mass_ratio_snapnum.append(major_merger.star_mass_ratio_snapnum)
        # _star_mass.append( np.sum(smass[sindx]) )
        # _dm_mass.append( np.sum(dmass[dmindx]) )
    
    dm_mass_ratio.append(_dm_mass_ratio)
    dm_mass_ratio_snapnum.append(_dm_mass_ratio_snapnum)
    merger_snapnum.append(_merger_snapnum)
    star_mass_ratio.append(_star_mass_ratio)
    star_mass_ratio_snapnum.append(_star_mass_ratio_snapnum)


mass_data_path = '../5_compare_distribution_functions/data'
star_mass_filename = os.path.join(mass_data_path,'star_mass.pkl')
dm_mass_filename = os.path.join(mass_data_path,'dm_mass.pkl')
if os.path.exists(star_mass_filename):
    with open(os.path.join(star_mass_filename),'rb') as handle:
        star_mass = pickle.load(handle)
    star_mass = np.concatenate(star_mass).flatten()
else:
    warnings.warn('Failed to find star_mass.pkl')
if os.path.exists(dm_mass_filename):
    with open(os.path.join(dm_mass_filename),'rb') as handle:
        dm_mass = pickle.load(handle)
    dm_mass = np.concatenate(dm_mass).flatten()
else:
    warnings.warn('Failed to find dm_mass.pkl')

dm_mass_ratio = np.concatenate(dm_mass_ratio).flatten()
dm_mass_ratio_snapnum = np.concatenate(dm_mass_ratio_snapnum).flatten()
merger_snapnum = np.concatenate(merger_snapnum).flatten()
star_mass_ratio = np.concatenate(star_mass_ratio).flatten()
star_mass_ratio_snapnum = np.concatenate(star_mass_ratio_snapnum).flatten()

### Make the plot

In [None]:
# Column-sized figure width
columnwidth = 244./72.27 # In inches, from pt
# Full-sized figure width
textwidth = 508./72.27 # In inches, from pt

In [None]:
facecolor='none'
edgecolor='Black'
s = 10
xaxislabel_fs = 10
yaxislabel_fs = 9
ticklabel_fs = 8

# Create the figure layout
fig = plt.figure(figsize=(columnwidth,5))
ncols = 5
gs = fig.add_gridspec(nrows=9, ncols=ncols, hspace=0.2, wspace=0.1)
axs = []
for i in range(4):
    axs.append(fig.add_subplot(gs[int(2*i+1):int(2*i+3),:ncols-1]))
axt = fig.add_subplot(gs[0,:ncols-1])
raxs = []
for i in range(4):
    raxs.append(fig.add_subplot(gs[int(2*i+1):int(2*i+3),ncols-1]))

### Make the primary panels
axs[0].scatter( np.log10(star_mass), 
                np.log10(1/star_mass_ratio), 
    marker='o', s=10, facecolor=facecolor, edgecolor=edgecolor, alpha=0.5 )

axs[1].scatter( np.log10(star_mass), 
                np.log10(1/dm_mass_ratio),
    marker='o', s=10, facecolor=facecolor, edgecolor=edgecolor, alpha=0.5 )

axs[2].scatter( np.log10(star_mass), 
                np.log10(dm_mass), 
    marker='o', s=10, facecolor=facecolor, edgecolor=edgecolor, alpha=0.5 )

axs[3].scatter( np.log10(star_mass), 
                putil.snapshot_to_redshift(merger_snapnum), 
    marker='o', s=10, facecolor=facecolor, edgecolor=edgecolor, alpha=0.5 )
    
    # axs[4].scatter( star_mass[i], star_mass_ratio_snapnum[i], marker='o', s=10,
    #     facecolor=facecolor, edgecolor=edgecolor, alpha=0.5 )

# Axis scales and labels
for i in range(len(axs)):
    if i < 3:
        axs[i].tick_params(labelbottom=False, labelsize=ticklabel_fs)
    else:
        axs[i].tick_params(labelbottom=True, labelsize=ticklabel_fs)

_msun_txt = '\mathrm{M}_{\odot}'
axs[-1].set_xlabel(r'$\log_{10}( M_{\star} / '+_msun_txt+r')$', 
    fontsize=xaxislabel_fs)
axs[0].set_ylabel(r'$\log_{10}( m_{\star} )$', fontsize=yaxislabel_fs)
axs[1].set_ylabel(r'$\log_{10}( m_{\rm DM} )$', fontsize=yaxislabel_fs)
axs[2].set_ylabel(r'$\log_{10}( M_{\rm DM} / '+_msun_txt+r')$', 
    fontsize=yaxislabel_fs)
axs[3].set_ylabel(r'$z_{\rm merge}$', fontsize=yaxislabel_fs)

### Make the top panel - marginalized histogram of the stellar mass

dens = False
nbin = 10
haxc = 'Black'
htype = 'step'

axt.hist( np.log10(star_mass), bins=nbin, histtype=htype, color=haxc,
         orientation='vertical', density=dens )
axt.tick_params(labelbottom=False, labelsize=ticklabel_fs)
axt.set_ylim(0,20)
axt.yaxis.set_ticks([0,10,20])
if dens:
    axt.set_ylabel(r'$p(\cdot)$', fontsize=yaxislabel_fs)
else:
    axt.set_ylabel(r'$N$', fontsize=yaxislabel_fs)

### Make the right panels, marginalized histograms of the dependent quantities

orient = 'horizontal'

raxs[0].hist( np.log10(1/star_mass_ratio), 
    bins=nbin, histtype=htype, color=haxc, orientation='horizontal', 
    density=dens)

raxs[1].hist( np.log10(1/dm_mass_ratio), 
    bins=nbin, histtype=htype, color=haxc, orientation='horizontal', 
    density=dens )

raxs[2].hist( np.log10(dm_mass), 
    bins=nbin, histtype=htype, color=haxc, orientation='horizontal', 
    density=dens )

raxs[3].hist( putil.snapshot_to_redshift(merger_snapnum), 
    bins=nbin, histtype=htype, color=haxc, orientation='horizontal', 
    density=dens )

for i in range(len(raxs)):
    if dens:
        raxs[i].set_xlim(0,1.2)
    else:
        raxs[i].set_xlim(0,50)
    raxs[i].tick_params(labelleft=False, labelsize=ticklabel_fs)
    raxs[i].xaxis.set_ticks([0,25,50])
    if i < 3:
        raxs[i].tick_params(labelbottom=False)

if dens:
    raxs[-1].set_xlabel(r'$p(\cdot)$', fontsize=yaxislabel_fs)
else:
    raxs[-1].set_xlabel(r'$N$', fontsize=yaxislabel_fs)

fig.tight_layout()
fig.subplots_adjust(hspace=0.2)
fig.show()
fig.savefig(fig_dir+'merger_stats.pdf')


In [None]:
corner_data = np.vstack( (np.log10(star_mass),
                          np.log10(dm_mass),
                          np.log10(1/star_mass_ratio),
                          np.log10(1/dm_mass_ratio),
                          putil.snapshot_to_redshift(merger_snapnum)
                          ) )
corner_labels = [r'$\log_{10}( M_{\star} / '+_msun_txt+r')$',
                 r'$\log_{10}( M_{\rm DM} / '+_msun_txt+r')$',
                 r'$\log_{10}( m_{\star} )$',
                 r'$\log_{10}( m_{\rm DM} )$',
                 r'$z_{\rm merge}$']

M, N = corner_data.shape
lim_fac = 0.1
xaxislabel_fs = 10
yaxislabel_fs = 9
ticklabel_fs = 8

fig = plt.figure( figsize=(textwidth, textwidth) )
axs = fig.subplots(nrows=M, ncols=M)

for i in range(M):
    for j in range(M):

        if i == j: # Histogram
            # Plot
            axs[i,j].hist( corner_data[i,:], bins=10, histtype='step', 
                           edgecolor='Black'
                           )
            # Limits
            x_min, x_max = np.min(corner_data[i,:]), np.max(corner_data[i,:])
            x_size = x_max - x_min
            axs[i,j].set_xlim( x_min - lim_fac*x_size, x_max + lim_fac*x_size )
            # Labels
            axs[i,j].set_ylabel('N', fontsize=yaxislabel_fs)
            axs[i,j].yaxis.set_label_position('right')
            axs[i,j].tick_params(labelright=True)
            if i == M-1:
                axs[i,j].set_xlabel(corner_labels[i], fontsize=xaxislabel_fs)
                axs[i,j].tick_params(labelleft=False)
            else:
                axs[i,j].tick_params(labelbottom=False, labelleft=False)
        
        elif j < i: # Scatter
            # Plot
            axs[i,j].scatter( corner_data[j,:], corner_data[i,:], 
                              marker='o', s=10, facecolor=facecolor, 
                              edgecolor=edgecolor, alpha=0.5 
                              )
            
            # Limits
            x_min, x_max = np.min(corner_data[j,:]), np.max(corner_data[j,:])
            y_min, y_max = np.min(corner_data[i,:]), np.max(corner_data[i,:])
            x_size = x_max - x_min
            y_size = y_max - y_min
            axs[i,j].set_xlim( x_min - lim_fac*x_size, x_max + lim_fac*x_size )
            axs[i,j].set_ylim( y_min - lim_fac*y_size, y_max + lim_fac*y_size )
            # Labels
            if i == M-1:
                axs[i,j].set_xlabel(corner_labels[j], fontsize=xaxislabel_fs)
            else:
                axs[i,j].tick_params(labelbottom=False)
            if j == 0 and i > 0:
                axs[i,j].set_ylabel(corner_labels[i], fontsize=yaxislabel_fs)
            else:
                axs[i,j].tick_params(labelleft=False)

        else: # Nothing, duplicate of lower triangle
            axs[i,j].axis('off')

fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.show()