# Define GW generator

In [1]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import math
from scipy import integrate
import scipy.optimize as optim

from tqdm import tqdm
import random
from numpy.fft import fft, fftfreq, ifft

In [2]:
C = 299792458.
G = 6.67e-11
MPC = 3.0857e22
MSUN = 1.989e30
AU = 1.496e11
YEAR = 31558149.763545603

In [3]:
import pycbc.waveform as wf
from scipy.interpolate import InterpolatedUnivariateSpline


# Converting between m1, q, M, eta, Mc, mu, where m1 >= m2, q = m2 / m1 <= 1, eta <= 0.25
def eta_q(q):
    return q / (1. + q) ** 2

def q_eta(eta):
    if eta == 1. / 4.:
        return 1.
    else:
        return (1. - 2. * eta - np.sqrt(1. - 4. * eta)) / 2. / eta

def M_m1_q(m1, q):
    return m1 * (1. + q)

def Mc_m1_q(m1, q):
    m2 = m1 * q
    return (m1 * m2) ** (3. / 5.) / (m1 + m2) ** (1. / 5.)

def mu_m1_q(m1, q):
    m2 = m1 * q
    return m1 * m2 / (m1 + m2)

def q_Mc_mu(Mc, mu):
    temp = (Mc / mu) ** (5. / 2.)
    return np.abs((temp - 2.) - np.sqrt(temp * np.abs(temp - 4.))) / 2

def m1_Mc_mu(Mc, mu):
    q = q_Mc_mu(Mc, mu)
    return mu * (1. + q) / q

def m1_Mc_q(Mc, q):
    return Mc / (q ** 3 / (1. + q)) ** 0.2

def mu_Mc_q(Mc, q):
    m1 = m1_Mc_q(Mc, q)
    return q / (1. + q) * m1

def Mc_M_eta(M, eta):
    return eta ** 0.6 * M

def fISCO(parWF): 
    Mc, mu, tc, phic, D, inc = parWF
    m1 = m1_Mc_mu(Mc, mu)
    q = q_Mc_mu(Mc, mu)
    M = m1 * (1. + q)
    return C ** 3 / 6. ** (3. / 2.) / np.pi / G / M / MSUN

def tf(f, parWF): # leading order PN
    Mc, mu, tc, phic, D, inc = parWF
    Mc = Mc * MSUN
    return tc - (8. * np.pi / 5. * (5. * G * Mc / C ** 3) ** (5. / 8.) * f) ** (-8. / 3.) 

def tISCO(parWF): # leading order PN
    return tf(fISCO(parWF), parWF)


SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(True)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal as _lal


In [4]:
class MBHBWaveform():
    def __init__(self, approx_method = 'SEOBNRv4_opt', f_lower = None): 
        self.approx_method = approx_method
        self.f_lower = f_lower
        
    def __call__(self, Mc, q, a1, a2, tc, phic, D, inc, psi, T=1., dt=10.):
        """ 
        Mc in [MSUN]
        tc in [s]
        D in [Mpc]
        """
        m1 = m1_Mc_q(Mc, q)
        m2 = m1 * q
        Tobs = tc / YEAR
        if self.f_lower == None: # calculate the lower limit of frequency to the leading order
            f_lower = 1.75e-5 * (Mc / 1e6) ** (-5. / 8.) * (Tobs / 10.) ** (-3. / 8.)
            f_lower = f_lower / 2. # for safty
        else:
            f_lower = self.f_lower
              
        # source frame waveform data 
        hp, hc = wf.get_td_waveform(approximant = self.approx_method, mass1 = m1, mass2 = m2, \
                                    spin1z = a1, spin2z = a2, coa_phase = phic, inclination = inc, \
                                    delta_t = dt, f_lower = f_lower, distance = D)
        hp, hc = hp.trim_zeros(), hc.trim_zeros()
        hSp_data = np.array(hp)
        hSc_data = np.array(hc)
        t_data = np.array(hp.sample_times) + tc
        
        
        # get interpolation functions
        hSp_func = InterpolatedUnivariateSpline(x=t_data, y=hSp_data, k=5, ext='zeros')
        hSc_func = InterpolatedUnivariateSpline(x=t_data, y=hSc_data, k=5, ext='zeros')
        
        # convert to SSB frame waveform 
        t = np.arange(0.0, T * YEAR, dt)
        cos2psi = np.cos(2.0 * psi)
        sin2psi = np.sin(2.0 * psi)

        hSp = hSp_func(t)
        hSc = hSc_func(t)

        hp = hSp * cos2psi - hSc * sin2psi
        hc = hSp * sin2psi + hSc * cos2psi

        return hp + 1j * hc

In [5]:
def CutJumpPointsForTDI2(chans, Mc, tc, t):
    """  
    only suitable for Mc in [1e5, 1e7], q in [0.1, 1]
    t, tc in [s]
    Mc in [MSUN]
    """ 
    cut_idx = np.where(t >= (1.4e-8 * Mc * DAY + tc))[0][0] # empirical value
    new_chans = []
    for chan in chans:
        chan[cut_idx:] = 0.
        new_chans.append(chan)
    return new_chans 

In [6]:
import h5py
from fastlisaresponse import pyResponseTDI, ResponseWrapper
# from astropy import units as un

YEAR = 31558149.763545603
SECOND = 1 / YEAR
DAY = 24 * 60 * 60 * SECOND

In [7]:
use_gpu = True
T = 30 * DAY  # duration of the signal
t0 = 10000.0  # time at which signal starts (chops off data at start of waveform where information is not correct)
sampling_frequency = 0.1
dt = 1 / sampling_frequency
total_len = 16000

In [8]:
# order of the langrangian interpolation
order = 25

# orbit_file_esa = "../orbit_files/esa-trailing-orbits.h5"
orbit_file_esa = "TaijiEqualArm.hdf5"

orbit_kwargs_esa = dict(orbit_file=orbit_file_esa)

# 1st or 2nd or custom (see docs for custom)
tdi_gen = "2nd generation"
# tdi_gen = '1st generation'

index_lambda = 9
index_beta = 10

tdi_kwargs_esa = dict(
    orbit_kwargs=orbit_kwargs_esa, order=order, tdi=tdi_gen, tdi_chan="AET",
)

mbhb = MBHBWaveform()

mbhb_taiji = ResponseWrapper(
    mbhb,
    T + 2*(t0)*SECOND,
    dt,
    index_lambda,
    index_beta,
    t0=t0,
    flip_hx=False,  # set to True if waveform is h+ - ihx
    use_gpu=use_gpu,
    remove_sky_coords=True,  # True if the waveform generator does not take sky coordinates
    is_ecliptic_latitude=True,  # False if using polar angle (theta)
    remove_garbage=True,  # removes the beginning of the signal that has bad information
    **tdi_kwargs_esa,
)

In [9]:
class Constant:
    MSUN_SI = 1.98848e30
    YRSID_SI = 31558149.763545603
    AU_SI = 149597870700.0
    C_SI = 299792458.0
    G_SI = 6.674080e-11
    GMSUN = 1.3271244210789466e20
    MTSUN_SI = 4.925491025873693e-06
    MRSUN_SI = 1476.6250615036158
    PC_SI = 3.0856775814913674e16
    PI = 3.141592653589793238462643383279502884
    PI_2 = 1.570796326794896619231321691639751442
    PI_3 = 1.047197551196597746154214461093167628
    PI_4 = 0.785398163397448309615660845819875721
    SQRTPI = 1.772453850905516027298167483341145183
    SQRTTWOPI = 2.506628274631000502415765284811045253
    INVSQRTPI = 0.564189583547756286948079451560772585
    INVSQRTTWOPI = 0.398942280401432677939946059934381868
    GAMMA = 0.577215664901532860606512090082402431
    SQRT2 = 1.414213562373095048801688724209698079
    SQRT3 = 1.732050807568877293527446341505872367
    SQRT6 = 2.449489742783178098197284074705891392
    INVSQRT2 = 0.707106781186547524400844362104849039
    INVSQRT3 = 0.577350269189625764509148780501957455
    INVSQRT6 = 0.408248290463863016366214012450981898
    F0 = 3.168753578687779e-08
    Omega0 = 1.9909865927683788e-07
    Omegam = 0.3175
    Omegalam = 0.6825
    H0 = 67.1
    H0_SI = H0 * 1000 / (1e6 * PC_SI)
    EPS = 1e-8
    # L_SI = 2.5e9
    # eorbit = 0.004824185218078991
    # ConstOmega = 1.99098659277e-7

In [10]:
class Cosmology(object):
    @staticmethod
    def H(zp, w):
        fn = 1.0 / (Constant.H0 * math.sqrt(Constant.Omegam * math.pow(1.0 + zp, 3.0) + Constant.Omegalam * math.pow(1.0 + zp, 3.0 * w)))
        return fn

    @staticmethod
    def DL(zup, w):
        """
        Usage: DL(3,w=0)[0]
        """
        pd = integrate.quad(Cosmology.H, 0.0, zup, args=(w))[0]
        res = (1.0 + zup) * pd  # in Mpc
        return res * Constant.C_SI * 1.0e-3, pd * Constant.C_SI * 1.0e-3

    @staticmethod
    def findz(zm, dlum, ww):
        """f-n needed for finding z for given DL, w"""
        dofzm = Cosmology.DL(zm, ww)
        return dlum - dofzm[0]

    @staticmethod
    def zofDl(DL, w, tolerance):
        """computes z(DL, w), Assumes DL in Mpc"""
        if tolerance > 1.0e-4:
            tolerance = 1.0e-4
        zguess = DL / 6.6e3
        zres = optim.fsolve(Cosmology.findz, zguess, args=(DL, 0.0), xtol=tolerance)
        return zres

# Generate {numsample} data

In [11]:
# constants
C = 299792458.
G = 6.67e-11
MPC = 3.0857e22
MSUN = 1.989e30
AU = 1.496e11

F_LASER = 3e14 

# Taiji params
L_TJ = 3e9
SACC_TJ = 3e-15
SOPT_TJ = 8e-12

# LISA params
L_LISA = 2.5e9
SACC_LISA = 3e-15
SOPT_LISA = 1e-11

# Cosmology
OmegaM = 0.3111
H0 = 67.66

In [12]:
class PSD_y():
    """
    PSD in the fractional frequency difference unit
    """
    def __init__(self, sacc = SACC_TJ, sopt = SOPT_TJ, L = L_TJ): # default Taiji [sacc] = acceleration, [sopt] = distance
        self.sa = sacc
        self.so = sopt
        self.L = L

    def PSD_Sa(self, f): 
        u = 2. * np.pi * f * self.L / C
        return (self.sa * self.L / u / C ** 2) ** 2 * (1. + (0.4e-3 / f) ** 2) * (1. + (f / 8e-3) ** 4)

    def PSD_So(self, f):
        u = 2. * np.pi * f * self.L / C
        return (u * self.so / self.L) ** 2 * (1. + (2e-3 / f) ** 4)

    def PSD_X(self, f): # self.sa, self.so are the asd of acc and opt noise, WTN and wang and Vallisneri
        u = 2. * np.pi * f * self.L / C
        Sa = self.PSD_Sa(f)
        So = self.PSD_So(f)
        return Sa * (8. * (np.sin(2. * u)) ** 2 + 32. * (np.sin(u)) ** 2) \
                + 16. * So * (np.sin(u)) ** 2
    
    def PSD_X2(self, f):
        u = 2. * np.pi * f * self.L / C
        Sa = self.PSD_Sa(f)
        So = self.PSD_So(f)
        return 64. * (np.sin(2. * u)) ** 2 * (np.sin(u)) ** 2 * (So + (3. + np.cos(2. * u)) * Sa)        

    def PSD_A(self, f): # WTN
        u = 2. * np.pi * f * self.L / C
        Sa = self.PSD_Sa(f)
        So = self.PSD_So(f)
        return 8. * So * (2. + np.cos(u)) * (np.sin(u)) ** 2 \
                + 16. * Sa * (3. + 2. * np.cos(u) + np.cos(2. * u)) * (np.sin(u)) ** 2
    
    def PSD_A2(self, f):
        u = 2. * np.pi * f * self.L / C
        Sa = self.PSD_Sa(f)
        So = self.PSD_So(f)
        return (8. * So * (2. + np.cos(u)) * (np.sin(u)) ** 2 \
                + 16. * Sa * (3. + 2. * np.cos(u) + np.cos(2. * u)) * (np.sin(u)) ** 2) * 4. * np.sin(2. * u) ** 2

    def PSD_T(self, f): # WTN
        u = 2. * np.pi * f * self.L / C
        Sa = self.PSD_Sa(f)
        So = self.PSD_So(f)
        return 16. * So * (1. - np.cos(u)) * (np.sin(u)) ** 2 \
                + 128. * Sa * (np.sin(u)) ** 2 * (np.sin(u / 2.)) ** 4

In [13]:
from numpy.fft import fft, fftfreq, ifft
def NoiseFromPSD(psd_func, t_arr):
    """
    generate noise in the time domain at time array t_arr
    psd_func should be a function which can be called with:  psd_func(freq array) 
    """
    tsample = t_arr[1] - t_arr[0]
    N = len(t_arr)
    sigma0 = np.sqrt(1. / 2. / tsample)
    n0 = np.random.normal(0, sigma0, N)
    n0f = fft(n0)
    f_arr = fftfreq(N, d = tsample)
    asd_arr = np.sqrt(psd_func(np.abs(f_arr[1:])))
    asd_arr = np.insert(asd_arr, 0, 0)  
    n1f = n0f * asd_arr
    n1 = np.real(ifft(n1f))
    return n1

In [14]:
t = np.arange(0, total_len, 1) / sampling_frequency
PSD = PSD_y()
psd_func = PSD.PSD_A2

In [15]:
# signal parameters
f_arr_signal = fftfreq(total_len, d = 10)
psd_signal = psd_func(f_arr_signal)
psd_signal = psd_signal[:total_len // 2 + 1]

  return (self.sa * self.L / u / C ** 2) ** 2 * (1. + (0.4e-3 / f) ** 2) * (1. + (f / 8e-3) ** 4)
  return (u * self.so / self.L) ** 2 * (1. + (2e-3 / f) ** 4)
  return (u * self.so / self.L) ** 2 * (1. + (2e-3 / f) ** 4)
  + 16. * Sa * (3. + 2. * np.cos(u) + np.cos(2. * u)) * (np.sin(u)) ** 2) * 4. * np.sin(2. * u) ** 2


In [16]:
def tukey(M, alpha=0.5):
    n = np.arange(0, M)
    width = int(np.floor(alpha * (M - 1) / 2.0))
    n1 = n[0 : width + 1]
    n2 = n[width + 1 : M - width - 1]
    n3 = n[M - width - 1 :]

    w1 = 0.5 * (1 + np.cos(np.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
    w2 = np.ones(n2.shape)
    w3 = 0.5 * (1 + np.cos(np.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / (M - 1))))
    w = np.concatenate((w1, w2, w3))

    return np.array(w[:M])

In [17]:
def whiten_data(signal, psd):
    idx = np.argwhere(psd >= 1e-45)
    win = tukey(len(signal), alpha=1.0 / 8.0)
    xf = np.fft.rfft(win * signal)
    invpsd = np.zeros(psd.size)
    invpsd[idx] = 1.0 / psd[idx]
    xf *= np.sqrt(invpsd)
    xf[0] = 0.0
    x = np.fft.irfft(xf)
    return x

In [18]:
def get_whiten_mf_snr(data, T_obs, fs=0.1, fmin=1e-5):
    """
    computes the snr of a signal given a PSD starting from a particular frequency index
    """
    N = int(T_obs * fs)
    df = 1.0 / T_obs
    dt = 1.0 / fs
    fidx = int(fmin / df)

    win = tukey(N, alpha=1.0 / 8.0)

    xf = np.fft.rfft(data * win) * dt
    invpsd = np.ones(xf.size)

    SNRsq = (4.0 * np.sum((np.abs(xf[fidx:]) ** 2) * invpsd[fidx:]) * df)
    return np.sqrt(SNRsq)

In [19]:
# mkdir for saving data
import os
if not os.path.exists('../data/GW/clean'):
    os.makedirs('../data/GW/clean')
if not os.path.exists('../data/GW/noisy'):
    os.makedirs('../data/GW/noisy')

In [20]:
numsample = 10000 # modify this to contral num

In [21]:
for i in tqdm(range(numsample)):
    Mc=np.random.uniform(1e6, 1e8)
    q=np.random.uniform(0.01, 1)
    a1=np.random.uniform(-0.99, 0.99)
    a2=np.random.uniform(-0.99, 0.99)
    # we set the tc in half of the total time
    tc = 15*24*60*50 + 10000
    tc_idx = int(tc / dt)
    phic=np.random.uniform(0, 2*np.pi)
    inc=np.random.uniform(0, np.pi)
    psi=np.random.uniform(0, np.pi)
    lam = np.random.uniform(0, 2 * np.pi) 
    beta = np.random.uniform(-np.pi / 2, np.pi / 2)

    # distance
    dgpc = np.random.uniform(10, 15)
    D = dgpc*1000

    chans = mbhb_taiji(Mc, q, a1, a2, tc, phic, D, inc, psi, lam, beta, dt=dt)

    # cut the error data
    # times = np.arange(len(chans[0])) * dt + t0
    # chans = CutJumpPointsForTDI2(chans, Mc, tc, times)
    
    chans = cp.array(chans[0])
    # segment (we only need tc 16000s)
    chans = chans[tc_idx-12000:tc_idx+4000]


    chans_numpy = chans.get()
    raw_signal = chans_numpy
    raw_noise = NoiseFromPSD(psd_func, t)
    whiten_signal = whiten_data(chans_numpy, psd_signal)
    whiten_noise = whiten_data(raw_noise, psd_signal)

    # rescale to target snr
    now_snr = get_whiten_mf_snr(whiten_signal, total_len*10)
    target_snr = 50

    # log in tqdm

    whiten_signal = whiten_signal * target_snr / now_snr

    np.save(f'../data/GW/clean/{i}.npy', whiten_signal)
    np.save(f'../data/GW/noisy/{i}.npy', whiten_signal + whiten_noise)


100%|██████████| 10000/10000 [1:16:20<00:00,  2.18it/s]
