In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import warnings
import numpy as np
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py

import kalepy as kale

import holodeck as holo
from holodeck.constants import MSOL, PC, YR, MPC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 12})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

In [None]:
PATH_DATA = "/Users/lzkelley/Research/NANOGrav/holodeck/data_external/"

FNAME_ILLUSTRIS_DATA = "illustris-galaxy-mergers_L75n1820FP_gas-100_dm-100_star-100_bh-000.hdf5"
FNAME_MCCONNELL_MA_2013 = "mcconnell+ma-2013_1211.2816.txt"

FNAME_ILLUSTRIS_DATA = os.path.join(PATH_DATA, FNAME_ILLUSTRIS_DATA)
FNAME_MCCONNELL_MA_2013 = os.path.join(PATH_DATA, FNAME_MCCONNELL_MA_2013)

_fnames = [
    FNAME_ILLUSTRIS_DATA, FNAME_MCCONNELL_MA_2013    
]

for fn in _fnames:
    if not os.path.isfile(fn):
        err = "Could not find '{}'!\n{}".format(os.path.basename(fn), fn)
        raise FileNotFoundError(err)

## Utility Functions

In [None]:
def plot_bin_pop(bin_pop):
    mt, mr = holo.utils.mtmr_from_m1m2(bin_pop.mass)
    redz = holo.utils.a_to_z(bin_pop.time)
    data = [mt/MSOL, mr, bin_pop.sepa/PC, 1+redz]
    data = [np.log10(dd) for dd in data]
    reflect = [None, [None, 0], None, [0, None]]
    labels = ['M', 'q', 'a', '1+z']
    if bin_pop.eccs is not None:
        data.append(bin_pop.eccs)
        reflect.append([0.0, 1.0])
        labels.append('e')
    
    kde = kale.KDE(data, reflect=reflect)
    corner = kale.Corner(kde, labels=labels, figsize=[8, 8])
    corner.plot_data(kde)
    return corner


def plot_mbh_scaling_relations(pop, fname=None, color='r'):
    fig, ax = plt.subplots(figsize=[8, 5])

    #   ====    Plot McConnell+Ma-2013 Data    ====
    handles = []
    names = []
    if fname is not None:
        hh = _draw_mm13_data(ax, fname)
        handles.append(hh)
        names.append('McConnell+Ma')

    #   ====    Plot MBH Merger Data    ====
    hh, nn = _draw_pop_masses(ax, pop, color)
    handles = handles + hh
    names = names + nn
    ax.legend(handles, names)
    
    return fig


def _draw_mm13_data(ax, fname):
    data = holo.observations.load_mcconnell_ma_2013(fname)
    data = {kk: data[kk] if kk == 'name' else np.log10(data[kk]) for kk in data.keys()}
    key = 'mbulge'
    mass = data['mass']
    yy = mass[:, 1]
    yerr = np.array([yy - mass[:, 0], mass[:, 2] - yy])
    vals = data[key]
    if np.ndim(vals) == 1:
        xx = vals
        xerr = None
    elif vals.shape[1] == 2:
        xx = vals[:, 0]
        xerr = vals[:, 1]
    elif vals.shape[1] == 3:
        xx = vals[:, 1]
        xerr = np.array([xx-vals[:, 0], vals[:, 2]-xx])
    else:
        raise ValueError()

    idx = (xx > 0.0) & (yy > 0.0)
    if xerr is not None:
        xerr = xerr[:, idx]
    ax.errorbar(xx[idx], yy[idx], xerr=xerr, yerr=yerr[:, idx], fmt='none', zorder=10)
    handle = ax.scatter(xx[idx], yy[idx], zorder=10)
    ax.set(ylabel='MBH Mass', xlabel=key)

    return handle


def _draw_pop_masses(ax, pop, color='r', skip=4):
    xx = pop.mbulge.flatten() / MSOL
    yy_list = [pop.mass]
    names = ['new']
    if hasattr(pop, '_mass'):
        yy_list.append(pop._mass)
        names.append('old')
    
    colors = [color, '0.5']
    handles = []
    if skip > 1:
        print("Plotting every {}th data-point".format(skip))
    cut = slice(None, None, skip)
    for ii, yy in enumerate(yy_list):
        yy = yy.flatten() / MSOL
        data = np.log10([xx[cut], yy[cut]])
        kale.plot.dist2d(
            data, ax=ax, color=colors[ii], hist=False, contour=True,
            median=True, mask_dense=True,
        )
        hh, = plt.plot([], [], color=colors[ii])
        handles.append(hh)

    return handles, names


def plot_gwb(gwb):
    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(xscale='log', yscale='log')
    ax.grid(True)

    draw_gwb_sample(ax, gwb, color='b', alpha=0.1)
    draw_gwb_conf(ax, gwb)
    draw_plaw(ax, gwb.freqs*YR, f0=1)

    return fig


def draw_gwb_sample(ax, gwb, num=10, back=True, fore=False, **kwargs):
    freqs = gwb.freqs * YR
    _back = gwb.eccen_back
    kwargs.setdefault('alpha', 0.5)
    kwargs.setdefault('lw', 1.0)
    cut = np.random.choice(_back.shape[1], 10, replace=False)

    if back:
        ax.plot(freqs, _back[:, cut], **kwargs)
        
    if fore:
        _fore = gwb.eccen_fore[:, cut]
        ax.scatter(freqs[:, np.newaxis] * np.ones_like(_fore), _fore, **kwargs)
    
    return 


def draw_gwb_conf(ax, gwb, conf=[0.25, 0.75], **kwargs):
    freqs = gwb.freqs * YR
    back = gwb.eccen_back
    kwargs.setdefault('alpha', 0.5)
    kwargs.setdefault('lw', 0.5)
    conf = np.percentile(back, 100*np.array(conf), axis=-1)
    return ax.fill_between(freqs, *conf, **kwargs)


def draw_plaw(ax, freqs, amp=1e-15, f0=1/YR, **kwargs):
    kwargs.setdefault('alpha', 0.5)
    kwargs.setdefault('color', '0.5')
    kwargs.setdefault('ls', '--')
    plaw = amp * np.power(freqs/f0, -2/3)
    return ax.plot(freqs, plaw, **kwargs)    
    

# Binary Population

## Construct Illustris-Based Binary Population

In [None]:
bin_pop = holo.BP_Illustris(FNAME_ILLUSTRIS_DATA)

ill_name = os.path.basename(bin_pop._fname).split('_')[1]
print("Loaded", bin_pop.size, "binaries from Illustris", ill_name)

In [None]:
plot_bin_pop(bin_pop)
plt.show()

### Apply a modifier to add (arbitrary) eccentricities, and resample 5x

In [None]:
mod_ecc = holo.PM_Eccentricity()
mod_resamp = holo.PM_Resample(resample=2.0)

mods = [mod_ecc, mod_resamp]
bin_pop.modify(mods)
    
print("Population now has", bin_pop.size, "elements")

In [None]:
plot_bin_pop(bin_pop)
plt.show()

### Apply Modifer to Use McConnell+Ma 2013 BH masses

In [None]:
# # Create the modifier using M-Mbulge relation
# mod_mm13 = holo.PM_MM13(relation='mbulge')

# # Choose percentiles
# percs = 100*sp.stats.norm.cdf([-1, 0, 1])
# percs = [0,] + percs.tolist() + [100,]

# # Format nicely
# str_array = lambda xx: ", ".join(["{:.2e}".format(yy) for yy in xx])
# str_masses = lambda xx: str_array(np.percentile(xx/MSOL, percs))

# # Modify population
# print("Masses before: ", str_masses(bin_pop.mass))
# bin_pop.modify(mod_mm13)
# print("Masses after : ", str_masses(bin_pop.mass))
    
# plot_mbh_scaling_relations(bin_pop, fname=FNAME_MCCONNELL_MA_2013)
# plt.show()

# Binary Evolution

In [None]:
bin_evo = holo.BE_Magic_Delay(bin_pop)
bin_evo.evolve()

# Calculate GWB

In [None]:
freqs = holo.utils.nyquist_freqs(10.0, 1.0)
gwb = holo.GWB(bin_evo, freqs/YR)

In [None]:
plot_gwb(gwb)
plt.show()

In [None]:
# NO ECC

_calc_mc_at_fobs()
fobs=3.17e-09, temp =(6.28e-47, 3.02e-44, 2.14e-42, 7.63e-40, 6.89e-32), for (0%, 16%, 50%, 84%, 100%)
fobs=3.17e-09, hs2  =(6.28e-47, 3.02e-44, 2.14e-42, 7.63e-40, 6.89e-32), for (0%, 16%, 50%, 84%, 100%)
fobs=3.17e-09, gne  =(1, 1, 1, 1, 1), for (0%, 16%, 50%, 84%, 100%)
fobs=3.17e-09, harms=(2, 2, 2, 2, 2), for (0%, 16%, 50%, 84%, 100%)



In [None]:
import zcode.math as zmath

In [None]:
bin_pop = holo.BP_Illustris(FNAME_ILLUSTRIS_DATA)

mods = []
# mod_ecc = holo.PM_Eccentricity([1.0e-4, 0.01])
mod_ecc = holo.PM_Eccentricity()
mods.append(mod_ecc)
mod_resamp = holo.PM_Resample(resample=5.0)
mods.append(mod_resamp)
mod_mm13 = holo.PM_MM13(relation='mbulge')
mods.append(mod_mm13)

bin_pop.modify(mods)

# bin_evo = holo.evolution.BE_Magic_Delay_Circ(bin_pop)
bin_evo = holo.evolution.BE_Magic_Delay_Eccen(bin_pop, time_delay=1e9*YR, nsteps=100)
bin_evo.evolve()

freqs = holo.utils.nyquist_freqs(10.0, 1.0)
gwb = holo.GWB(bin_evo, freqs/YR, nreals=30)

plot_gwb(gwb)
plt.show()

In [None]:
plot_sepa_eccen(bin_evo)
plt.show()

In [None]:
import zcode.plot as zplot

def plot_sepa_eccen(evo):
    
    xx = np.logspace(-3, 1, 20) / YR

    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(xscale='log', yscale='log')

    data = evo.at('fobs', xx)

    xx = xx * YR
    
    def draw(ax, name, color=None, units=1.0):
        # vals = getattr(data, vals)
        vals = data[name]
        if vals is None:
            return
        vals = vals / units
        vals = np.percentile(vals, [25, 50, 75], axis=0)
        hh, = ax.plot(xx, vals[1], color=color)
        ax.fill_between(xx, vals[0], vals[-1], alpha=0.25, color=hh.get_color())                
        ax.set_ylabel(name)
        
    draw(ax, 'sepa', 'blue', PC)

    col = 'green'
    ax = zplot.twin_axis(ax, pos=1.0, scale='lin', color=col)
    draw(ax, 'eccen', col)

    col = 'red'
    ax = zplot.twin_axis(ax, pos=1.1, scale='log', color=col)
    draw(ax, 'dadt', col, PC/(1e9*YR))

    return

plot_sepa_eccen(bin_evo)
plt.show()

In [None]:
def plot_sepa_eccen(evo):
    fig, axes = plt.subplots(figsize=[10, 5], ncols=2)

    ax = axes[0]
    ax.set(xscale='log')
    ax.scatter(evo.sepa/PC, evo.eccen)

    ax = axes[1]
    ax.set(xscale='log', yscale='log')
    ax.scatter(evo.dadt, evo.dedt)
    print("dadt = ", holo.utils.minmax(evo.dadt))
    print("dedt = ", holo.utils.minmax(evo.dedt))
    
    return

plot_sepa_eccen(bin_evo)
plt.show()