In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
import holodeck.gravwaves
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

# Quick: Population, Evolution, GW Spectrum

In [None]:
# ---- Create initial population

pop = holo.population.Pop_Illustris()

# ---- Apply population modifiers

# resample to increase the number of binaries
mod_resamp = holo.population.PM_Resample(resample=2.0)
# modify population (in-place)
pop.modify(mod_resamp)

# ---- Evolve binary population

# create a fixed-total-time hardening mechanism
fixed = holo.evolution.Fixed_Time.from_pop(pop, 2.0 * GYR)
# add population and hardening to an evolution instance
evo = holo.evolution.Evolution(pop, fixed)
# evolve binary population
evo.evolve()

# ---- Calculate and Plot GWB

# construct sampling frequencies
freqs = holo.utils.nyquist_freqs(dur=20.0*YR, cad=0.3*YR)
# calculate discretized GW signals
gwb = holo.gravwaves.GW_Discrete(evo, freqs, nreals=30)
gwb.emit()

plot.plot_gwb(gwb)
plt.show()

# Step-by-Step

## Construct Illustris-Based Binary Population

In [None]:
pop = holo.population.Pop_Illustris()
ill_name = os.path.basename(pop._fname).split('_')[1]
print("Loaded", pop.size, "binaries from Illustris", ill_name)

In [None]:
plot.plot_bin_pop(pop)
plt.show()

### Apply a modifier to resample binary population

In [None]:
mod_resamp = holo.population.PM_Resample(resample=2.0)
pop.modify(mod_resamp)
print("Population now has", pop.size, "elements")

### Apply Modifer to Use McConnell+Ma 2013 BH masses

In [None]:
# Create the modifier using M-Mbulge relation
# mod_mm13 = holo.PM_MM13(relation='mbulge')
mmbulge = holo.relations.MMBulge_MM13()
mod_mm13 = holo.population.PM_Mass_Reset(mmbulge, scatter=True)

# Choose percentiles
percs = 100*sp.stats.norm.cdf([-1, 0, 1])
percs = [0,] + percs.tolist() + [100,]

# Format nicely
str_array = lambda xx: ", ".join(["{:.2e}".format(yy) for yy in xx])
str_masses = lambda xx: str_array(np.percentile(xx/MSOL, percs))

# Modify population
print("Masses before: ", str_masses(pop.mass))
pop.modify(mod_mm13)
print("Masses after : ", str_masses(pop.mass))
    
plot.plot_mbh_scaling_relations(pop)
plt.show()

# Binary Evolution

In [None]:
# Set timescale for all binaries to merge over
fix_time = 2.0 * GYR
# Construct 'hardening' instance for this fixed time
fixed = holo.evolution.Fixed_Time.from_pop(pop, fix_time)
# Construct evolution instance using fixed time hardening
evo = holo.evolution.Evolution(pop, fixed)
# Evolve population
evo.evolve()

## Compare resulting lifetimes to targeted lifetime

Make sure that the resulting evolution timescale is consistent with the desired timescale (`fix_time`).
It's okay if there is some difference (a few percent) as the method is approximate.

In [None]:
# Calculate the total lifetime of each binary
time = evo.tlbk
dt = time[:, 0] - time[:, -1]

# Create figure
fig, ax = plot.figax(scale='lin', xlabel='Time: actual/specified', ylabel='density')
# use kalepy to plot distribution
kale.dist1d(dt/fix_time, density=True)

plt.show()

## Plot Hardening Rate vs. Separation

In [None]:
# Create spacing in separation (xaxis) to plot against
sepa = np.logspace(-4, 4, 100) * PC
# Plot hardening rates
plot.plot_evo(evo, sepa=sepa)

plt.show()

## Plot Hardening Rate vs. Frequency

In [None]:
# Create frequency spacing (xaxis) to plot against
# freqs = np.logspace(-2, 1, 20) / YR
fobs = holo.utils.nyquist_freqs(20.0, 0.3) / YR

plot.plot_evo(evo, freqs=fobs)
plt.show()

# Calculate GWB

In [None]:
# freqs = holo.utils.nyquist_freqs(20.0, 0.3) / YR
gwb = holo.gravwaves.GW_Discrete(evo, fobs, nreals=5)
gwb.emit()

In [None]:
plot.plot_gwb(gwb)
plt.show()

# Sample Full Light-Cone Population of Binaries

In [None]:
pop = holo.population.Pop_Illustris()
mod_resamp = holo.population.PM_Resample(resample=2.0)
pop.modify(mod_resamp)

mmbulge = holo.relations.MMBulge_MM13(mamp=7e8*MSOL)
mod_mm13 = holo.population.PM_Mass_Reset(mmbulge, scatter=True)
pop.modify(mod_mm13)

fixed = holo.evolution.Fixed_Time.from_pop(pop, 2.0 * GYR)
evo = holo.evolution.Evolution(pop, fixed)
evo.evolve()

In [None]:
fobs = holo.utils.nyquist_freqs(20.0 * YR, 0.2 * YR)

In [None]:
gwb = holo.gravwaves.GW_Discrete(evo, fobs, nreals=10)
gwb.emit()

plot.plot_gwb(gwb)
plt.show()

In [None]:
# samples = evo.sample_full_population(fobs, DOWN=None)
# hs, fo = holo.sam._strains_from_samples(samples)

# nloud = 5
# colors = plot._get_cmap('plasma')(np.linspace(0.05, 0.95, nloud))# print(colors)

# fig, ax = plot.figax(figsize=[12, 8], xlabel='Frequency [yr$^{-1}$]', ylabel='c-Strain')
# for ii in utils.tqdm(range(fobs.size-1)):
#     # if ii < 10 or ii > 16:
#     #     continue
    
#     fextr = [fobs[ii+jj] for jj in range(2)]
#     fextr = np.asarray(fextr)
#     cycles = 1.0 / np.diff(np.log(fextr))[0]

#     idx = (fextr[0] <= fo) & (fo < fextr[1])
#     hs_bin = hs[idx]
#     fo_bin = fo[idx]    

#     tot = np.sqrt(np.sum(cycles * hs_bin**2))
#     ax.plot(fextr*YR, tot * np.ones_like(fextr), 'k--')

#     idx = np.argsort(hs_bin)[::-1]
    
#     for jj, cc in enumerate(colors):
#         if jj > len(idx):
#             break
#         hi = idx[jj]
#         lo = idx[jj+1:]
#         gw_hi = np.sqrt(np.sum(cycles * hs_bin[hi]**2))
#         gw_lo = np.sqrt(np.sum(cycles * hs_bin[lo]**2))

#         fave = np.average(fo_bin[hi], weights=hs_bin[hi])
#         ax.plot(fextr*YR, gw_lo * np.ones_like(fextr), color=cc, lw=0.5)
#         ax.scatter(fave*YR, gw_hi, marker='.', color=cc, alpha=0.5)

# plt.show()