In [16]:
%reset -f

In [17]:
import numpy as np

In [19]:
PI = np.pi

In [20]:
import gwbench.basic_constants as bc

In [21]:
# Injection parameter calculation

# e0 = 0.0
# chi1z = 0.0
# chi2z = 0.0

tc = 0.0
phic = 0.0

M1_solar = 1.4    # Change the mass of the binary
M2_solar = 1.4

M_solar = M1_solar + M2_solar          # Total mass in solar mass unit
M_SI = M_solar*bc.Msun                 # Total mass in SI unit
M_sec = M_SI*bc.GNewton/bc.cLight**3   # Total mass in seconds

eta = (M1_solar*M2_solar/(M_solar)**2) 

Mc_solar = M_solar*eta**(3/5)          # Chirp mass in solar mass unit
Mc = M_sec*eta**(3/5)                  # Chirp mass in seconds

# print(Mc_solar)
# print(Mc)

In [22]:
import aiss_model_np as aiss_np     # loading the PSD and can be change from the script
from scipy.integrate import quad    # 

In [23]:
# Choose the PSD and the desired frequency range

# Fs = 4096


fs = 20
flso = (6.**(3/2) * PI * M_sec)**(-1)

# deltaF = 2**(-2.9)
# f = np.arange(0, Fs, deltaF)

# i_fs = int((fs-0)/deltaF)
# i_flso = int((flso-0)/deltaF) + 1

# f = f[i_fs:i_flso]

f = np.arange(fs, flso, 0.1)

psd_func = aiss_np.Sh_aLIGO 
psd = psd_func(f)

# Calculating Amplitude corresponding to SNR 10
rho0  = 10

def integrand(f1):
    return f1**(-7/3) / psd_func(f1)

ans, err = quad(integrand, f[0], flso) # quad returns the answer of the quadrature sum and the error, the error is quite high
A =  (rho0**2 / (4*ans))**0.5

DL = ((5./24.)**0.5/PI**(2./3.))*(Mc**(5./6.)/A) # From AISS
DL_Mpc = DL*bc.cLight/bc.Mpc


# print(DL)
# print(A)
# print(flso)
# print(f[-1])
# print(f)

## User Choices

#### choose the desired detectors

In [24]:
# network_spec = ['aLIGO_H']

In [25]:
from network_check import Network # "Network" is a class
import gwbench.wf_class as wfc   # "wfc" is a class

#### initialize the network with the desired detectors

In [None]:
net = Network()  # "net" is an object inside the class "Network"
net.wf = wfc.Waveform()

# print(net.wf)

#### choose the desired waveform 

In [None]:
import wf_models.tf2_2_np as tf2_2_np
import wf_models.tf2_2_sp as tf2_2_sp

#### pass the chosen waveform to the network for initialization

In [None]:
def select_wf_model_quants(self):
    np_mod = tf2_2_np
    sp_mod = tf2_2_sp 
    
    if sp_mod is None: sp_tmp = None
    else:              sp_tmp = sp_mod.hfpc
        
    return np_mod.wf_symbs_string, np_mod.hfpc, sp_tmp

net.wf.wf_symbs_string, net.wf.hfpc_np, net.wf.hfpc_sp = select_wf_model_quants(net)

#### set the injection parameters

In [None]:
inj_params = {
    'Mc':    Mc_solar,
    'eta':   0.2499,
    'chi1z': 0,
    'chi2z': 0,
    'DL':    DL_Mpc,
    'tc':    tc,
    'phic':  phic,
    'iota':  0,
    'ra':    0,
    'dec':   0.0,
    'psi':   0,
    'gmst0': 0
    }

#### assign with respect to which parameters to take derivatives

In [None]:
deriv_symbs_string = 'Mc eta tc phic'

#### assign which parameters to convert to cos or log versions

In [None]:
conv_cos = ()
conv_log = ('Mc', 'eta')

#### choose whether to take Earth's rotation into account

In [None]:
use_rot = 0

#### pass all these variables to the network

In [None]:
net.set_net_vars(
    f=f, inj_params=inj_params,
    deriv_symbs_string=deriv_symbs_string,
    conv_cos=conv_cos, conv_log=conv_log,
    use_rot=use_rot
    )

## GW benchmarking

#### compute the WF polarizations and their derivatives

In [None]:
net.calc_wf_polarizations()

In [None]:
net.calc_wf_polarizations_derivs_num()

In [None]:
import gwbench.basic_functions as bfs
import fisher_analysis_tools as fat

In [None]:
deriv_symbs_list = deriv_symbs_string.split(' ')
deriv_hfp_list = ['del_' + ('log_' + item if item in conv_log else item) + '_hfp' for item in deriv_symbs_list]
del_vs_f_dic = bfs.get_sub_dict(net.del_hfpc,deriv_hfp_list,1)

net.fisher, net.cov, net.wc_fisher, net.cond_num = fat.calc_fisher_cov_matrices(list(del_vs_f_dic.values()), psd, f, 0)
net.errs = fat.get_errs_from_cov(net.cov, net.deriv_variables)


In [None]:
#print error values

from math import floor, log10

def round_n(x, n):
    return round(x, n - int(floor(log10(abs(x)))) - 1)


# print the contents of the network objects

print("tc(ms): ", round_n(net.errs['tc']*1000,5)) # 1000 here is for sec to msec conversion
print("phic: ", round_n(net.errs['phic'], 5))
print("log_Mch: ", round_n(net.errs['log_Mc']*100,5)) # 100 here is for percentage error
print("log_eta: ", round_n(net.errs['log_eta']*100,5)) # 100 here is for percentage error
# print("log_e0: ", round_n(net.errs['log_e0'],4))
print()
