In [None]:
from typing import Callable, Iterator, Tuple
import chex
import jax

import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import jaxopt
import h5py
import pandas as pd
from scipy import signal, interpolate
import sxs
import glob

from math import pi, log
from ripple.typing import Array
from scipy.optimize import minimize, minimize_scalar

from numpy import random, abs
from ripple.waveforms import IMRPhenomD, IMRPhenomD_utils
from ripple.waveforms.IMRPhenomD import *
from ripple.waveforms.IMRPhenomD_utils import get_coeffs, get_transition_frequencies
from ripple import ms_to_Mc_eta, Mc_eta_to_ms
from jax import grad, vmap, scipy
from functools import partial
import time
from tqdm import tqdm
import json

PhenomD_coeff_table = jnp.array(
    [
        [  # rho1 (element 0)
            3931.8979897196696,
            -17395.758706812805,
            3132.375545898835,
            343965.86092361377,
            -1.2162565819981997e6,
            -70698.00600428853,
            1.383907177859705e6,
            -3.9662761890979446e6,
            -60017.52423652596,
            803515.1181825735,
            -2.091710365941658e6,
        ],
        [  # rho2 (element 1)
            -40105.47653771657,
            112253.0169706701,
            23561.696065836168,
            -3.476180699403351e6,
            1.137593670849482e7,
            754313.1127166454,
            -1.308476044625268e7,
            3.6444584853928134e7,
            596226.612472288,
            -7.4277901143564405e6,
            1.8928977514040343e7,
        ],
        [  # rho3 (element 2)
            83208.35471266537,
            -191237.7264145924,
            -210916.2454782992,
            8.71797508352568e6,
            -2.6914942420669552e7,
            -1.9889806527362722e6,
            3.0888029960154563e7,
            -8.390870279256162e7,
            -1.4535031953446497e6,
            1.7063528990822166e7,
            -4.2748659731120914e7,
        ],
        [  # v2 (element 3)
            0.8149838730507785,
            2.5747553517454658,
            1.1610198035496786,
            -2.3627771785551537,
            6.771038707057573,
            0.7570782938606834,
            -2.7256896890432474,
            7.1140380397149965,
            0.1766934149293479,
            -0.7978690983168183,
            2.1162391502005153,
        ],
        [  # gamma1 (element 4)
            0.006927402739328343,
            0.03020474290328911,
            0.006308024337706171,
            -0.12074130661131138,
            0.26271598905781324,
            0.0034151773647198794,
            -0.10779338611188374,
            0.27098966966891747,
            0.0007374185938559283,
            -0.02749621038376281,
            0.0733150789135702,
        ],
        [  # gamma2 (element 5)
            1.010344404799477,
            0.0008993122007234548,
            0.283949116804459,
            -4.049752962958005,
            13.207828172665366,
            0.10396278486805426,
            -7.025059158961947,
            24.784892370130475,
            0.03093202475605892,
            -2.6924023896851663,
            9.609374464684983,
        ],
        [  # gamma3 (element 6)
            1.3081615607036106,
            -0.005537729694807678,
            -0.06782917938621007,
            -0.6689834970767117,
            3.403147966134083,
            -0.05296577374411866,
            -0.9923793203111362,
            4.820681208409587,
            -0.006134139870393713,
            -0.38429253308696365,
            1.7561754421985984,
        ],
        [  # sig1 (element 7)
            2096.551999295543,
            1463.7493168261553,
            1312.5493286098522,
            18307.330017082117,
            -43534.1440746107,
            -833.2889543511114,
            32047.31997183187,
            -108609.45037520859,
            452.25136398112204,
            8353.439546391714,
            -44531.3250037322,
        ],
        [  # sig2 (element 8)
            -10114.056472621156,
            -44631.01109458185,
            -6541.308761668722,
            -266959.23419307504,
            686328.3229317984,
            3405.6372187679685,
            -437507.7208209015,
            1.6318171307344697e6,
            -7462.648563007646,
            -114585.25177153319,
            674402.4689098676,
        ],
        [  # sig3 (element 9)
            22933.658273436497,
            230960.00814979506,
            14961.083974183695,
            1.1940181342318142e6,
            -3.1042239693052764e6,
            -3038.166617199259,
            1.8720322849093592e6,
            -7.309145012085539e6,
            42738.22871475411,
            467502.018616601,
            -3.064853498512499e6,
        ],
        [  # sig4 (element 10)
            -14621.71522218357,
            -377812.8579387104,
            -9608.682631509726,
            -1.7108925257214056e6,
            4.332924601416521e6,
            -22366.683262266528,
            -2.5019716386377467e6,
            1.0274495902259542e7,
            -85360.30079034246,
            -570025.3441737515,
            4.396844346849777e6,
        ],
        [  # beta1 (element 11)
            97.89747327985583,
            -42.659730877489224,
            153.48421037904913,
            -1417.0620760768954,
            2752.8614143665027,
            138.7406469558649,
            -1433.6585075135881,
            2857.7418952430758,
            41.025109467376126,
            -423.680737974639,
            850.3594335657173,
        ],
        [  # beta2 (element 12)
            -3.282701958759534,
            -9.051384468245866,
            -12.415449742258042,
            55.4716447709787,
            -106.05109938966335,
            -11.953044553690658,
            76.80704618365418,
            -155.33172948098394,
            -3.4129261592393263,
            25.572377569952536,
            -54.408036707740465,
        ],
        [  # beta3 (element 13)
            -0.000025156429818799565,
            0.000019750256942201327,
            -0.000018370671469295915,
            0.000021886317041311973,
            0.00008250240316860033,
            7.157371250566708e-6,
            -0.000055780000112270685,
            0.00019142082884072178,
            5.447166261464217e-6,
            -0.00003220610095021982,
            0.00007974016714984341,
        ],
        [  # a1 (element 14)
            43.31514709695348,
            638.6332679188081,
            -32.85768747216059,
            2415.8938269370315,
            -5766.875169379177,
            -61.85459307173841,
            2953.967762459948,
            -8986.29057591497,
            -21.571435779762044,
            981.2158224673428,
            -3239.5664895930286,
        ],
        [  # a2 (element 15)
            -0.07020209449091723,
            -0.16269798450687084,
            -0.1872514685185499,
            1.138313650449945,
            -2.8334196304430046,
            -0.17137955686840617,
            1.7197549338119527,
            -4.539717148261272,
            -0.049983437357548705,
            0.6062072055948309,
            -1.682769616644546,
        ],
        [  # a3 (element 16)
            9.5988072383479,
            -397.05438595557433,
            16.202126189517813,
            -1574.8286986717037,
            3600.3410843831093,
            27.092429659075467,
            -1786.482357315139,
            5152.919378666511,
            11.175710130033895,
            -577.7999423177481,
            1808.730762932043,
        ],
        [  # a4 (element 17)
            -0.02989487384493607,
            1.4022106448583738,
            -0.07356049468633846,
            0.8337006542278661,
            0.2240008282397391,
            -0.055202870001177226,
            0.5667186343606578,
            0.7186931973380503,
            -0.015507437354325743,
            0.15750322779277187,
            0.21076815715176228,
        ],
        [  # a5 (element 18)
            0.9974408278363099,
            -0.007884449714907203,
            -0.059046901195591035,
            1.3958712396764088,
            -4.516631601676276,
            -0.05585343136869692,
            1.7516580039343603,
            -5.990208965347804,
            -0.017945336522161195,
            0.5965097794825992,
            -2.0608879367971804,
        ],
    ]
)

In [None]:
@jax.jit
def _get_coeffs(theta: Array, table: Array) -> Array:
    # Retrives the coefficients needed to produce the waveform

    m1, m2, chi1, chi2 = theta
    m1_s = m1 * gt
    m2_s = m2 * gt
    M_s = m1_s + m2_s
    eta = m1_s * m2_s / (M_s ** 2.0)

    # Definition of chiPN from lalsuite
    chi_s = (chi1 + chi2) / 2.0
    chi_a = (chi1 - chi2) / 2.0
    seta = (1 - 4 * eta) ** (1 / 2)
    chiPN = chi_s * (1 - 76 * eta / 113) + seta * chi_a

    coeff = (
        table[:, 0]
        + table[:, 1] * eta
        + (chiPN - 1.0)
        * (
            table[:, 2]
            + table[:, 3] * eta
            + table[:, 4] * (eta ** 2.0)
        )
        + (chiPN - 1.0) ** 2.0
        * (
            table[:, 5]
            + table[:, 6] * eta
            + table[:, 7] * (eta ** 2.0)
        )
        + (chiPN - 1.0) ** 3.0
        * (
            table[:, 8]
            + table[:, 9] * eta
            + table[:, 10] * (eta ** 2.0)
        )
    )

    # FIXME: Change to dictionary lookup
    return coeff

noise_dataframe = pd.read_csv('/content/drive/MyDrive/Colab/final_data/aLIGOZeroDetHighPower_fs.dat', delimiter=' ', header = None)
noise_curve = noise_dataframe.values[:, 2]
noise_f = noise_dataframe.values[:, 0]

@jax.jit
def inner(h1: Array, h2: Array, f):
    df = f[1] - f[0]
    # noise = jnp.interp(f, noise_f, noise_curve)
    noise = 1
    cross_multi = jnp.real(h1 * jnp.conj(h2)) / noise
    return 4 * jnp.sum(cross_multi * df)

@jax.jit
def mismatch(h1: Array, h2: Array, f):
    return 1 - (inner(h1, h2, f) / jnp.sqrt(inner(h1, h1, f) * inner(h2, h2, f)))

# @jax.jit
def loss(lambdas: Array, intrin: Array, extrin: Array, f: Array, NR_complex: Array) -> Array:
    f_sep = int(len(f) / 100)
    
    NR_phase = -jnp.unwrap(jnp.angle(NR_complex))
    IMR = IMRPhenomD._gen_IMRPhenomD(f, intrin, extrin, _get_coeffs(intrin, lambdas), f[0])
    IMR_phase = -jnp.unwrap(jnp.angle(IMR))
    phase_diff = NR_phase - IMR_phase
    
    A = jnp.vstack([f, jnp.ones(len(f))]).T
    two_pi_t0, phi0 = jnp.linalg.lstsq(A, phase_diff, rcond=None)[0]
    
    NR_shifted = NR_complex * jnp.exp(1j * (two_pi_t0 * f + phi0))
    
    return mismatch(NR_shifted[0::f_sep], IMR[0::f_sep], f[0::f_sep])

import matplotlib
matplotlib.rcParams['figure.figsize'] = (6, 4)

In [None]:
lambdas = np.resize(pd.read_csv('./lambdas/1neg2pos.dat', sep=" ", header=None).values, (11, 211))[-1]
lambdas = np.resize(lambdas, (19, 11))

In [None]:
scale = PhenomD_coeff_table

catalog_list = ['0226', '0219', '2100', '2093', '0325', '0326', '0224', '1492',
       '0225', '0221', '2099', '2098', '0220', '2095', '2094', '1503',
       '1502', '1481', '0436', '0415', '0418', '0327', '0304', '0459',
       '0370', '1124', '2092', '0159', '0218', '0228', '0389', '0328',
       '1475', '0212', '2102', '0070', '0072', '0073', '0001', '0066',
       '0086', '0067', '0071', '0068', '0090', '1137', '0152', '1135',
       '0002', '0150', '0171', '0170', '1132', '0180', '0153', '0160',
       '0148', '1154', '1122', '1155', '1123', '0158', '0149', '2089',
       '2104', '0157', '1153', '0178', '0154', '0176', '0155', '1114',
       '2086', '1141', '0215', '0230', '0151', '1477', '0366', '1134',
       '0172', '0329', '1509', '2096', '1125', '2097', '1506', '0217',
       '1497', '0330', '2085', '0213', '2091', '2087', '1476', '0211',
       '1507', '0156', '0376', '0394', '0462', '0447', '0216', '2090',
       '0227', '2088', '2101', '0214', '0585', '0229', '2103', '1495',
       '1501', '0222', '1499', '1500', '2106', '2083', '2084', '0209',
       '2105', '0210', '0232', '0231', '0004', '0005', '1376', '1498',
       '1351', '0544', '0518', '0396', '1513', '1352', '1496', '0626',
       '0311', '0523', '0198', '0312', '0313', '1353', '0318', '0310',
       '0305', '0314', '0307', '1490', '0465', '0486', '0438', '0559',
       '0475', '0409', '0503', '0535', '0386', '0398', '1223', '0591',
       '0464', '1143', '0466', '1142', '0377', '0525', '0507', '1487',
       '0315', '1508', '1474', '1493', '0306', '1505', '1471', '1482',
       '0625', '1413', '1473', '1511', '1146', '0437', '0404', '0579',
       '0361', '0369', '0392', '0019', '0025', '0441', '0385', '0008',
       '0007', '0593', '0397', '1415', '0440', '0372', '0016', '0014',
       '0013', '0012', '0009', '0194', '1470', '0499', '1480', '1479',
       '1488', '1412', '1491', '1465', '0501', '0423', '0488', '0473',
       '0435', '0355', '0414', '0566', '0512', '0451', '0550', '0382',
       '0402', '0371', '0454', '0552', '0388', '1510', '1416', '0545',
       '1414', '1354', '1469', '1466', '0580', '0530', '1478', '1504',
       '0482', '0239', '0252', '2113', '0248', '2126', '2122', '0332',
       '2117', '0243', '0258', '0399', '0333', '1222', '0410', '0233',
       '0234', '0331', '0407', '0375', '0584', '0257', '2107', '0354',
       '2108', '0334', '0335', '2128', '2121', '0574', '2132', '2131',
       '0245', '2109', '2119', '0599', '1112', '2118', '2127', '2114',
       '2115', '2123', '2116', '2124', '0253', '0247', '0238', '0235',
       '2130', '0448', '0254', '0256', '0244', '0237', '0251', '2120',
       '0249', '0242', '0240', '2111', '0554', '0184', '0255', '2112',
       '0236', '0169', '2129', '0461', '1166', '0513', '0387', '2125',
       '0412', '1164', '2110', '0241', '0250', '1167', '1165', '1148',
       '1147', '1494', '1467', '1459', '1468', '0201', '0631', '1453',
       '1472', '1512', '1454', '0259', '0191', '1462', '1461', '1484',
       '1387', '1456', '2151', '0280', '0261', '2146', '0293', '2133',
       '0260', '0279', '2145', '0274', '0292', '0273', '1221', '2152',
       '2156', '2135', '0263', '0268', '2141', '0264', '2142', '1152',
       '0277', '1151', '0046', '1150', '0285', '0291', '2139', '0045',
       '0047', '0278', '0286', '0267', '2157', '0284', '2136', '1172',
       '0275', '0269', '0287', '2153', '2159', '1175', '1174', '1173',
       '0283', '0281', '0272', '0270', '2140', '2150', '2147', '1178',
       '0262', '1179', '0266', '0290', '0183', '2138', '0168', '0265',
       '2148', '0288', '0289', '2154', '2134', '2149', '2163', '2143',
       '2160', '2162', '2144', '2265', '2137', '2158', '2155', '2161',
       '0271', '0282', '0174', '0036', '0041', '0031', '0040', '0038',
       '1485', '1446', '1447', '1483', '1457', '0200', '0317', '1489',
       '0193', '0294', '1452', '1486', '1458', '1932', '1907', '1966',
       '1962', '1936', '1942', '1911', '1938', '1906', '2018', '1937',
       '1417', '2036', '2013', '1961', '2014', '1418', '0182', '0167',
       '1931', '2040', '1220', '1451', '1450', '1449', '1434', '0190',
       '0295', '1445', '1463', '0112', '0055', '0056', '1111', '0208',
       '0110', '0109', '0187', '0296', '1428', '0197', '1440', '1443',
       '1432', '1438', '1444', '0181', '0166', '1437', '1425', '1436',
       '1424', '1439', '0297', '1464', '0192', '1442', '1435', '1448',
       '0202', '0207', '0298', '1110', '0205', '0204', '0206', '0203',
       '0188', '1427', '0299', '1429', '0195', '1421', '1422', '1426',
       '1420', '1419', '1423', '1375', '1430', '1441', '1433', '1431',
       '0063', '1455', '1460', '0064', '0114', '0186', '0300', '0199',
       '0301', '0189', '1108', '0302', '0196', '0185', '1107', '0303']

# catalog_droplist = ['0156', '0151', '0001', '0152', '0172', '0234', '0235', '0169', '0256', '0257', '1418', '0167', '1417', '1419', '0063', '1426'] # testing waveforms
catalog_droplist = []

M = 50.0
theta_extrinsic = np.array([440.0, 0.0, 0.0])
logdiff_test = []
logdiff_train = []

theta_intrinsic_list = []
for catalog_number in catalog_list:
    with open('./NR_waveform/NR_'+str(catalog_number)+'_metadata.json') as file:
        metadata = json.load(file)
        q = round(metadata['reference_mass_ratio'] * 1000) / 1000
        chi1 = metadata['reference_dimensionless_spin1'][2]
        chi2 = metadata['reference_dimensionless_spin2'][2]

        theta_intrinsic = [M * q / (1 + q), M * 1 / (1 + q), chi1, chi2]
        theta_intrinsic_list.append(theta_intrinsic)

with tqdm(total=len(catalog_list)) as pbar:
    original_loss_list_test = []
    new_loss_list_test = []
    original_loss_list_train = []
    new_loss_list_train = []
    for i in range(len(catalog_list)):
        if catalog_list[i] not in catalog_droplist:
            data = pd.read_csv('./NR_waveform/NR_'+str(catalog_list[i])+'.dat', sep=" ", header=None)
            f_uniform = data.values[:, 0]
            NR_waveform = data.values[:, 1] + 1j * data.values[:, 2]
            new_loss = loss(lambdas, theta_intrinsic_list[i], theta_extrinsic, f_uniform, NR_waveform)
            original_loss = loss(PhenomD_coeff_table, theta_intrinsic_list[i], theta_extrinsic, f_uniform, NR_waveform)

            original_loss_list_test.append(original_loss)
            new_loss_list_test.append(new_loss)
            logdiff_test.append(np.log10(new_loss/original_loss))
            pbar.update(1)
        else: 
            data = pd.read_csv('/content/drive/MyDrive/Colab/final_data/NR_waveform/NR_'+str(catalog_list[i])+'.txt', sep=" ", header=None)
            f_uniform = data.values[:, 0]
            NR_waveform = data.values[:, 1] + 1j * data.values[:, 2]
            new_loss = loss(lambdas, theta_intrinsic_list[i], theta_extrinsic, f_uniform, NR_waveform)
            original_loss = loss(PhenomD_coeff_table, theta_intrinsic_list[i], theta_extrinsic, f_uniform, NR_waveform)

            original_loss_list_train.append(original_loss)
            new_loss_list_train.append(new_loss)
            logdiff_train.append(np.log10(new_loss/original_loss))
            
            pbar.update(1)


In [None]:
data = pd.concat([pd.DataFrame(catalog_list), pd.DataFrame(original_loss_list_test), pd.DataFrame(new_loss_list_test)], axis=1)
data.to_csv('./mismatch/1neg2pos.dat', sep=' ', index=False, header=False)