In [None]:
import numpy as np
import matplotlib.pyplot as plt
import contextlib
from ler.rates import LeR
from gwsnr import GWSNR
from ler.utils import append_json, get_param_from_json

In [None]:
gwsnr = GWSNR(
    npool=6,
    waveform_approximant='IMRPhenomXPHM',
    snr_type='inner_product',
)

In [None]:
class ModelGenerator():

    __init__(
        npool=4,
        size=1000,
        batch_size=200000,
        z_min=0.0,
        z_max=5.0,
        detectors=['L1','H1','V1']  # detectors
        verbose=True,
        create_new=True,
        **kwargs,  # ler and gwsnr arguments
    )

    self.npool = npool
    self.z_min = z_min
    self.z_max = z_max
    self.verbose = verbose

    self.ler_init_args = (
        event_type="BBH",
        cosmology=None,
        # gwsnr args
        mtot_min=2.0,
        mtot_max=200,
        ratio_min=0.1,
        ratio_max=1.0,
        mtot_resolution=500,
        ratio_resolution=50,
        sampling_frequency=2048.0
        waveform_approximant="IMRPhenomXPHM",
        minimum_frequency=20.0,
        snr_type="interpolation",
        psds=None,
        ifos=None,
        interpolator_dir=self.interpolator_directory,
        create_new_interpolator=False,
        gwsnr_verbose=False,
        multiprocessing_verbose=True,
        mtot_cut=True,
    )
    ler_init_args.update(kwargs)

    # use os to delete the exsiting file
    if create_new:
        for det in detectors:
            output_path = "ler_data/snr_"+det+"_spinlessIMRPhenomXPHM.json"
            if os.path.exists(output_path):
                os.remove(output_path)

    # generate spinless snr for snr_range and astrophysical range
    spinless_waveform_snrs(size=size, batch_size=batch_size);

    # recalculate snr within the snr_range.
    # find snr in astrophysical range
    # save it


    def spinless_waveform_snrs(size, batch_size):

        args = self.ler_init_args.copy()

        # ler initialization
        ler = LeR(
            npool=self.npool,
            z_min=self.z_min,
            z_max=self.z_max,  # becareful with this value
            verbose=self.verbose,
            spin_zero=True,
            spin_precession=False,
            # ler
            event_type="BBH",
            cosmology=None,
            # gwsnr args
            mtot_min=2.0,
            mtot_max=200,
            ratio_min=0.1,
            ratio_max=1.0,
            mtot_resolution=500,
            ratio_resolution=50,
            sampling_frequency=2048.0
            waveform_approximant="IMRPhenomXPHM",
            minimum_frequency=20.0,
            snr_type="interpolation",
            psds=None,
            ifos=None,
            interpolator_dir=self.interpolator_directory,
            create_new_interpolator=False,
            gwsnr_verbose=False,
            multiprocessing_verbose=True,
            mtot_cut=True,
        )

        ler.batch_size = batch_size

        len_ = 0
        # size = self.size
        print(f'total event to collect: {size}\n')
        while len_<size:
            with contextlib.redirect_stdout(None):
                unlensed_param = ler.unlensed_cbc_statistics(size=ler.batch_size, resume=False)

            # SNR for each detector
            detectors = ["L1", "H1", "V1"]
            for i,det in enumerate(detectors):
                snr = np.array(unlensed_param[det])

                # setting SNR range
                idx1 = np.argwhere(snr<2).flatten()
                idx2 = np.argwhere((snr>=2) & (snr<4)).flatten()
                idx3 = np.argwhere((snr>=4) & (snr<6)).flatten()
                idx4 = np.argwhere((snr>=6) & (snr<8)).flatten()
                idx5 = np.argwhere((snr>=8) & (snr<10)).flatten()
                idx6 = np.argwhere((snr>=10) & (snr<12)).flatten()
                idx7 = np.argwhere((snr>=12) & (snr<14)).flatten()
                idx8 = np.argwhere((snr>=14) & (snr<16)).flatten()
                idx16 = np.argwhere(snr>=16).flatten()

                unlensed_param1 = {}
                unlensed_param2 = {}
                unlensed_param3 = {}
                unlensed_param4 = {}
                unlensed_param5 = {}
                unlensed_param6 = {}
                unlensed_param7 = {}
                unlensed_param8 = {}
                unlensed_param16 = {}
                len_idx16 = len(idx16)

                for key, value in unlensed_param.items():
                    if len(idx1)>len_idx16:
                        unlensed_param1[key] = value[idx1][:len_idx16]
                    else:
                        unlensed_param1[key] = value[idx1]
                    if len(idx2)>len_idx16:
                        unlensed_param2[key] = value[idx2][:len_idx16]
                    else:
                        unlensed_param2[key] = value[idx2]
                    if len(idx3)>len_idx16:
                        unlensed_param3[key] = value[idx3][:len_idx16]
                    else:
                        unlensed_param3[key] = value[idx3]
                    if len(idx4)>len_idx16:
                        unlensed_param4[key] = value[idx4][:len_idx16]
                    else:
                        unlensed_param4[key] = value[idx4]
                    if len(idx5)>len_idx16:
                        unlensed_param5[key] = value[idx5][:len_idx16]
                    else:
                        unlensed_param5[key] = value[idx5]
                    if len(idx6)>len_idx16:
                        unlensed_param6[key] = value[idx6][:len_idx16]
                    else:
                        unlensed_param6[key] = value[idx6]
                    if len(idx7)>len_idx16:
                        unlensed_param7[key] = value[idx7][:len_idx16]
                    else:
                        unlensed_param7[key] = value[idx7]
                    if len(idx8)>len_idx16:
                        unlensed_param8[key] = value[idx8][:len_idx16]
                    else:
                        unlensed_param8[key] = value[idx8]
                    unlensed_param16[key] = value[idx16]
                    
                    unlensed_param[key] = np.concatenate([unlensed_param1[key], unlensed_param2[key], unlensed_param3[key], unlensed_param4[key], unlensed_param5[key], unlensed_param6[key], unlensed_param7[key], unlensed_param8[key], unlensed_param16[key]])

                # save the parameters
                append_json(ler.ler_directory+"/snr_"+det+"_spinlessIMRPhenomXPHM.json", unlensed_param, replace=False);

            unlensed_params = get_param_from_json(ler.ler_directory+"/snr_"+det+"_spinlessIMRPhenomXPHM.json")
            len_ = len(unlensed_params[det])

            print(f"Collected number of events: {len_}")

    