In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py

from holodeck import plot, detstats, utils
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

In [None]:
fobs_cents, fobs_edges = utils.pta_freqs()
sam = holo.sams.Semi_Analytic_Model() 
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, 3*GYR)


## Get Strain Model

In [None]:
NLOUDEST = 10
NREALS = 30

In [None]:
hc_ss, hc_bg, sspar, bgpar = sam.gwb(fobs_gw_edges=fobs_edges, realize=NREALS, 
                                     loudest=NLOUDEST, params=True)

# Get Psrs Model

In [None]:
NSKIES = 25
NPSRS = 10

In [None]:
# for the 0th realization
psrs, sigmin, sigmax, sigma = detstats.calibrate_one_pta(hc_bg[:,0], fobs_cents, NPSRS, ret_sig=True)
print(f"{sigmin=}, {sigmax=}, {sigma=}")

## Noise Spectral Density
Noise from all but one source

In [None]:
Sh_rest = detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents)
print(f"{Sh_rest.shape=}")

In [None]:
xx = fobs_cents * YR
y1 = Sh_rest[:,:,0]
label1 = 'S_h,rest (all but 1st loudest)'

y2 = Sh_rest[:,:,9]
label2 = 'S_h,rest (all but 10th loudest)'

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Power Spectral Density (Hz$^{-1}$)')
h1 = plot.draw_med_conf(ax, xx, y1)
h2 = plot.draw_med_conf(ax, xx, y2)
ax.legend(handles=[h1, h2], labels=[label1, label2])

In [None]:
# Sh_rest = hc^2 / (12 pi^2 freqs^3)

# power spectral density of single loudest source

l1 = detstats._power_spectral_density(hc_ss[:,:,0], fobs_cents)
l2 = detstats._power_spectral_density(hc_ss[:,:,9], fobs_cents)
label3 = '1st loudest'
label4 = '10th loudest'

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Power Spectral Density')
h1 = plot.draw_med_conf(ax, xx, y1)
h2 = plot.draw_med_conf(ax, xx, y2)
c1 = h1[0].get_color()
c2 = h2[0].get_color()
ax.legend(handles=[h1, h2], labels=[label1, label2])
for rr in range(NREALS):
    h3 = ax.scatter(xx, l1[:,rr], ec=c1, fc=None, alpha=0.1)
    h4 = ax.scatter(xx, l2[:,rr], ec=c2, fc=None, alpha=0.1)

ax.legend(handles=[h1, h2, h3, h4], labels=[label1, label2, label3, label4])

In [None]:
dur = 1.0/fobs_cents[0]
cad = 1.0/(2*fobs_cents[-1])
sigmas = np.ones(NPSRS)*sigma

rr=0

y1 = detstats._total_noise(cad, sigmas, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], fobs_cents)[0,:,:,0]
y2 = detstats._power_spectral_density(hc_ss[:,rr:rr+1,0], fobs_cents)
y3 = detstats._white_noise(cad, sigmas) # (P,)
label1 = 'total noise'
label2 = 'hc-rest noise'
label3 = 'white noise'
# label4 = '10th loudest'
print(f"{y1.shape=}, {y2.shape=}, {y3.shape=}")

fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_NHZ, ylabel='Noise Spectral Density (Hz$^{-3}$)')
h1 = plot.draw_med_conf(ax, xx, y1)
h2 = plot.draw_med_conf(ax, xx, y2)
h3 = ax.axhline(y3[0], linestyle='--', color='k', alpha=0.5)
# c1 = h1[0].get_color()
# c2 = h2[0].get_color()
# ax.legend(handles=[h1, h2], labels=[label1, label2])
# for rr in range(NREALS):
#     h3 = ax.scatter(xx, l1[:,rr], ec=c1, fc=None, alpha=0.1)
#     h4 = ax.scatter(xx, l2[:,rr], ec=c2, fc=None, alpha=0.1)

ax.legend(handles=[h1, h2, h3,], labels=[label1, label2, label3,])

In [None]:
print(l3)

## Total noise (Sh_rest + white noise)

## SNR_ss

# Detection Probability

In [None]:
dp_ss, snr_ss, gamma_ssi = detstats.detect_ss_pta(psrs, fobs_cents, hc_ss, hc_bg, 
                                                  nskies=NSKIES, ret_snr=True)
dp_bg, snr_bg = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg, ret_snr=True)

print(f"{dp_ss.shape=}, {snr_ss.shape=}, {dp_bg.shape=}, {snr_bg.shape=}")

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Detection Probability')

xx = fobs_cents*YR
y1 = dp_bg[0] # 1
y2 = gamma_ssi[:,0,:,0] # F, S
y3 = dp_ss[0] # S
label1 = 'BG Detprob'
label2 = '1st Loudest Detprob, for each sky realization'
label3 = 'Overall SS Detprob, for each sky realization'

h1 = ax.axhline(y1)
h2 = plot.draw_med_conf_color(ax, xx, y2, color='orange')
for ss in range(NSKIES):
    h3 = ax.axhline(y3[ss], color='tab:red', alpha=0.2)

ax.legend(handles=[h1,h2,h3], labels=[label1, label2, label3])

# SNR

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='SNR')

xx = fobs_cents*YR
y1 = snr_bg[0] # F,
y2 = snr_ss[:,0,:,0] # F, S
# y3 = dp_ss[0] # S
label1 = 'BG SNR'
label2 = '1st Loudest SNR, for each sky realization'

h1 = ax.axhline(y1)
h2 = plot.draw_med_conf_color(ax, xx, y2, color='orange')
# for ss in range(NSKIES):
#     h3 = ax.axhline(y3[ss], color='tab:red', alpha=0.2)

ax.legend(handles=[h1,h2,], labels=[label1, label2])