In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import os

from holodeck import plot, detstats, utils
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

import hasasia.sensitivity as hsen
import hasasia.sim as hsim
import hasasia.skymap as hsky

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10

SAVEFIG = False
TOL=0.01
MAXBADS=5

RED_GAMMA = None
RED2WHITE = None

NVARS = 21

NPSRS = 40
NSKIES = 100
TARGET = 'gsmf_phi0' # EDIT AS NEEDED
TITLE = plot.PARAM_KEYS[TARGET]  # EDIT AS NEEDED

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz',     
):
    load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()
    print(target, "got data")

    return data, params

In [None]:
data, params = get_data(TARGET)

In [None]:
var=10

fobs_cents = data[var]['fobs_cents']
dur = 1.0 / fobs_cents[0]
cad = 1.0 / (2.0 * fobs_cents[-1])
hc_ss = data[var]['hc_ss']
hc_bg = data[var]['hc_bg']

In [None]:
phis = None
thetas = None
npsrs = NPSRS
red2white = RED2WHITE
sigstart = 1.5e-7

# randomize pulsar positions
if phis is None: phis = np.random.uniform(0, 2*np.pi, size = npsrs)
if thetas is None: thetas = np.random.uniform(np.pi/2, np.pi/2, size = npsrs)
sigma = sigstart
# if red2white is not None:
#     red_amp = _white_noise(cad, sigma) * red2white

psrs = hsim.sim_pta(timespan=dur/YR, cad=1/(cad/YR), sigma=sigma,
                phi=phis, theta=thetas)

In [None]:
dp_bg = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg=hc_bg[:,np.newaxis], hc_ss=hc_ss[:,np.newaxis,:],
                            red_amp=None, red_gamma=None, ss_noise=False)[0]
print(holo.utils.stats(dp_bg))

In [None]:
spectra = []
for psr in psrs:
    sp = hsen.Spectrum(psr, freqs=fobs_cents)
    sp.NcalInv
    spectra.append(sp)

In [None]:
sc_ss = hsen.DeterSensitivityCurve(spectra).h_c
sc_bg = hsen.GWBSensitivityCurve(spectra).h_c

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN)
xx = fobs_cents*YR
ax.plot(xx, sc_bg, label='BG Sensitivity', linestyle='--', color='tab:blue')
ax.plot(xx, sc_ss, label='SS Sensitivity', linestyle='--', color='tab:pink')

ax.fill_between(xx, *np.percentile(hc_bg, (25,75), axis=-1), 
                color='tab:blue', alpha=0.2)
ax.plot(xx, np.median(hc_bg, axis=-1), label='hc_BG', color='tab:blue')

ax.fill_between(xx, *np.percentile(hc_ss[:,:,0], (25, 75), axis=-1),
                color='tab:pink', alpha=0.2)
ax.plot(xx, np.median(hc_ss[:,:,0], axis=-1), label='hc_SS', color='tab:pink')
ax.legend()

# Calculate detstats

In [None]:
# def detect_ss_pta_custom_noise(pulsars, fobs, hc_ss, hc_bg, noise,
#               theta_ss=None, phi_ss=None, Phi0_ss=None, iota_ss=None, psi_ss=None, nskies=25, 
#               Fe_bar = None, red_amp=None, red_gamma=None, alpha_0=0.001, Fe_bar_guess=15,
#               ret_snr=False, print_nans=False, snr_cython=True, gamma_cython=True, grid_path=detstats.GAMMA_RHO_GRID_PATH):
#     """ Calculate the single source detection probability, and all intermediary steps for
#     R strain realizations and S sky realizations.

#     Parameters
#     ----------
#     pulsars : (P,) list of hasasia.Pulsar objects
#         A set of pulsars generated by hasasia.sim.sim_pta()
#     fobs : (F,) 1Darray of scalars
#         Observer frame gw frequency bin centers in Hz.
#     hc_ss : (F,R,L) NDarray of scalars
#         Characteristic strain of the L loudest single sources at
#         each frequency, for R realizations.
#     hc_bg : (F,R)
#         Characteristic strain of the background at each frequency,
#         for R realizations.
#     theta_ss : (F,S,L) NDarray
#         Polar (latitudinal) angular position in the sky of each single source.
#         Must be provided, to give the shape for sky realizations.
#     phi_ss : (F,S,L) NDarray or None
#         Azimuthal (longitudinal) angular position in the sky of each single source.
#         If None, random values between 0 and 2pi will be assigned.
#     Phi0_ss : (F,S,L) NDarray or None
#         Initial GW phase.
#         If None, random values between 0 and 2pi will be assigned.
#     iota_ss : (F,S,L) NDarray or None
#         Inclination of each single source with respect to the line of sight.
#         If None, random values between 0 and pi will be assigned.
#     psi_ss : (F,S,L) NDarray or None
#         Polarization of each single source.
#         If None, random values between 0 and pi will be assigned.
#     Fe_bar : scalar or None
#         Threshold F-statistic
#     red_amp : scalar or None
#         Amplitude of pulsar red noise.
#     red_gamma : scalar or None
#         Power law index of pulsar red noise.
#     alpha_0 : scalar
#         False alarm probability
#     ret_snr : Bool
#         Whether or not to also return snr_ss.

#     Returns
#     -------
#     gamma_ss : (R,S) NDarray
#         Probability of detecting any single source, for each R and S realization.
#     snr_ss : (F,R,S,L) NDarray
#         SNR of each single source. Returned only if ret_snr is True.
#     gamma_ssi : (F,R,S,L) NDarray
#         DP of each single source. Returned only if ret_snr is True.

#     """

#     dur = 1.0/fobs[0]
#     cad = 1.0/(2*fobs[-1])
#     fobs_cents, fobs_edges = utils.pta_freqs(dur, num=len(fobs))
#     dfobs = np.diff(fobs_edges)

#     # Assign random single source sky params, if not provided.
#     nfreqs, nreals, nloudest = [*hc_ss.shape]
#     if theta_ss is None:
#         theta_ss = np.random.uniform(0,np.pi, size=nfreqs*nskies*nloudest).reshape(nfreqs, nskies, nloudest)
#     if phi_ss is None:
#         phi_ss = np.random.uniform(0,2*np.pi, size=theta_ss.size).reshape(theta_ss.shape)
#     if Phi0_ss is None:
#         Phi0_ss = np.random.uniform(0,2*np.pi, size=theta_ss.size).reshape(theta_ss.shape)
#     if iota_ss is None:
#         iota_ss = np.random.uniform(0, np.pi, size = theta_ss.size).reshape(theta_ss.shape)
#     if psi_ss is None:
#         psi_ss = np.random.uniform(0, np.pi, size = theta_ss.size).reshape(theta_ss.shape)

#     # unitary vectors
#     m_hat = detstats._m_unitary_vector(theta_ss, phi_ss, psi_ss) # (3,F,S,L)
#     n_hat = detstats._n_unitary_vector(theta_ss, phi_ss, psi_ss) # (3,F,S,L)
#     Omega_hat = detstats._Omega_unitary_vector(theta_ss, phi_ss) # (3,F,S,L)


#     # get pulsar properties
#     thetas = np.zeros(len(pulsars))
#     phis = np.zeros(len(pulsars))
#     sigmas = np.zeros(len(pulsars))
#     for ii in range(len(pulsars)):
#         thetas[ii] = pulsars[ii].theta
#         phis[ii] = pulsars[ii].phi
#         sigmas[ii] = np.mean(pulsars[ii].toaerrs)

#     pi_hat = detstats._pi_unitary_vector(phis, thetas) # (3,P)

#     # antenna pattern functions
#     F_iplus, F_icross = detstats._antenna_pattern_functions(m_hat, n_hat, Omega_hat,
#                                                    pi_hat) # (P,F,S,L)

#     # noise spectral density
#     # S_i = detstats._total_noise(cad, sigmas, hc_ss, hc_bg, fobs, red_amp, red_gamma)
#     S_i = noise
#     if noise.shape != (npsrs, nfreqs, nreals, nloudest):
#         err = f"{noise.shape=}, must be shape (P,F,R,L)"
#         raise ValueError(err)

#     # amplitudw
#     amp = detstats._amplitude(hc_ss, fobs, dfobs) # (F,R,L)

#     # SNR (includes a_pol, b_pol, and Phi_T calculations internally)
#     if snr_cython:
#         snr_ss = detstats._snr_ss(amp, F_iplus, F_icross, iota_ss, dur, Phi0_ss, S_i, fobs) # (F,R,S,L)
#     else:
#         snr_ss = detstats._snr_ss_5dim(amp, F_iplus, F_icross, iota_ss, dur, Phi0_ss, S_i, fobs) # (F,R,S,L)

#     if gamma_cython:
#         gamma_ssi = detstats._gamma_ssi_cython(snr_ss, grid_path=grid_path) # (F,R,S,L)
#     else:
#         if (Fe_bar is None):
#             Num = hc_ss[:,0,:].size # number of single sources in a single strain realization (F*L)
#             Fe_bar = detstats._Fe_thresh(Num, alpha_0=alpha_0, guess=Fe_bar_guess) # scalar

#         gamma_ssi = detstats._gamma_ssi(Fe_bar, rho=snr_ss, print_nans=print_nans) # (F,R,S,L)
#     gamma_ss = detstats._ss_detection_probability(gamma_ssi) # (R,S)

#     if ret_snr:
#         return gamma_ss, snr_ss, gamma_ssi
#     else:
#         return gamma_ss

In [None]:
# put noise_ss in desired shape

$$ \mathrm{noise} = \frac{h_{c,sens}^2}{ 12\pi^2 f^3 }$$

In [None]:
print(sc_ss.shape)
noise_has_ss = sc_ss**2 / fobs_cents**3 / (12*np.pi**2)
noise_ss = np.repeat(noise_has_ss, NPSRS*NREALS*NLOUDEST).reshape(NFREQS, NPSRS, NREALS, NLOUDEST) # (F,P,R,L)
noise_ss = np.swapaxes(noise_ss, 0, 1) # (P,F,R,L)

Sh_rest = detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents) # (F,R,L)
noise_ss = noise_ss + Sh_rest[np.newaxis,:,:,:]

In [None]:
dp_ss_has, snr_ss_has, dp_ssi_has = detstats.detect_ss_pta(
    psrs, fobs_cents, hc_ss, hc_bg, nskies=NSKIES,
    custom_noise=noise_ss, ret_snr=True)
dp_ss_def, snr_ss_def, dp_ssi_def = detstats.detect_ss_pta(
    psrs, fobs_cents, hc_ss, hc_bg, nskies=NSKIES,
    custom_noise=None, ret_snr=True)

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest Detection Probability')
xx = fobs_cents*YR

y1 = dp_ssi_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = has.sc + rest'
y2 = dp_ssi_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = white + rest'

handles = []
for yy  in [y1, y2]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)

ax.legend(handles=handles, labels=[label1, label2])


In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Top 3 Detection Probability')
xx = fobs_cents*YR

y1 = dp_ssi_has[:,:,:,:3].reshape(NFREQS, NREALS*3*NSKIES)
label1 = 'noise = has.sc + rest'
y2 = dp_ssi_def[:,:,:,:3].reshape(NFREQS, NREALS*3*NSKIES)
label2 = 'noise = white + rest'

handles = []
for yy  in [y1, y2]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)

ax.legend(handles=handles, labels=[label1, label2])


In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest SNR')
xx = fobs_cents*YR

y1 = snr_ss_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = has.sc + rest'
y2 = snr_ss_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = white + rest'

handles = []
for yy  in [y1, y2]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)

ax.legend(handles=handles, labels=[label1, label2])


In [None]:
total_noise_ss = detstats._total_noise(cad, np.repeat(sigma, NPSRS), hc_ss, hc_bg, fobs_cents)
print(total_noise_ss.shape)

white_noise = detstats._white_noise(cad, np.repeat(sigma, NPSRS))
print(white_noise.shape)

In [None]:
print(noise_has_ss.shape)

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='noise')
xx = fobs_cents*YR

y1 = np.swapaxes(noise_ss[:,:,:,0], 0,1).reshape(NFREQS, NPSRS*NREALS)
label1 = 'noise = has.sc + hc_rest'
y2 = np.swapaxes(total_noise_ss[:,:,:,0], 0,1).reshape(NFREQS, NPSRS*NREALS)
label2 = 'noise = white + hc_rest'
y3 = white_noise[:,np.newaxis]
label3 = 'white'
y4 = (detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents))[:,:,0]
label4 = 'hc_rest'
y5 = noise_has_ss[:,np.newaxis]
label5 = 'has sens curve'
print(y3.shape, y4.shape)

handles = []
for yy  in [y1, y2, y3, y4, y5]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)
labels = [label1, label2, label3, label4, label5]
ax.legend(handles=handles, labels=labels)


In [None]:
print(sigma)
print(holo.utils.stats(white_noise))