In [None]:
import numpy as np
import matplotlib.pyplot as plt
#%config InlineBackend.figure_format = 'retina'  # For sharper figures, but it takes more time
import scipy as sp
from copy import deepcopy 

from lisatools.utils.constants import *
from lisatools.sensitivity  import SensitivityMatrix, AET1SensitivityMatrix, AE1SensitivityMatrix
from lisatools.analysiscontainer import AnalysisContainer
from lisatools.datacontainer import DataResidualArray

from bbhx.waveforms.phenomhm import PhenomHMAmpPhase
from bbhx.waveformbuild import BBHWaveformFD
from bbhx.utils.interpolate import CubicSplineInterpolant

import noise_generation as noise_generation
from tools.LISASimulator import LISASimulator
from tools.likelihood import get_dh, get_hh
import tools.likelihood as likelihood

from tools.time_freq_likelihood import TimeFreqLikelihood

Generate simulated LISA data using LISASimulator. Returns:
1. Time domain data
2. Frequency domain data
3. Frequency array that can be used to generate templates using bbhx
4. Time array if needed
5. Sensitivity Matrix for inner product calculations

In [None]:
Tobs = YRSID_SI
dt = 5.
include_T_channel = False # Set to True if you want to include the T channel in the simulation, otherwise only A and E channels will be included.

wave_gen = BBHWaveformFD(amp_phase_kwargs=dict(run_phenomd=False))
sim = LISASimulator(Tobs=Tobs, dt=dt, wave_gen=wave_gen, include_T_channel=include_T_channel)

f_ref = 0.0
phi_ref = 0.0
m1 = 9e5
m2 = 5e5
a1 = 0.2
a2 = 0.4
dist = 10e3 * PC_SI * 1e6  # 3e3 in Mpc
inc = np.pi/3
beta = np.pi/4.
lam = np.pi/5.
psi = np.pi/6.
t_ref = 0.5 * YRSID_SI  # in the SSB reference frame

parameters = np.array([m1, m2, a1, a2, dist, phi_ref, f_ref, inc, lam, beta, psi, t_ref])

modes = [(2,2), (2,1), (3,3), (3,2), (4,4), (4,3)]
waveform_kwargs = dict(direct=False, fill=True, squeeze=False, length=1024)

data_t, data_f, f_array, t_array, sens_mat = sim(seed = 42, parameters=parameters, modes=modes, waveform_kwargs=waveform_kwargs)

In [None]:
analysis = TimeFreqLikelihood(data_t=data_t, wave_gen=wave_gen, dt=dt)
analysis.get_stft_of_data()
analysis.calculate_time_frequency_likelihood(
    m1*10,
    m2, 
    a1,
    a2,
    dist, 
    phi_ref,
    f_ref, 
    inc,
    lam,
    beta,
    psi,
    t_ref,
    waveform_kwargs=dict(
        length=1024, 
        combine=False,  # TODO: check this
        direct=False,
        fill=True,
        squeeze=True,
        freqs=f_array,
        modes=modes
    )
)

In [None]:
best = 0.08640623615147122
m1_10 = -0.024209015459648226
m1_1000 = -9.35790536288313e-10


In [None]:
# imports
from eryn.ensemble import EnsembleSampler
from eryn.prior import ProbDistContainer, uniform_dist
from eryn.state import State

In [None]:
def wrapper_likelihood(x, fixed_parameters, freqs, analysis, **kwargs):
    all_parameters = np.zeros(12)
    mT = x[0]
    q = x[1]
    all_parameters[0] = mT / (1 + q)
    all_parameters[1] = mT * q / (1 + q)
    all_parameters[5] = x[2]
    all_parameters[-1] = x[3]

    all_parameters[np.array([2, 3, 4, 6, 7, 8, 9, 10])] = fixed_parameters

    ll = analysis.calculate_time_frequency_likelihood(
        *all_parameters,
        waveform_kwargs=dict(
            length=1024, 
            combine=False,  # TODO: check this
            direct=False,
            fill=True,
            squeeze=True,
            freqs=freqs
        ),
    )
    return ll

In [None]:
# clear (for internal clearing of answers)
 
priors = {"mbh": ProbDistContainer({
    0: uniform_dist(9e5, 5e6),
    1: uniform_dist(0.05, 0.999999),
    2: uniform_dist(0.0, 2 * np.pi),
    3: uniform_dist(0.0, Tobs + 24 * 3600.0),
})}

injection_params = np.array([
    m1 + m2,
    m2 / m1,
    a1,
    a2,
    dist, 
    phi_ref,
    f_ref, 
    inc,
    lam,
    beta,
    psi,
    t_ref
])

fixed_parameters = np.array([
    a1,
    a2,
    dist, 
    f_ref, 
    inc,
    lam,
    beta,
    psi,
])

periodic = {"mbh": {2: 2 * np.pi}}

ntemps = 10
nwalkers = 32
ndims = {"mbh": 4}
sampler = EnsembleSampler(
    nwalkers,
    ndims,
    wrapper_likelihood,
    priors,
    args=(fixed_parameters, f_array, analysis),
    branch_names=["mbh"],
    tempering_kwargs=dict(ntemps=ntemps),
    nleaves_max=dict(mbh=1),
    periodic=periodic
)

In [None]:
# x[None] adds a new dimension at the front: It’s a standard trick to reshape arrays for broadcasting or stacking purposes.
 
injection_params_sub = np.array([m1 + m2, m2 / m1, phi_ref, t_ref])
start_params = injection_params_sub[None, None, None, :] * (1 + 1e-7 * np.random.randn(ntemps, nwalkers, 1, injection_params_sub.shape[0]))
start_state = State({"mbh": start_params})
sampler.compute_log_prior(start_state.branches_coords)
sampler.run_mcmc(start_state, 10, progress=True)

In [None]:
from chainconsumer import Chain, ChainConsumer, make_sample, Truth
import pandas as pd
samples = sampler.get_chain()["mbh"][:, 0].reshape(-1, 4)
df = pd.DataFrame(samples, columns=["mT", "q", "lam", "beta"])
c = ChainConsumer()
c.add_chain(Chain(samples=df, name="An Example Contour"))
c.add_truth(Truth(location={"mT": injection_params_sub[0], "q": injection_params_sub[1], "lam": injection_params_sub[2], "beta": injection_params_sub[3]}))
fig = c.plotter.plot()

In [None]:
f,t,Zxx_data_A=sp.signal.stft(data_t[0], fs=1./dt, nperseg=15000)
f,t,Zxx_data_E=sp.signal.stft(data_t[1], fs=1./dt, nperseg=15000)

In [None]:
#all_parameters = np.array([m1, m2, a1, a2, dist, phi_ref, f_ref, inc, lam, beta, psi, t_ref], modes)
parameters_new = deepcopy(parameters)

# Modify the parameters as needed
parameters_new[0] = 1.2 * m1  # Example: increase m1 by 20%
parameters_new[1] = 0.8 * m2  # Example: decrease m2 by 20%

template_f = wave_gen(*parameters_new,freqs=sim.freq, modes=modes, **waveform_kwargs)[0]
template_f = template_f[:2] # remove T channel
template_t = np.fft.irfft(template_f, axis=-1)

In [None]:
f,t,Zxx_A=sp.signal.stft(template_t[0], fs=1./dt, nperseg=15000)
f,t,Zxx_E=sp.signal.stft(template_t[1], fs=1./dt, nperseg=15000)
df = f[1] - f[0]
f[0]=f[1]
sens_mat_new = AE1SensitivityMatrix(f).sens_mat
power_A = np.abs(Zxx_A)**2
power_E = np.abs(Zxx_E)**2

weighted_power_A = power_A / sens_mat_new[0][:, np.newaxis]
weighted_power_E = power_E / sens_mat_new[1][:, np.newaxis]

hh_A = np.sum(weighted_power_A)
hh_E = np.sum(weighted_power_E)

print((hh_A + hh_E)*4*df)

In [None]:
dh_total = 0.0
for i in range(len(t)):
    numerator_A = (np.abs(Zxx_A[:, i])**2)/sens_mat_new[0]
    numerator_E = (np.abs(Zxx_E[:, i])**2)/sens_mat_new[1]
    dh_A = np.sum(numerator_A) * df * 4
    dh_E = np.sum(numerator_E) * df * 4
    dh_total += dh_A + dh_E
print(dh_total)


In [None]:
Zxx_data_A[0].shape

In [None]:
dh_total = 0.0
for i in range(len(Zxx_data_A[0])):
    numerator_A = np.real(Zxx_data_A[:, i].conj() * Zxx_A[:, i] / sens_mat_new[0])
    numerator_E = np.real(Zxx_data_E[:, i].conj() * Zxx_E[:, i] / sens_mat_new[1])
    dh_A = np.sum(numerator_A) * df * 4
    dh_E = np.sum(numerator_E) * df * 4
    dh_total += dh_A + dh_E
print(dh_total)


In [None]:
# Element-wise multiply template and conjugate data, divide by sensitivity (freq axis)
weighted_A = (Zxx_A * np.conj(Zxx_data_A)) / sens_mat_new[0][:, np.newaxis]  # shape (n_freq, n_time)
weighted_E = (Zxx_E * np.conj(Zxx_data_E)) / sens_mat_new[1][:, np.newaxis]

# Sum over freq and time axes and multiply by constants
dh_total = 4 * df * np.sum(weighted_A.real + weighted_E.real)

print(dh_total)


In [None]:
x = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
y = np.array([10,10,100,1000])
x/y[:, np.newaxis]

In [None]:
inner_product_A = Zxx_A * np.conj(Zxx_A) / sens_mat_new[0][:, np.newaxis]
inner_product_E = Zxx_E * np.conj(Zxx_E) / sens_mat_new[1][:, np.newaxis]
dh_A = np.sum(inner_product_A, axis=-1)
dh_E = np.sum(inner_product_E, axis=-1)
inner_product = (np.real(dh_A) + np.real(dh_E)) * 4 * df
print(inner_product)

In [None]:
dh_A

In [None]:
power = np.abs(Zxx_A)**2  # shape: (7501, 843)

# Step 2: divide by the sensitivity matrix (broadcasts across time axis)
weighted_power = power / sens_mat_A[:, np.newaxis]  # shape: (7501, 843)

# Step 3: sum over all frequencies and times
inner_product = np.sum(weighted_power)
print(f"Inner product: {inner_product}")

In [None]:
np.sum(np.real(np.divide((Zxx[0].conj() * Zxx[0]) , np.array(AE1SensitivityMatrix(f)[0].T)[:, np.newaxis]))*4*df)

In [None]:
analysis = TimeFreqLikelihood()

In [None]:
def wrapper_likelihood(variable_parameters, fixed_parameters, freqs, analysis, **kwargs):
    all_parameters = np.zeros(12)
    all_parameters[0] = variable_parameters[0]
    all_parameters[1] = variable_parameters[1]
    all_parameters[5] = variable_parameters[2]
    all_parameters[-1] = variable_parameters[3]

    all_parameters[np.array([2, 3, 4, 6, 7, 8, 9, 10])] = fixed_parameters

    ll = analysis.calculate_signal_likelihood(
        *all_parameters,
        waveform_kwargs=dict(
            length=1024, 
            combine=False,  # TODO: check this
            direct=False,
            fill=True,
            squeeze=True,
            freqs=freqs
        ),
        source_only=True
        # data_arr_kwargs=dict(f_arr=freqs)
    )
    return ll

In [None]:
snr_from_lisatools = sim.SNR_optimal_lisatools()
snr_my_code = sim.SNR_optimal()
print(snr_from_lisatools, snr_my_code)

In [None]:
data = sim.signal_with_noise_f
f_array = np.fft.rfftfreq(sim.signal_with_noise_t.shape[2])  # returns the correct frequency array for the signal with noise
f_array[0] = f_array[1]  # avoid zero frequency
data[0]

In [None]:
#sim.df, len(sim.signal_with_noise_t[0,0])*sim.dt, sim.Tobs, sim.Tobs / sim.dt, sim.time.shape, sim.freq

In [None]:
data[0], template_f, sim.sens_mat.sens_mat

In [None]:
dh = get_dh(data[0], template_f, sens_mat=sim.sens_mat, df=f_array[2] - f_array[1])
hh = get_hh(template_f, sens_mat=sim.sens_mat, df=f_array[2] - f_array[1])
dh/ np.sqrt(hh)