# Step 1: Generate gravitational waves

## Import libs

In [None]:
import numpy as np
import pyfstat
import os
import warnings
from pyfstat.utils import get_sft_as_arrays
from tqdm import tqdm
import matplotlib.pyplot as plt
import gc
from joblib import Parallel, delayed
import multiprocessing
warnings.filterwarnings("ignore")

## Config

In [None]:
class CFG:
    n_parallel_jobs = 5 # number of parallel jobs
    n_instance_per_job = 3000 # number of instances for per parallel job

## Function - generate gravitational waves

In [None]:
def generate(signal_rate=0.01, parallel_job=0)->np.ndarray:

    # parameters
    F0 = np.random.uniform(48, 502)
    F1 = 10**np.random.uniform(-12, -8) if np.random.uniform(0, 1) < 0.5 else -10**np.random.uniform(-12, -8)
    Band = 0.33
    Alpha = 2 * np.pi * np.random.rand()
    Delta = np.pi * (0.5 - np.random.rand())
    cosi = 1 - 2 * np.random.rand()
    psi = 0.5 * np.pi * (0.5 - np.random.rand())
    phi = 2 * np.pi * np.random.rand()
    
    writer_kwargs = {
        "label": "single_detector_gaussian_noise",
        "tstart": 1238166018,
        "duration": 120 * 86400,
        "detectors": "H1,L1",
        "sqrtSX": 5e-24,
        "Tsft": 1800,
        "SFTWindowType": "tukey",
        "SFTWindowBeta": 0.01,
    }
    
    # generate pure signal
    signal_parameters_generator = pyfstat.AllSkyInjectionParametersGenerator(
        priors={
            "tref": writer_kwargs["tstart"],
            "F0": F0,
            "F1": F1,
            "F2": 0,
            "Band": Band,
            "h0": lambda: writer_kwargs["sqrtSX"] * 1000,
            "Alpha": Alpha,
            "Delta": Delta,
            "cosi": cosi,
            "psi": psi,
            "phi": phi,
        },
    )
    
    writer_kwargs["outdir"] = f"PyFstat_data_cache/{parallel_job}"
    params = signal_parameters_generator.draw()

    writer = pyfstat.Writer(**writer_kwargs, **params)
    writer.make_data()
    frequency, timestamps, sft_pure_signal = get_sft_as_arrays(writer.sftfilepath)

    # generate noise
    noise_parameters_generator = pyfstat.AllSkyInjectionParametersGenerator(
        priors={
            "tref": writer_kwargs["tstart"],
            "F0": F0,
            "F1": F1,
            "F2": 0,
            "Band": Band,
            "h0": lambda: writer_kwargs["sqrtSX"] * 0,
            "Alpha": Alpha,
            "Delta": Delta,
            "cosi": cosi,
            "psi": psi,
            "phi": phi,
        },
    )
    
    writer_kwargs["outdir"] = f"PyFstat_data_cache/{parallel_job}"
    params = noise_parameters_generator.draw()
    writer = pyfstat.Writer(**writer_kwargs, **params)
    writer.make_data()
    _, _, sft_noise = get_sft_as_arrays(writer.sftfilepath)

    frequency = frequency.astype(np.float32)
    sft_H1_pure_signal = sft_pure_signal['H1'] * 1e22 / 1000 * signal_rate

    # signal rate of H1 and L1 can be different
    signal_rate_2 = signal_rate * (3**np.random.uniform(-1, 1))
    signal_rate_2 = min(max(signal_rate_2, 2**(-6.6)), 2**(-3.3))

    sft_L1_pure_signal = sft_pure_signal['L1'] * 1e22 / 1000 * signal_rate_2
    sft_H1_noise = sft_noise['H1'] * 1e22
    sft_L1_noise = sft_noise['L1'] * 1e22
    
    # signal = pure singal + noise
    sft_H1_signal = sft_H1_pure_signal + sft_H1_noise
    sft_L1_signal = sft_L1_pure_signal + sft_L1_noise

    # from complex to real
    sft_H1_noise = np.sqrt(sft_H1_noise.real**2 + sft_H1_noise.imag**2)
    sft_L1_noise = np.sqrt(sft_L1_noise.real**2 + sft_L1_noise.imag**2)
    sft_H1_signal = np.sqrt(sft_H1_signal.real**2 + sft_H1_signal.imag**2)
    sft_L1_signal = np.sqrt(sft_L1_signal.real**2 + sft_L1_signal.imag**2)

    freq_length = frequency.shape[0] - 1
    start = np.random.randint(0, freq_length - 360) + 1
    frequency = frequency[start:start+360].astype(np.float32)
    sft_H1_noise = sft_H1_noise[start:start+360].astype(np.float32)
    sft_L1_noise = sft_L1_noise[start:start+360].astype(np.float32)
    sft_H1_signal = sft_H1_signal[start:start+360].astype(np.float32)
    sft_L1_signal = sft_L1_signal[start:start+360].astype(np.float32)

    return frequency, sft_H1_noise, sft_L1_noise, sft_H1_signal, sft_L1_signal

## Demonstration generated gravitational waves

In [None]:
frequency, H1_noise, L1_noise, H1_signal, L1_signal = generate(signal_rate=2**-3.3)

# plot the generated data - singal and noise
plt.figure(figsize=(10, 5))
plt.title('signal')
plt.imshow(H1_signal, aspect='auto', cmap='gray')
plt.colorbar()
plt.show()

plt.figure(figsize=(10, 5))
plt.title('noise')
plt.imshow(H1_noise, aspect='auto', cmap='gray')
plt.colorbar()
plt.show()

# Generate gravitational waves through multiprocessing

In [None]:
# output directory
file_path=f"./input/generated-data/"
if not os.path.exists(file_path):
    os.makedirs(file_path)

In [None]:
def func(parallel_job):
    time = tqdm(range(CFG.n_instance_per_job))
    for idx in range(CFG.n_instance_per_job):
        gc.collect()
        while(True):
            try:
                signal_rate = 2**np.random.uniform(-6, -3.3) # singal rate between 2^-6 and 2^-3.3
                frequency, H1_noise, L1_noise, H1_signal, L1_signal = generate(signal_rate=signal_rate, parallel_job=parallel_job)
                np.savez(file_path + "signal-" + f"{parallel_job}-" + str(idx) + ".npz", H1=H1_signal, L1=L1_signal, L1_noise=L1_noise, H1_noise=H1_noise)
                break
            except:
                continue
        
        time.update(1)

In [None]:
# multi processing
num_cores = multiprocessing.cpu_count()
print("num_cores: ", num_cores)
Parallel(n_jobs=num_cores)(delayed(func)(parallel_job) for parallel_job in range(CFG.n_parallel_jobs))