In [None]:
# import misc
# import copy

import holodeck as holo
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from holodeck import utils, cosmo
from holodeck import plot as hplot
import holodeck.accretion
from holodeck.constants import MSOL, PC, YR, MPC, GYR, MYR
import gwb as psgwb
# import seaborn as sns

holo.log.setLevel(holo.log.INFO)

In [None]:
ECCEN_INIT = 0.5
F_EDD = 0.5
NSTEPS = 100

pop = holo.population.Pop_Illustris()
size = pop.size
eccen = np.ones(size) * ECCEN_INIT
pop = holo.population.Pop_Illustris(eccen=eccen)

# redz = cosmo.a_to_z(pop.scafa)

hards = [
    holo.hardening.Hard_GW,
    holo.hardening.CBD_Torques(),
    holo.hardening.Sesana_Scattering(),
    holo.hardening.Dynamical_Friction_NFW(attenuate=True),
]

acc = holo.accretion.Accretion(accmod='Siwek22', f_edd=F_EDD, subpc=True, evol_mass=True, edd_lim=1.0)

In [None]:
evo._last_index

In [None]:
evo = holo.evolution.Evolution(pop, hards, debug=True, acc=acc)
evo.evolve()

In [None]:
fig, ax = holo.plot.figax()
beg = 0
bin_vals = range(evo._size)
# bin_vals = [3]
for bin in bin_vals:
    if bin > 0:
        beg = evo._last_index[bin-1] + 1
    end = evo._last_index[bin] + 1
    if beg >= end-1:
        break
    print(beg, end)
    xx = evo.tlook[beg:end]/GYR
    print(xx)
    yy = evo.sepa[beg:end]/PC
    ax.plot(xx, yy, marker='x', label=bin)
    beg = end - 1
    
    
plt.legend()
plt.show()

In [None]:
breaker()

# Generate evolution objects for populations with and without CBD influence

In [None]:
keys = ['no_doteb', 'doteb']
f_edd = 0.1
nsteps = 100
eccen_init=0.01
keys, evol_dict = misc.generate_evol(keys, f_edd=f_edd, eccen_init=eccen_init, ecc_test=True, nsteps=nsteps)

In [None]:
fig, ax = holo.plot.figax()
ax.grid(True, alpha=0.25)
fobs, _ = utils.pta_freqs(num=40)
nreals=30
plot_nanograv23 = True

lw = 3
ticksize = 5
tickwidth = 2
alpha = 1.0
fs = 20

for key in keys:
    evol = evol_dict[key]
    gwb = holo.gravwaves.GW_Discrete(evol, fobs, nreals=nreals)
    gwb.emit(eccen=True)
    median_gwb = np.median(gwb.both, axis=-1)

    if 'ls_%s' %key in evol_dict.keys():
        ls = evol_dict['ls_%s' %key]
    else:
        ls = '-'

    cc, = ax.plot(fobs, median_gwb, label=evol_dict['label_%s' %key], \
                    color=evol_dict['color_%s' %key], linewidth=lw, \
                    linestyle=ls)
    conf = np.percentile(gwb.both, [25, 75], axis=-1)
    ax.fill_between(fobs, *conf, color=cc.get_color(), alpha=0.1)

    twin_ax = hplot._twin_yr(ax, nano=False, fs=fs, label=False)

    plt.setp(twin_ax.get_xticklabels(which='both'), fontsize=fs, rotation=0)
    plt.setp(twin_ax.get_yticklabels(which='both'), fontsize=fs)
    twin_ax.tick_params(axis='both', which='major', direction='inout', size=ticksize, width=tickwidth)
    twin_ax.tick_params(axis='both', which='minor', direction='inout', size=0.7*ticksize, width=0.7*tickwidth)

    plt.setp(ax.get_xticklabels(which='both'), fontsize=fs, rotation=0)
    plt.setp(ax.get_yticklabels(which='both'), fontsize=fs)
    ax.tick_params(axis='both', which='major', direction='inout', size=ticksize, width=tickwidth)
    ax.tick_params(axis='both', which='minor', direction='inout', size=0.7*ticksize, width=0.7*tickwidth)

if plot_nanograv23:
    f_det = 1./YR
    amp_det = 2.4*10**(-15)
    err_det = np.array([[0.6*10**(-15),0.7*10**(-15)]]).T
    ax.errorbar(f_det, amp_det, yerr=err_det, color = 'royalblue', \
            marker='*', markersize=4*lw, mew=lw, label='NANOGrav\n(2023)')

twin_ax.set_xlabel(r'$f_{\rm GW} \ [\rm{yr}^{-1}]$', fontsize=fs)
ax.set_xlabel(r'$f_{\rm GW} \ [\rm{Hz}]$', fontsize=fs)
ax.legend(fontsize=0.8*fs, ncol=max(1,int(len(keys)/2.)+1), loc='lower center')
fig.tight_layout()

In [None]:
evo_neg_cbd = evol_dict['no_doteb']
evo_pos_cbd = evol_dict['doteb']
evo_neg_cbd, evo_pos_cbd

In [None]:
pars = evo_neg_cbd._EVO_PARS

In [None]:
for pp in pars:
    print(pp)
    for evo in [evo_neg_cbd, evo_pos_cbd]:
        scafa = evo.scafa[:, -1]
        sel = (scafa < 1.0)
        vv = getattr(evo, pp)[sel, -1]
        print(vv.shape, utils.stats(vv))

    # print()
    # scafa = evo.scafa[:, -1]
    # sel = (scafa < 1.0)
    # print(f"{utils.frac_str(sel)=}")

    # mass = evo.mass[:, -1, :]
    # print("all")
    # for mm in mass.T:
    #     print(utils.stats(mm))

    # print("coal")
    # mass = evo.mass[sel, -1, :]
    # for mm in mass.T:
    #     print(utils.stats(mm))

    # mc = utils.chirp_mass(*mass.T)
    # print(f"{utils.stats(mc/MSOL)=}")


In [None]:
fobs, _ = utils.pta_freqs(dur=100*YR, num=100)
nreals = 10

In [None]:
fig, ax = holo.plot.figax()

gwbs = []
for evo in [evo_neg_cbd, evo_pos_cbd]:
    gwb = holo.gravwaves.GW_Discrete(evo, fobs, nreals=nreals)
    gwb.emit(eccen=True)
    hc = gwb.both
    hc = np.median(hc, axis=-1)
    ax.plot(fobs, hc)
    gwbs.append(gwb)
    
plt.show()


In [None]:
num = 6
fig, axes = holo.plot.figax(figsize=[10, 5], ncols=2, ylim=[1e-20, 1e-13])

for ax, gwb in zip(axes, gwbs):
    hc = gwb.harms
    _, nharms = hc.shape
    colors = mpl.cm.get_cmap('viridis')(np.linspace(0.0, 1.0, nharms-num))

    ax.plot(fobs, np.sqrt(np.sum(hc, axis=1)), 'k--', alpha=0.5)
    for ii in range(num):
        ax.plot(fobs, np.sqrt(hc[:, ii]), alpha=0.8, label=ii+1)    

    for ii, col in enumerate(colors):
        jj = num + ii
        lab = jj+1 if jj%10 == 0 else None
        ax.plot(fobs, np.sqrt(hc[:, jj]), color=col, alpha=0.8, label=lab)
        
ax.legend()
plt.show()

In [None]:
fig, ax = holo.plot.figax()

gwbs = []
for eccen in [True, False]:
    gwb = holo.gravwaves.GW_Discrete(evo_pos_cbd, fobs, nreals=nreals)
    gwb.emit(eccen=eccen)
    hc = gwb.both
    hc = np.median(hc, axis=-1)
    ax.plot(fobs, hc)
    gwbs.append(gwb)
    
plt.show()


In [None]:
num = 6
fig, axes = holo.plot.figax(figsize=[10, 5], ncols=2, ylim=[1e-20, 1e-13])

for ax, gwb in zip(axes, gwbs):
    hc = gwb.harms
    _, nharms = hc.shape


    ax.plot(fobs, np.sqrt(np.sum(hc, axis=1)), 'k--', alpha=0.5)

    if nharms == 1:
        continue
    
    for ii in range(num):
        ax.plot(fobs, np.sqrt(hc[:, ii]), alpha=0.8, label=ii+1)    

    colors = mpl.cm.get_cmap('viridis')(np.linspace(0.0, 1.0, nharms-num))

    for ii, col in enumerate(colors):
        jj = num + ii
        lab = jj+1 if jj%10 == 0 else None
        ax.plot(fobs, np.sqrt(hc[:, jj]), color=col, alpha=0.8, label=lab)
        
ax.legend()
plt.show()