In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import os

from holodeck import plot, detstats, utils
from holodeck.constants import YR, MSOL, MPC, GYR
import holodeck as holo

import hasasia.sensitivity as hsen
import hasasia.sim as hsim
import hasasia.skymap as hsky

In [None]:
SHAPE = None
NREALS = 500
NFREQS = 40
NLOUDEST = 10

SAVEFIG = False
TOL=0.01
MAXBADS=5

RED_GAMMA = None
RED2WHITE = None

NVARS = 21
BGL=1

NPSRS = 40
NSKIES = 100
TARGET = 'gsmf_phi0' # EDIT AS NEEDED
TITLE = plot.PARAM_KEYS[TARGET]  # EDIT AS NEEDED

In [None]:
# def get_data(
#         target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
#     path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz',     
# ):
#     load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

#     if os.path.exists(load_data_from_file) is False:
#         err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
#         raise Exception(err)
#     file = np.load(load_data_from_file, allow_pickle=True)
#     data = file['data']
#     params = file['params']
#     file.close()
#     print(target, "got data")

#     return data, params

In [None]:
data, params, dets = detstats.get_data(TARGET)

# Red Noise

In [None]:
var=10

fobs_cents = data[var]['fobs_cents']
dur = 1.0 / fobs_cents[0]
cad = 1.0 / (2.0 * fobs_cents[-1])
hc_ss = data[var]['hc_ss']
hc_bg = data[var]['hc_bg']

In [None]:
red_gamma=-3.0
red2white=0.00

rr=0
psrs, red_amp, _sigstart, _sigmin, _sigmax = detstats.calibrate_one_pta(
    hc_bg[:,rr], hc_ss[:,rr,:], fobs_cents,
    npsrs=50, debug=True, ret_sig=True,
    red_gamma=red_gamma, red2white=red2white)
        
    # use those psrs to calculate realization detstats
_dp_bg, _snr_bg = detstats.detect_bg_pta(
    psrs, fobs_cents, hc_bg[:,rr:rr+1],  hc_ss[:,rr:rr+1,:],
      ret_snr=True, red_amp=red_amp, red_gamma=red_gamma)
_dp_ss, _snr_ss, _gamma_ssi = detstats.detect_ss_pta(
            psrs, fobs_cents, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], 
            custom_noise=None, red_amp=red_amp, red_gamma=red_gamma,
            nskies=NSKIES, ret_snr=True,)

sigmas = np.zeros(len(psrs))
for ii in range(len(psrs)):
    sigmas[ii] = np.mean(psrs[ii].toaerrs)

In [None]:

xlabel = plot.LABEL_GW_FREQUENCY_YR
ylabel1 = 'Noise Power Spectral Density'
ylabel2 = 'Detection Probability'
fig, axs = plot.figax(nrows=2, sharex=True, xlabel=xlabel)

for ii, ylabel in enumerate([ylabel1, ylabel2]):
    axs[ii].set_ylabel(ylabel)

xx = fobs_cents*YR

# calculate white noise
white_noise = detstats._white_noise(cad, sigmas)# P,1
print(f"white: {white_noise.shape}, {holo.utils.stats(white_noise)}")
axs[0].axhline(white_noise[0], color='grey', label='white')

# add red noise
red_noise = detstats._red_noise(red_amp, red_gamma, fobs_cents) # (1,F,)
# red_noise = red_amp**2 / (12*np.pi**2) * (fobs_cents/fref)**red_gamma * (fref)**-3 # (1,F,)
print(f"red: {red_noise.shape}, {holo.utils.stats(red_noise)}")
print(f"{red_amp=}, {red_gamma=}")
axs[0].plot(xx, red_noise, color='red', label='red')


noise = white_noise[:,np.newaxis]  + red_noise[np.newaxis,:] # (P,F,)
print(f"total: {noise.shape}, {holo.utils.stats(noise)}")
y1 = np.swapaxes(noise, 0, 1)
print(y1.shape)
plot.draw_med_conf_color(axs[0], xx, y1, color='green')

y2 = _gamma_ssi[...,0].squeeze()

axs[0].axvline(1, linestyle=':', color='k')


axs[0].legend()
axs[1].plot(xx, y2, alpha=0.5)

fig.subplots_adjust(hspace=0)

In [None]:
def plot_it(debug=True):
    xlabel = plot.LABEL_GW_FREQUENCY_YR
    ylabel1 = 'Noise Power Spectral Density'
    ylabel2 = 'Detection Probability'
    fig, axs = plot.figax(nrows=2, sharex=True, xlabel=xlabel)

    for ii, ylabel in enumerate([ylabel1, ylabel2]):
        axs[ii].set_ylabel(ylabel)

    xx = fobs_cents*YR

    # calculate white noise
    white_noise = detstats._white_noise(cad, sigmas)# P,1
    if debug: print(f"white: {white_noise.shape}, {holo.utils.stats(white_noise)}")
    axs[0].axhline(white_noise[0], color='grey', label='white')

    # add red noise
    red_noise = detstats._red_noise(red_amp, red_gamma, fobs_cents) # (1,F,)
    # red_noise = red_amp**2 / (12*np.pi**2) * (fobs_cents/fref)**red_gamma * (fref)**-3 # (1,F,)
    if debug: print(f"red: {red_noise.shape}, {holo.utils.stats(red_noise)}")
    if debug: print(f"{red_amp=}, {red_gamma=}")
    axs[0].plot(xx, red_noise, color='red', label='red')


    noise = white_noise[:,np.newaxis]  + red_noise[np.newaxis,:] # (P,F,)
    if debug: print(f"total: {noise.shape}, {holo.utils.stats(noise)}")
    y1 = np.swapaxes(noise, 0, 1)
    if debug: print(y1.shape)
    plot.draw_med_conf_color(axs[0], xx, y1, color='green')

    y2 = _gamma_ssi[...,0].squeeze()

    axs[0].axvline(1, linestyle=':', color='k')


    axs[0].legend()
    axs[1].plot(xx, y2, alpha=0.5)

    fig.subplots_adjust(hspace=0)
    return fig

In [None]:
nfreqs = len(fobs_cents)
freqs = np.repeat(fobs_cents*YR, NSKIES*NLOUDEST).reshape(
    nfreqs, 1, NSKIES, NLOUDEST)

logmean, logvar2 = detstats.weighted_mean_variance(np.log10(freqs), weights=_gamma_ssi)
print(10**(logmean))

print(freqs.shape)
print(_gamma_ssi.shape)

In [None]:
red_gamma=-1.5

for red2white in [0.5, 1, 2]:

    rr=0
    psrs, red_amp, _sigstart, _sigmin, _sigmax = detstats.calibrate_one_pta(
        hc_bg[:,rr], hc_ss[:,rr,:], fobs_cents,
        npsrs=50, debug=True, ret_sig=True,
        red_gamma=red_gamma, red2white=red2white)
            
        # use those psrs to calculate realization detstats
    _dp_bg, _snr_bg = detstats.detect_bg_pta(
        psrs, fobs_cents, hc_bg[:,rr:rr+1],  hc_ss[:,rr:rr+1,:],
        ret_snr=True, red_amp=red_amp, red_gamma=red_gamma)
    _dp_ss, _snr_ss, _gamma_ssi = detstats.detect_ss_pta(
                psrs, fobs_cents, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], 
                custom_noise=None, red_amp=red_amp, red_gamma=red_gamma,
                nskies=NSKIES, ret_snr=True,)

    sigmas = np.zeros(len(psrs))
    for ii in range(len(psrs)):
        sigmas[ii] = np.mean(psrs[ii].toaerrs)

    logmean, logvar2 = detstats.weighted_mean_variance(np.log10(freqs), weights=_gamma_ssi)
    print(f"Q={red2white}, favg={10**(logmean)}")
    print(f"{red_amp=}, {holo.utils.stats(_gamma_ssi)=}")
    fig = plot_it(debug=False)

    ax1 = fig.axes[1]
    ax1.axvline(10**(logmean))
    ax1.text(0.85, 0.25, f"favg={10**logmean}")
    ax0 = fig.axes[0]
    ax0.set_title(f"{red2white=}")

In [None]:
red_gamma=-3

for red2white in [0, 0.01, 1, 100]:

    rr=0
    psrs, red_amp, _sigstart, _sigmin, _sigmax = detstats.calibrate_one_pta(
        hc_bg[:,rr], hc_ss[:,rr,:], fobs_cents,
        npsrs=50, debug=True, ret_sig=True,
        red_gamma=red_gamma, red2white=red2white)
            
        # use those psrs to calculate realization detstats
    _dp_bg, _snr_bg = detstats.detect_bg_pta(
        psrs, fobs_cents, hc_bg[:,rr:rr+1],  hc_ss[:,rr:rr+1,:],
        ret_snr=True, red_amp=red_amp, red_gamma=red_gamma)
    _dp_ss, _snr_ss, _gamma_ssi = detstats.detect_ss_pta(
                psrs, fobs_cents, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], 
                custom_noise=None, red_amp=red_amp, red_gamma=red_gamma,
                nskies=NSKIES, ret_snr=True,)

    sigmas = np.zeros(len(psrs))
    for ii in range(len(psrs)):
        sigmas[ii] = np.mean(psrs[ii].toaerrs)

    logmean, logvar2 = detstats.weighted_mean_variance(np.log10(freqs), weights=_gamma_ssi)
    print(f"Q={red2white}, favg={10**(logmean)}")
    print(f"{red_amp=}, {holo.utils.stats(_gamma_ssi)=}")
    fig = plot_it(debug=False)

    ax1 = fig.axes[1]
    ax1.axvline(10**(logmean))
    ax1.text(0.85, 0.25, f"favg={10**logmean:.2e}")

    ax0 = fig.axes[0]
    ax0.set_title(f"{red2white=}, {red_gamma=}")

# Old

In [None]:
var=10

fobs_cents = data[var]['fobs_cents']
dur = 1.0 / fobs_cents[0]
cad = 1.0 / (2.0 * fobs_cents[-1])
hc_ss = data[var]['hc_ss']
hc_bg = data[var]['hc_bg']

In [None]:
phis = None
thetas = None
npsrs = NPSRS
red2white = RED2WHITE
sigstart = 2.6e-6

# randomize pulsar positions
if phis is None: phis = np.random.uniform(0, 2*np.pi, size = npsrs)
if thetas is None: thetas = np.random.uniform(np.pi/2, np.pi/2, size = npsrs)
sigma = sigstart
# if red2white is not None:
#     red_amp = _white_noise(cad, sigma) * red2white

psrs = hsim.sim_pta(timespan=dur/YR, cad=1/(cad/YR), sigma=sigma,
                phi=phis, theta=thetas)

In [None]:
dp_bg = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg=hc_bg,
                            red_amp=None, red_gamma=None, ss_noise=False)
print(holo.utils.stats(dp_bg))
print(np.mean(dp_bg), np.median(dp_bg))

In [None]:
spectra = []
for psr in psrs:
    sp = hsen.Spectrum(psr, freqs=fobs_cents)
    sp.NcalInv
    spectra.append(sp)

In [None]:
sc_ss = hsen.DeterSensitivityCurve(spectra).h_c
sc_bg = hsen.GWBSensitivityCurve(spectra).h_c

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN)
xx = fobs_cents*YR
ax.plot(xx, sc_bg, label='BG Sensitivity', linestyle='--', color='tab:blue')
ax.plot(xx, sc_ss, label='SS Sensitivity', linestyle='--', color='tab:pink')

ax.fill_between(xx, *np.percentile(hc_bg, (25,75), axis=-1), 
                color='tab:blue', alpha=0.2)
ax.plot(xx, np.median(hc_bg, axis=-1), label='hc_BG', color='tab:blue')

ax.fill_between(xx, *np.percentile(hc_ss[:,:,0], (25, 75), axis=-1),
                color='tab:pink', alpha=0.2)
ax.plot(xx, np.median(hc_ss[:,:,0], axis=-1), label='hc_SS', color='tab:pink')
ax.legend()

# Calculate detstats

$$ \mathrm{noise} = \frac{h_{c,sens}^2}{ 12\pi^2 f^3 }$$

In [None]:
print(sc_ss.shape)
noise_has_ss = sc_ss**2 / fobs_cents**3 / (12*np.pi**2)
noise_ss = np.repeat(noise_has_ss, NPSRS*NREALS*NLOUDEST).reshape(NFREQS, NPSRS, NREALS, NLOUDEST) # (F,P,R,L)
noise_ss = np.swapaxes(noise_ss, 0, 1) # (P,F,R,L)

Sh_rest = detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents) # (F,R,L)
noise_ss = noise_ss + Sh_rest[np.newaxis,:,:,:]

noise_has_bg = sc_bg**2 / fobs_cents**3 / (12*np.pi**2)
noise_bg = np.repeat(noise_has_bg, NPSRS*NREALS).reshape(NFREQS, NPSRS, NREALS) # (F,P,R)
noise_bg = np.swapaxes(noise_bg, 0, 1) # (P,F,R)

In [None]:
dp_ss_has, snr_ss_has, dp_ssi_has = detstats.detect_ss_pta(
    psrs, fobs_cents, hc_ss, hc_bg, nskies=NSKIES,
    custom_noise=noise_ss, ret_snr=True)
dp_ss_def, snr_ss_def, dp_ssi_def = detstats.detect_ss_pta(
    psrs, fobs_cents, hc_ss, hc_bg, nskies=NSKIES,
    custom_noise=None, ret_snr=True)

dp_bg_has, snr_bg_has = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg, ret_snr=True, custom_noise=noise_bg)
dp_bg_def, snr_bg_def = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg, ret_snr=True, custom_noise=None)


In [None]:
print(f"{holo.utils.stats(dp_bg_def)=}")
print(f"{holo.utils.stats(dp_bg_has)=}")
print(np.mean(dp_bg_has))

In [None]:
freqs = np.repeat(fobs_cents, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST)*YR
favg_has, var2_has = detstats.weighted_mean_variance(freqs, dp_ssi_has)
favg_def, var2_def = detstats.weighted_mean_variance(freqs, dp_ssi_def)

In [None]:
print(dp_bg_has.shape)

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest Detection Probability (unclbrtd)')
xx = fobs_cents*YR

y1 = dp_ssi_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = DeterSC + S_rest'
y2 = dp_ssi_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = S_WN + S_rest'

colors = ['tab:blue', 'tab:orange',]
handles = []

for ii,yy  in enumerate([y1, y2]):
    hh = plot.draw_med_conf_color(ax, xx, yy, color=colors[ii])
    handles.append(hh)


var2s = [var2_has, var2_def]
colors = ['blue', 'orangered',]
for ii,favg in enumerate([favg_has, favg_def]):
    std = np.sqrt(var2s[ii])
    hh = ax.axvline(favg, color=colors[ii], linestyle='--')
    handles.append(hh)
    # ax.axvspan(favg-std, favg+std, alpha=0.2, color=colors[ii])

y3 = dp_bg_has
y4 = dp_bg_def
colors = ['darkblue', 'saddlebrown']
for ii,yy in enumerate([y3, y4]):
    hh = ax.axhline(np.median(yy), color=colors[ii])
    handles.append(hh)
    for pp in [50, 95]:
        percs = (50-pp/2, 50+pp/2)
    ax.axhspan(*np.percentile(yy, percs), color=colors[ii], alpha=0.2)


labels = [label1, label2, 'favg (DSC)', 'favg (WN)', 'DP_BG (GSC)', 'DP_BG (WN)']
ax.legend(handles=handles, labels=labels)


In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='Top 3 Detection Probability')
xx = fobs_cents*YR

y1 = dp_ssi_has[:,:,:,:3].reshape(NFREQS, NREALS*3*NSKIES)
label1 = 'noise = has.sc + rest'
y2 = dp_ssi_def[:,:,:,:3].reshape(NFREQS, NREALS*3*NSKIES)
label2 = 'noise = white + rest'

handles = []
for yy  in [y1, y2]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)

ax.legend(handles=handles, labels=[label1, label2])


In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest SNR')
xx = fobs_cents*YR

y1 = snr_ss_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = has.sc + rest'
y2 = snr_ss_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = white + rest'

handles = []
for yy  in [y1, y2]:
    hh = plot.draw_med_conf(ax, xx, yy,)
    handles.append(hh)

ax.legend(handles=handles, labels=[label1, label2])


In [None]:
total_noise_ss = detstats._total_noise(cad, np.repeat(sigma, NPSRS), hc_ss, hc_bg, fobs_cents)
print(total_noise_ss.shape)

white_noise = detstats._white_noise(cad, np.repeat(sigma, NPSRS))
print(white_noise.shape)

In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='noise')
xx = fobs_cents*YR

y1 = np.swapaxes(noise_ss[:,:,:,0], 0,1).reshape(NFREQS, NPSRS*NREALS)
label1 = 'noise = has.sc + hc_rest'
y2 = np.swapaxes(total_noise_ss[:,:,:,0], 0,1).reshape(NFREQS, NPSRS*NREALS)
label2 = 'noise = white + hc_rest'
y3 = white_noise[:,np.newaxis]
label3 = 'white'
y4 = (detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents))[:,:,0]
label4 = 'hc_rest'
y5 = noise_has_ss[:,np.newaxis]
label5 = 'has sens curve'
print(y3.shape, y4.shape)

handles = []
colors = ['tab:blue', 'tab:orange', 'blue', 'green', 'orangered']
linestyles = ['-', '-', '--', ':', ':']
for ii,yy  in enumerate([y1, y2, y3, y4, y5]):
    hh, = ax.plot(xx, np.median(yy, axis=-1), color=colors[ii], linestyle=linestyles[ii])
    ax.fill_between(xx, *np.percentile(yy, (5, 95), axis=-1), color=colors[ii], alpha=0.25)
    # hh = plot.draw_med_conf_color(ax, xx, yy, color=colors[ii])
    handles.append(hh)
labels = [label1, label2, label3, label4, label5]
ax.legend(handles=handles, labels=labels)


In [None]:
fig, ax = plot.figax(xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='noise')
xx = fobs_cents*YR

y1 = np.swapaxes(noise_bg[:,:,:], 0,1).reshape(NFREQS, NPSRS*NREALS)
label1 = 'noise = has.gwbSC'
y3 = white_noise[:,np.newaxis]
label3 = 'white'
y4 = (detstats._Sh_rest_noise(hc_ss, hc_bg, fobs_cents))[:,:,0]
label4 = 'hc_rest'
# y5 = noise_has_bg[:,np.newaxis]
# label5 = 'GWB sens curve'
print(y3.shape, y4.shape)

handles = []
colors = ['tab:blue', 'green', 'orangered']
linestyles = ['-', '-', '--', ':', ':']
for ii,yy  in enumerate([y1,  y3, y4]):
    hh, = ax.plot(xx, np.median(yy, axis=-1), color=colors[ii], linestyle=linestyles[ii])
    ax.fill_between(xx, *np.percentile(yy, (5, 95), axis=-1), color=colors[ii], alpha=0.25)
    # hh = plot.draw_med_conf_color(ax, xx, yy, color=colors[ii])
    handles.append(hh)
labels = [label1, label3, label4]
ax.legend(handles=handles, labels=labels)


# Calibrate by Realization

In [None]:
from datetime import datetime

In [None]:
def detect_pspace_model_clbrt_pta(fobs_cents, hc_ss, hc_bg, npsrs, nskies, DSC=False,
                        sigstart=1e-6, sigmin=1e-9, sigmax=1e-4, tol=0.01, maxbads=5,
                        thresh=detstats.DEF_THRESH, debug=False, save_snr_ss=False, save_gamma_ssi=True,
                        red_amp=None, red_gamma=None, red2white=None, ss_noise=False): 
    """ Detect pspace model using individual sigma calibration for each realization
    
    Parameters
    ----------

    red2white : scalar or None
        Fixed ratio between red and white noise amplitude, if not None. 
        Otherwise, red noise stays fixed

    """
    dur = 1.0/fobs_cents[0]
    cad = 1.0/(2*fobs_cents[-1])

    nfreqs, nreals, nloudest = [*hc_ss.shape]
        
    # form arrays for individual realization detstats
    # set all to nan, only to be replaced if successful pta is found
    dp_ss = np.ones((nreals, nskies)) * np.nan   
    dp_bg = np.ones(nreals) * np.nan
    snr_ss = np.ones((nfreqs, nreals, nskies, nloudest)) * np.nan
    snr_bg = np.ones((nreals)) * np.nan
    gamma_ssi = np.ones((nfreqs, nreals, nskies, nloudest)) * np.nan


    # for each realization, 
    # use sigmin and sigmax from previous realization, 
    # unless it's the first realization of the sample
    _sigstart, _sigmin, _sigmax = sigstart, sigmin, sigmax 
    if debug: 
        mod_start = datetime.now()
        real_dur = datetime.now()
    failed_psrs=0
    for rr in range(nreals):
        if debug: 
            now = datetime.now()
            if (rr%100==99):
                print(f"{rr=}, {(now-real_dur)/100} s per realization, {_sigmin=:.2e}, {_sigmax=:.2e}, {_sigstart=:.2e}")
                real_dur = now

        # get calibrated psrs 
        psrs, red_amp, _sigstart, _sigmin, _sigmax = detstats.calibrate_one_pta(hc_bg[:,rr], hc_ss[:,rr,:], fobs_cents, npsrs, tol=tol, maxbads=maxbads,
                                    sigstart=_sigstart, sigmin=_sigmin, sigmax=_sigmax, debug=debug, ret_sig=True,
                                    red_amp=red_amp, red_gamma=red_gamma, red2white=red2white, ss_noise=ss_noise)
        _sigmin /= 2
        _sigmax *= 2 + 2e-20 # >1e-20 to make sure it doesnt immediately fail the 0 check 

        if psrs is None:
            failed_psrs += 1
            continue # leave values as nan, if no successful PTA was found
        
        # use those psrs to calculate realization BG detstats
        _dp_bg, _snr_bg = detstats.detect_bg_pta(psrs, fobs_cents, hc_bg[:,rr:rr+1],  hc_ss[:,rr:rr+1,:], ret_snr=True, red_amp=red_amp, red_gamma=red_gamma)
        
        dp_bg[rr], snr_bg[rr] = _dp_bg.squeeze(), _snr_bg.squeeze()


        # calculate SS noise from DeterSensitivityCurve and S_h,rest
        if DSC:
            spectra = []
            for psr in psrs:
                sp = hsen.Spectrum(psr, freqs=fobs_cents)
                sp.NcalInv
                spectra.append(sp)
            sc_hc = hsen.DeterSensitivityCurve(spectra).h_c
            noise_dsc = sc_hc**2 / (12 * np.pi**2 * fobs_cents**3)
            noise_dsc = np.repeat(noise_dsc, npsrs*1*nloudest).reshape(nfreqs, npsrs, 1, nloudest) # (F,P,R,L)
            noise_dsc = np.swapaxes(noise_dsc, 0, 1)  # (P,F,R,L)
            noise_rest = detstats._Sh_rest_noise(hc_ss[:,rr:rr+1,:], hc_bg[:,rr:rr+1], fobs_cents) # (F,R,L)
            noise_ss = noise_dsc + noise_rest[np.newaxis,:,:,:]
            
        else:
            noise_ss = None

        # calculate realizatoin SS detstats
        _dp_ss, _snr_ss, _gamma_ssi = detstats.detect_ss_pta(
            psrs, fobs_cents, hc_ss[:,rr:rr+1], hc_bg[:,rr:rr+1], custom_noise=noise_ss,
            nskies=nskies, ret_snr=True, red_amp=red_amp, red_gamma=red_gamma)
        # if debug: print(f"{_dp_ss.shape=}, {_snr_ss.shape=}, {_gamma_ssi.shape=}")
        dp_ss[rr], snr_ss[:,rr], gamma_ssi[:,rr] = _dp_ss.squeeze(), _snr_ss.squeeze(), _gamma_ssi.squeeze()

    ev_ss = detstats.expval_of_ss(gamma_ssi)
    df_ss, df_bg = detstats.detfrac_of_reals(dp_ss, dp_bg)
    _dsdat = {
        'dp_ss':dp_ss, 'snr_ss':snr_ss, 'gamma_ssi':gamma_ssi, 
        'dp_bg':dp_bg, 'snr_bg':snr_bg,
        'df_ss':df_ss, 'df_bg':df_bg, 'ev_ss':ev_ss,
        }
    if save_gamma_ssi:
        _dsdat.update(gamma_ssi=gamma_ssi)
    if save_snr_ss:
        _dsdat.update(snr_ss=snr_ss)
    print(f"Model took {datetime.now() - mod_start} s, {failed_psrs}/{nreals} realizations failed.")
    return _dsdat

In [None]:
def psrs_spectra_gwbnoise(psrs, fobs, nreals, npsrs):
    """ Get GWBSensitivityCurve noise and spectra for psrs
    
    """
    spectra = []
    for psr in psrs:
        sp = hsen.Spectrum(psr, freqs=fobs)
        sp.NcalInv
        spectra.append(sp)
    sc_bg = hsen.GWBSensitivityCurve(spectra).h_c
    noise_gsc = sc_bg**2 / (12 *np.pi**2 *fobs**3)
    noise_gsc = np.repeat(noise_gsc, npsrs*nreals).reshape(len(fobs), npsrs, nreals) # (F,P,R)
    noise_gsc = np.swapaxes(noise_gsc, 0, 1) # (P,F,R)

    return spectra, noise_gsc

def dsc_noise(fobs, nreals, npsrs, nloudest, psrs=None, spectra=None):
    """ Get DeterSensitivityCurve noise using either psrs or spectra
    """

    if spectra is None:
        assert psrs is not None, 'Must provide spectra or psrs'
        spectra = []
        for psr in psrs:
            sp = hsen.Spectrum(psr, freqs=fobs)
            sp.NcalInv
            spectra.append(sp)
    sc_ss = hsen.DeterSensitivityCurve(spectra).h_c
    noise_dsc = sc_ss**2 / (12 *np.pi**2 *fobs**3)
    noise_dsc = np.repeat(noise_dsc, npsrs*nreals*nloudest).reshape(len(fobs), npsrs, nreals, nloudest) # (F,P,R,L)
    noise_dsc = np.swapaxes(noise_dsc, 0, 1) # (P,F,R,L)
    return noise_dsc

In [None]:
# Calculate calibrated DSC detstats

_dsdat_has = detect_pspace_model_clbrt_pta(fobs_cents, hc_ss, hc_bg, NPSRS, NSKIES, DSC=True, 
                                       save_snr_ss=True, save_gamma_ssi=True, debug=True)
dp_ss_clbrt_has = _dsdat_has['dp_ss']
snr_ss_clbrt_has = _dsdat_has['snr_ss']
dp_ssi_clbrt_has = _dsdat_has['gamma_ssi']

In [None]:
# Calculate calibrated default (Rosado) detstats
_dsdat_def = detect_pspace_model_clbrt_pta(fobs_cents, hc_ss, hc_bg, NPSRS, NSKIES, DSC=False, 
                                       save_snr_ss=True, save_gamma_ssi=True, debug=True)
dp_ss_clbrt_def = _dsdat_def['dp_ss']
snr_ss_clbrt_def = _dsdat_def['snr_ss']
dp_ssi_clbrt_def = _dsdat_def['gamma_ssi']

In [None]:
print(holo.utils.stats(_dsdat_has['dp_bg'] - _dsdat_def['dp_bg']))

In [None]:
freqs = np.repeat(fobs_cents, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST)*YR
favg_clbrt_has, var2_clbrt_has = detstats.weighted_mean_variance(freqs, dp_ssi_clbrt_has)
favg_clbrt_def, var2_clbrt_def = detstats.weighted_mean_variance(freqs, dp_ssi_clbrt_def)

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest Detection Probability (Clbrtd)')
xx = fobs_cents*YR

y1 = dp_ssi_clbrt_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = DeterSC + S_rest'
y2 = dp_ssi_clbrt_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = S_WN + S_rest'

colors = ['tab:blue', 'tab:orange',]
handles = []

for ii,yy  in enumerate([y1, y2]):
    hh = plot.draw_med_conf_color(ax, xx, yy, color=colors[ii])
    handles.append(hh)


var2s = [var2_clbrt_has, var2_clbrt_def]
colors = ['blue', 'orangered',]
for ii,favg in enumerate([favg_clbrt_has, favg_clbrt_def]):
    std = np.sqrt(var2s[ii])
    hh = ax.axvline(favg, color=colors[ii], linestyle='--')
    handles.append(hh)
    # ax.axvspan(favg-std, favg+std, alpha=0.2, color=colors[ii])

y3 = _dsdat_has['dp_bg']
y4 = _dsdat_def['dp_bg']
colors = ['darkblue', 'saddlebrown']
for ii,yy in enumerate([y3, y4]):
    hh = ax.axhline(np.median(yy), color=colors[ii])
    handles.append(hh)
    for pp in [50, 95]:
        percs = (50-pp/2, 50+pp/2)
    ax.axhspan(*np.percentile(yy, percs), color=colors[ii], alpha=0.2)


labels = [label1, label2, 'favg (DSC)', 'favg (WN)', 'DP_BG (GSC)', 'DP_BG (WN)']
ax.legend(handles=handles, labels=labels)


# Use GWB Sensitivity Curve Calibration

In [None]:
_dsdat_has_clbrt = detstats.detect_pspace_model_clbrt_pta_gsc(fobs_cents, hc_ss, hc_bg, npsrs=NPSRS, nskies=NSKIES,
                                                            debug=True, save_snr_ss=True, save_gamma_ssi=True)

In [None]:
print(holo.utils.stats(_dsdat_has_clbrt['dp_bg']))

In [None]:
freqs = np.repeat(fobs_cents, NREALS*NSKIES*NLOUDEST).reshape(NFREQS, NREALS, NSKIES, NLOUDEST)*YR
favg_clbrt_sc, var2_clbrt_sc = detstats.weighted_mean_variance(freqs, _dsdat_has_clbrt['gamma_ssi'])


In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel='0th Loudest Detection Probability (Clbrtd)')
xx = fobs_cents*YR

y1 = dp_ssi_clbrt_has[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label1 = 'noise = DeterSC + S_rest'
y2 = dp_ssi_clbrt_def[:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label2 = 'noise = S_WN + S_rest'
y3 = _dsdat_has_clbrt['gamma_ssi'][:,:,:,0].reshape(NFREQS, NREALS*NSKIES)
label3 = 'hasasia GSC-calibrated'

colors = ['tab:blue', 'tab:orange', 'tab:green']
handles = []

for ii,yy  in enumerate([y1, y2,y3]):
    hh = plot.draw_med_conf_color(ax, xx, yy, color=colors[ii])
    handles.append(hh)


var2s = [var2_clbrt_has, var2_clbrt_def, var2_clbrt_sc]
colors = ['blue', 'orangered', 'limegreen']
for ii,favg in enumerate([favg_clbrt_has, favg_clbrt_def, favg_clbrt_sc]):
    std = np.sqrt(var2s[ii])
    hh = ax.axvline(favg, color=colors[ii], linestyle='--')
    handles.append(hh)
    # ax.axvspan(favg-std, favg+std, alpha=0.2, color=colors[ii])

y4 = _dsdat_has['dp_bg']
y5 = _dsdat_def['dp_bg']
y6 = _dsdat_has_clbrt['dp_bg']
colors = ['darkblue', 'saddlebrown', 'darkgreen']
lws = [3,2,1]
for ii,yy in enumerate([y4, y5, y6]):
    hh = ax.axhline(np.median(yy), color=colors[ii], lw=lws[ii])
    handles.append(hh)
    for pp in [50, 95]:
        percs = (50-pp/2, 50+pp/2)
    ax.axhspan(*np.percentile(yy, percs), color=colors[ii], alpha=0.2)


labels = [label1, label2, label3, 'favg (DSC)', 'favg (WN)', 'favg (GSC-clbrtd)', 'DP_BG (GSC)', 'DP_BG (WN)', 'DP_BG (GSC-clbrtd)']
ax.legend(handles=handles, labels=labels, loc='upper right')
