In [1]:
import numpy as np
import matplotlib.pyplot as plt
from few.waveform import GenerateEMRIWaveform, FastSchwarzschildEccentricFlux, FastKerrEccentricEquatorialFlux
from few.utils.constants import Gpc, MRSUN_SI, YRSID_SI
from typing import Optional, Union, Callable
from tqdm import tqdm

use_gpu = True

from stableemrifisher.fisher.fisher import StableEMRIFisher
from stableemrifisher.utils import inner_product

from fastlisaresponse import ResponseWrapper  # Response function 
from lisatools.detector import EqualArmlengthOrbits
from lisatools.sensitivity import get_sensitivity, A1TDISens, E1TDISens, T1TDISens

if not use_gpu:
    
    import few
    
    #tune few configuration
    cfg_set = few.get_config_setter(reset=True)
    
    cfg_set.enable_backends("cpu")
    cfg_set.set_log_level("info");
    force_backend = 'cpu'
else:    
    pass #let the backend decide for itself.

startup


In [2]:
# we will make a custom PSD class which inherits the lisatools.sensitivity.get_sensitivity class but allows for an extra "degradation" parameter.

class get_sens:
    def __init__(self):
        pass
    def __call__(self, 
                 f, #the frequency must be the first parameter for SEF's PSD calculations
                 d = 1.0, #degradation parameter.
                 **noise_kwargs):
        return d * get_sensitivity(f, **noise_kwargs)

In [3]:
#waveform class setup
waveform_class = FastSchwarzschildEccentricFlux
waveform_class_kwargs = dict(inspiral_kwargs=dict(err=1e-11,),
                             mode_selector_kwargs=dict(mode_selection_threshold=1e-5))

#waveform generator setup
waveform_generator = GenerateEMRIWaveform
waveform_generator_kwargs = dict(return_list=False)

#ResponseWrapper setup
ResponseWrapper = ResponseWrapper
tdi_gen ="1st generation"# "2nd generation"#
order = 20  # interpolation order (should not change the result too much)
tdi_kwargs_esa = dict(
    orbits=EqualArmlengthOrbits(), order=order, tdi=tdi_gen, tdi_chan="AE",
)  # could do "AET"
index_lambda = 8
index_beta = 7
# with longer signals we care less about this
t0 = 10000.0  # throw away on both ends when our orbital information is weird
T = 0.1
dt = 10.0

ResponseWrapper_kwargs = dict(
    #waveform_gen=waveform_generator,
    Tobs = T,
    dt = dt,
    index_lambda = index_lambda,
    index_beta = index_beta,
    t0 = t0,
    flip_hx = True,
    is_ecliptic_latitude=False,
    remove_garbage="zero",
    **tdi_kwargs_esa
)

#noise setup
channels = [A1TDISens, E1TDISens]
noise_model = get_sens() #use the custom PSD class
noise_kwargs = [{"sens_fn": channel_i} for channel_i in channels]

noise_kwargs_deg = [{"sens_fn": channel_i, "d":2.0} for channel_i in channels] #for each channel, we add a 2x PSD degradation

In [4]:
#we initialize two versions of SEF, one with the degraded PSD, and one with the usual PSD
sef_deg = StableEMRIFisher(waveform_class=waveform_class, 
                       waveform_class_kwargs=waveform_class_kwargs,
                       waveform_generator=waveform_generator,
                       waveform_generator_kwargs=waveform_generator_kwargs,
                       ResponseWrapper=ResponseWrapper, ResponseWrapper_kwargs=ResponseWrapper_kwargs,
                       noise_model=noise_model, noise_kwargs=noise_kwargs_deg, channels=channels,
                      stats_for_nerds = False, use_gpu = use_gpu,
                       T=T, dt=dt,
                      deriv_type='stable')

In [5]:
sef_nodeg = StableEMRIFisher(waveform_class=waveform_class, 
                       waveform_class_kwargs=waveform_class_kwargs,
                       waveform_generator=waveform_generator,
                       waveform_generator_kwargs=waveform_generator_kwargs,
                       ResponseWrapper=ResponseWrapper, ResponseWrapper_kwargs=ResponseWrapper_kwargs,
                       noise_model=noise_model, noise_kwargs=noise_kwargs, channels=channels,
                      stats_for_nerds = True, use_gpu = use_gpu,
                       T=T, dt=dt,
                      deriv_type='stable')

In [8]:
m1 = 1e6
m2 = 1e1
a = 0
p0 = 9.5
e0 = 0.4
xI0 = 1.0
dist = 0.1
qS = np.pi/3
phiS = np.pi/4
qK = np.pi/6
phiK = np.pi/8
Phi_phi0 = 1.0
Phi_theta0 = 0.0
Phi_r0 = 0.0

pars_list = [m1, m2, a, p0, e0, xI0, dist, qS, phiS, qK, phiK, Phi_phi0, Phi_theta0, Phi_r0]

emri_kwargs = {"T":T, "dt":dt}

SNR_deg = sef_deg.SNRcalc_SEF(*pars_list,**emri_kwargs, use_gpu=use_gpu)

wave ndim: 2
Computing SNR for parameters: (1000000.0, 10.0, 0, 9.5, 0.4, 1.0, 0.1, 1.0471975511965976, 0.7853981633974483, 0.5235987755982988, 0.39269908169872414, 1.0, 0.0, 0.0)


--- Logging error ---
Traceback (most recent call last):
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 1160, in emit
    msg = self.format(record)
          ^^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 999, in format
    return fmt.format(record)
           ^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 703, in format
    record.message = record.getMessage()
                     ^^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 392, in getMessage
    msg = msg % self.args
          ~~~~^~~~~~~~~~~
TypeError: not all arguments converted during string formatting
Call stack:
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, m

In [9]:
m1 = 1e6
m2 = 1e1
a = 0.
p0 = 9.5
e0 = 0.4
xI0 = 1.0
dist = 0.1
qS = np.pi/3
phiS = np.pi/4
qK = np.pi/6
phiK = np.pi/8
Phi_phi0 = 1.0
Phi_theta0 = 0.0
Phi_r0 = 0.0

pars_list = [m1, m2, a, p0, e0, xI0, dist, qS, phiS, qK, phiK, Phi_phi0, Phi_theta0, Phi_r0]

emri_kwargs = {"T":T, "dt":dt}

SNR_nodeg = sef_nodeg.SNRcalc_SEF(*pars_list,**emri_kwargs, use_gpu=use_gpu)

wave ndim: 2
Computing SNR for parameters: (1000000.0, 10.0, 0.0, 9.5, 0.4, 1.0, 0.1, 1.0471975511965976, 0.7853981633974483, 0.5235987755982988, 0.39269908169872414, 1.0, 0.0, 0.0)


--- Logging error ---
Traceback (most recent call last):
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 1160, in emit
    msg = self.format(record)
          ^^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 999, in format
    return fmt.format(record)
           ^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 703, in format
    record.message = record.getMessage()
                     ^^^^^^^^^^^^^^^^^^^
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/logging/__init__.py", line 392, in getMessage
    msg = msg % self.args
          ~~~~^~~~~~~~~~~
TypeError: not all arguments converted during string formatting
Call stack:
  File "/media/shubham/Ubuntu-HDD/miniconda3/envs/torchenv/lib/python3.12/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, m

In [12]:
#the degradation factor d appears as a sqrt(d) in the ratio between the two SNRs. You can calculate why that is yourself :)

print(SNR_nodeg/SNR_deg)

1.414213562373095
