# Example use-case: LISA UCB convolution of double white-dwarf binaries at current-day

This example fleshes out the steps required to estimate the population of observable double white dwarf systems in the LISA band. Relevant studies are: https://arxiv.org/abs/2405.20484

Several ingredients are necessary here:
- population-synthesis results that contain white-dwarfs
- a Milky-Way galaxy star formation rate history model
- a method to evolve double white-dwarf under the influence of gravitational wave radiation.

Convolution-by-sampling was developed especially for this project, as we want to 'generate' double white dwarf systems at a certain lookback time, and evolve them (through gravitational-wave radiation) to the present day.

The convolution broadly is done as follows:
- In a given lookback-time bin we calculate the total mass formed into stars
- We use that to generate double white dwarf systems (using mass_formed * yield-per-mass-formed) 
- We assign a birth time to these systems (with values bound by the edges of the lookback time bin)
- We 'evolve' these systems up to the current day under the influence of gravitational-wave radiation. We make use of Legwork ([Wagg et al 2021](https://ui.adsabs.harvard.edu/abs/2022ApJS..260...52W/abstract)) in this example.
- Filter out certain systems (in particular those that, at current-day, are not in the LISA waveband)
- Calculate detection probabilities for the rest based on their position (either randomly assigned or motivated by a spatially-defined SFH) in the Milkyway and their system properties.
- Use this information to predict observable populations of DWD systems

In the following notebook I will show how to set up all the necessary pieces for this script and how to put them together. We will start some general imports like usual, then continue with setting up the `sfr_dict`, then design the post convolution function, and then combine everything and run the convolution.

Lets start with setting up some form of star formation rate for the milkyway. There are many estimates and descriptions of increasing complexity.

In [1]:
import os
import json
import time
import copy
import astropy.units as u
import legwork as lw
import numpy as np
import astropy.constants as const
import pandas as pd
import pkg_resources
import h5py

from syntheticstellarpopconvolve import convolve, default_convolution_config, default_convolution_instruction
from syntheticstellarpopconvolve.general_functions import temp_dir, generate_boilerplate_outputfile
from syntheticstellarpopconvolve.convolution_by_sampling import (
    select_dict_entries_with_new_indices,
)

from syntheticstellarpopconvolve.usecase_notebook_utils.usecase_lisa_utils import get_mass_norm, sample_distances_simple, get_period, sample_distances_interpolated

TMP_DIR = temp_dir("code", "convolve_stochastically", clean_path=True)

# The flag below allows the user to run this notebook without the full data or starformation rate. 
FULL_VERSION = os.getenv("EXAMPLE_USECASE_UCB_VERSION", False)

  import pkg_resources


Next step is to load some data. In this case we have loaded some data expressed in the T0 bincodex format, and sourced from SeBa simulations, which are Monte-Carlo based (opposed to grid-based) simulations. We want to select double white-dwarf systems only, and provide a `normalized_yield` to them.

In [2]:
# Load the data
data_filename = os.getenv('EXAMPLE_DATA_USECASE_UCB_FILENAME') if FULL_VERSION else None
example_usecase_UCB_events_filename = data_filename if data_filename is not None else pkg_resources.resource_filename(
    "syntheticstellarpopconvolve",
    "example_data/example_BinCodex_dwd.h5"
)
example_usecase_UCB_events_data = pd.read_hdf(example_usecase_UCB_events_filename,  key='T0')

## 
# Determine the normalized_yield and select only systems that are double white-dwarfs

# get mass normalisation
mass_normalisation_fiducial = get_mass_norm(IC_model="fiducial", binary_fraction=0.5)

# set normalised yield
example_usecase_UCB_events_data["normalized_yield"] = 1 / mass_normalisation_fiducial
# Query the dataset to select the formation of the WDs

# check if things start with some number. It's easier to turn them into strings for this.
example_usecase_UCB_events_data["str_event"] = example_usecase_UCB_events_data["event"].astype(str)
example_usecase_UCB_events_data["str_type1"] = example_usecase_UCB_events_data["type1"].astype(str)
example_usecase_UCB_events_data["str_type2"] = example_usecase_UCB_events_data["type2"].astype(str)

# lets query the type-changing events. Any type-change will do
wd_binaries = example_usecase_UCB_events_data.query("str_event.str.startswith('1')")

# The type should change to a WD-type (and the other should already be one)
wd_binaries = wd_binaries.query("str_type1.str.startswith('2')")
wd_binaries = wd_binaries.query("str_type2.str.startswith('2')")

# Lets delete the string versions of the columns again
wd_binaries = wd_binaries.drop(columns=["str_event", "str_type1", "str_type2"])

# lets also delete the original dataframe
del example_usecase_UCB_events_data

Lets now set up some form of star formation rate for the Milkyway. There are many estimates and descriptions of increasing complexity, but for simplicity lets take a constant star formation that within 10Gyr will have formed a mass equivalent to the stellar mass of the Milky way. We will ignore metallicity here.

In [3]:
sfr_dict = {}
scale = 1e-1

# Set up the lookback time bins
if FULL_VERSION:
    sfr_dict["lookback_time_bin_edges"] = (np.arange(0, 10, 0.1) * u.Gyr).to(u.yr)
else:
    sfr_dict["lookback_time_bin_edges"] = (np.arange(0., 10., 1) * u.Gyr).to(u.yr)

#
sfr_dict["starformation_rate_array"] = (6 * u.Msun / u.yr) *  np.ones(sfr_dict["lookback_time_bin_edges"].shape[0] - 1)

if not FULL_VERSION:
    sfr_dict["starformation_rate_array"] *= scale
# sfr_dict["starformation_rate_array"] *= scale

With the starformation information now configured, we can design the actual machinery of this calculation: the post-convolution step.

Convolution by sampling 'generates' systems, and returns indices to those systems that allows us to link back to system properties. It also provides formation times that we can use.

The idea in this routine will be to:
- Use the 'generated' system (indices)
- Calculate the the time difference between current day (lookback time = 0) and the lookback time where the event occurred (lookback time = formation time + delay time). Filter out events that occur in the future.
- Use this time to 'evolve' the double white dwarf system forward in time under the influence of gravitational wave radiation. (we use Legwork).
- Filter out systems that fall outside of the LISA waveband.

This is what we will handle within the post-convolution. We will thus store every system that falls within the LISA waveband at current day. We will handle the processing of detectability in the next step. 

In [11]:
import time
import functools
import logging 

from syntheticstellarpopconvolve.usecase_notebook_utils.usecase_lisa_utils import precompute_radial_cdf, sample_distances_interpolated


Hr = 4  # Radial scale length
inverse_cdf = precompute_radial_cdf(Hr)
bound_sample_distances_interpolated = functools.partial(
    sample_distances_interpolated,
    inverse_cdf=inverse_cdf
)


def post_convolution_function(
    config, sfr_dict, data_dict, time_bin_info_dict, convolution_results, convolution_instruction
):
    """
    Post-convolution function to handle integrating the systems forward in time and finding those that end up in the LISA waveband.

    using local_indices to select everything and using Alexey's distance sampler to handle sampling the distances

    TODO: we can make this even faster by: 
    - not including distances in the first sources call (setting all to 1)
    - integrating
    - filtering based on hz and whether they merged
    - then assign distances to subset
    - then do sources with those
    - filter on SNR
    - combine filters
    """

    start = time.time()

    ######
    # unpack data
    start_unpack = time.time()
    
    # These allow linking back to the data of the systems    
    system_indices = convolution_results["sampled_indices"] 
    local_indices = np.arange(len(system_indices))
    stop_unpack = time.time()

    ######
    # Set formation and event lookback times
    formation_lookback_times = np.random.random(system_indices.shape) * time_bin_info_dict['bin_size'] + time_bin_info_dict['bin_edge_lower']
    event_lookback_times = formation_lookback_times - data_dict['delay_time'][system_indices]


    print(time_bin_info_dict)
    print(event_lookback_times)
    
    # Filter out systems that occur in the future 
    # TODO: make function
    future_mask = event_lookback_times > 0
    number_of_future_events = future_mask.size-future_mask.sum()

    system_indices = system_indices[future_mask==1]
    local_indices = local_indices[future_mask==1]
    
    convolution_results = select_dict_entries_with_new_indices(
        sampled_data_dict=convolution_results,
        new_indices=local_indices,
    )
    
    config["logger"].warning(
        f'filtered out {number_of_future_events} future events'
    )

    ######
    # select system properties using the indices
    start_readout = time.time()
    sma = data_dict["semimajor_axis"][system_indices] * u.Rsun
    m_1 = data_dict["mass1"][system_indices] * u.Msun
    m_2 = data_dict["mass2"][system_indices] * u.Msun 
    eccentricity = data_dict["eccentricity"][system_indices]
    periods = get_period(sma, m_1, m_2)
    f_orb_i = (1 / periods).to(u.Hz)
    dummy_dist = np.ones(sma.shape) * u.kpc # Provide fake distances at first because we will filter out any system that is not in the correct frequency band
    stop_readout = time.time()

    #########
    # Evolve systems under GW radiation 
    start_GW = time.time()

    # Set up sources In Legwork
    # TODO: make function
    sources = lw.source.Source(
        m_1=m_1,
        m_2=m_2,
        ecc=eccentricity,
        f_orb=f_orb_i,
        dist=dummy_dist,
        interpolate_g=True,
    )

    # Evolve the systems until today
    t_evol = event_lookback_times
    sources.evolve_sources(t_evol)
    f_orb_now = sources.f_orb # Get the orbital frequencies of the systems at current-day.
    convolution_results["f_orb_now"] = f_orb_now # Store the orbital frequency in the result dict

    stop_GW = time.time()

    #####################################################
    # 
    
    ####
    # Make categorisations and filter out systems
    start_mask = time.time()

    ##############
    # determine (un)merged systems 
    # TODO: make function

    # Whether a system is merged is determined by checking if its orbital frequency is above 100hz
    merged_frequency = 1e2 * u.Hz

    unmerged_mask = f_orb_now < merged_frequency

    config["logger"].warning(
        f"Of the total of {len(local_indices)} systems {np.count_nonzero(unmerged_mask==0)} are merged by today and {np.count_nonzero(unmerged_mask)} are not"
    )

    ##############
    # determine unmerged systems in LISA passband
    # TODO: make function
    lower_bound_LISA_passband = 1e-5 * u.Hz
    upper_bound_LISA_passband = 1e-1 * u.Hz
    
    waveband_mask = (f_orb_now >= lower_bound_LISA_passband) & (f_orb_now <= upper_bound_LISA_passband)

    config["logger"].warning(
        f"Of the total of {len(local_indices)} systems {np.count_nonzero(waveband_mask)} are within the lisa frequency passband ([{lower_bound_LISA_passband},{upper_bound_LISA_passband}])"
    )

    ##############
    # combine masks and filter
    combined_mask = unmerged_mask * waveband_mask
    
    local_indices_unmerged_systems_in_waveband = local_indices[combined_mask==1]
    system_indices_unmerged_systems_in_waveband = system_indices[combined_mask==1]
    
    config["logger"].warning(
        f"Of the total of {len(local_indices)} systems {np.count_nonzero(combined_mask)} are within the lisa frequency passband ([{lower_bound_LISA_passband},{upper_bound_LISA_passband}]) and are not merged"
    )

    #####################################################
    # sample distances
    start_dist = time.time()
    dist = sample_distances_interpolated(NBin=len(local_indices_unmerged_systems_in_waveband), Hr=Hr, inverse_cdf=inverse_cdf)
    stop_dist = time.time()

    #######
    # filter based on detectability here
    sources = lw.source.Source(
        m_1=m_1[local_indices_unmerged_systems_in_waveband],
        m_2=m_2[local_indices_unmerged_systems_in_waveband],
        ecc=eccentricity[local_indices_unmerged_systems_in_waveband],
        f_orb=f_orb_now[local_indices_unmerged_systems_in_waveband],
        dist=dist,
        interpolate_g=True,
    )
    
    snr = sources.get_snr(verbose=False)
    detectability_mask = snr > 7

    #############
    # filter out undetectable systems
    local_indices_detectable_unmerged_systems_in_waveband = local_indices_unmerged_systems_in_waveband[detectability_mask==1]
    system_indices_detectable_unmerged_systems_in_waveband = system_indices_unmerged_systems_in_waveband[detectability_mask==1]

    config["logger"].critical(
        f"time-bin {time_bin_info_dict['bin_number']}: Of the total of {len(local_indices_unmerged_systems_in_waveband)} unmerged systems in the lisa waveband {np.count_nonzero(detectability_mask)} have a signal to noise (SNR) ratio of above 7 in an observation period of 4 years"
    )

    ##############
    # combine masks and construct filtered index list
    stop_mask = time.time()
    stop = time.time()

    config["logger"].critical(
        "The post-convolution step took {}s, {}s of which for unpacking. {}s is for readout. {}s is for distance sampling. {}s of which for GW integration, and {}s of which for masking.".format(
            stop-start,
            stop_unpack-start_unpack,
            stop_readout-start_readout,
            stop_dist-start_dist,
            stop_GW-start_GW,
            stop_mask-start_mask            
        )
    )

    #######
    # return only data from now unmerged systems within the lisa passband
    # We use the function `select_dict_entries_with_new_indices` to select the data using the new indices in every entry in the dict. 
    # Result dict does not contain that many 
    convolution_results = select_dict_entries_with_new_indices(
        sampled_data_dict=convolution_results,
        new_indices=local_indices_detectable_unmerged_systems_in_waveband,
    )

    convolution_results['dist'] = dist[snr>7]
    convolution_results['snr'] = snr[snr>7]

    return convolution_results

Lastly, we set up the `convolution_config`, which binds everything together. Most of this is similar to other scripts, but here we use a different kind of convolution, namely convolution by sampling. Information about this type of sampling is available **here TODO: link**

In [12]:
##################
#

# create file
output_hdf5_filename = os.path.join(TMP_DIR, "output_hdf5.h5")
generate_boilerplate_outputfile(output_hdf5_filename)

# store the data frame in the hdf5file
wd_binaries.to_hdf(output_hdf5_filename, key="input_data/example_usecase_UCB_events")

#
convolution_config = copy.copy(default_convolution_config)
convolution_config["output_filename"] = output_hdf5_filename
convolution_config["tmp_dir"] = TMP_DIR
convolution_config['num_cores'] = 1 # change this for any production-run
convolution_config['logger'].setLevel(logging.INFO)

###
# convolution instructions
convolution_config["convolution_instructions"] = [
    {
        **default_convolution_instruction,
        "convolution_type": "sample",
        "input_data_name": "example_usecase_UCB_events",
        "output_data_name": "example_usecase_UCB_events",
        "post_convolution_function": post_convolution_function,
        "data_column_dict": {
            # required
            "normalized_yield": "normalized_yield",
            "delay_time": {"column_name": "time", "unit": u.Myr},
            # The columns below are required because we use them in the post_convolution function. We can either give their units here, or we can assign them in the post-convolution function.
            "semimajor_axis": "semiMajor",
            "mass1": "mass1",
            "mass2": "mass2",
            "eccentricity": "eccentricity"
        },
        'multiply_by_sfr_time_binsize': True
    },
]

convolution_config['convolution_lookback_time_bin_edges'] = np.arange(0, 12, 0.25)*u.Gyr
convolution_config['multiprocessing']=False

# store SFR dict
convolution_config["SFR_info"] = sfr_dict

In [13]:
# convolve
print("starting convolution")
convolve(config=convolution_config)

print("finished convolution")

[convolution_by_sampling.py:127 -       sample_systems ] 2025-04-14 23:41:39,837: Sampling systems. Using yield array [  0.           0.           0.         ... 134.03195135 134.03195135
 134.03195135]


starting convolution


[convolution_by_sampling.py:169 -       sample_systems ] 2025-04-14 23:41:40,657: Sampled (12282219,) systems.
[convolution_by_sampling.py:77 - convolution_by_sampling_post_convolution_hook_wrapper ] 2025-04-14 23:41:40,658: Handling post-convolution function hook call for convolution by sampling
[2101976194.py:67 - post_convolution_function ] 2025-04-14 23:41:40,899: filtered out 11378381 future events


{'bin_number': 0, 'bin_center': <Quantity 0.125 Gyr>, 'bin_edge_lower': <Quantity 0. Gyr>, 'bin_size': <Quantity 0.25 Gyr>, 'bin_type': 'convolution time', 'time_type': 'lookback_time', 'reverse_bin_order': False, 'convolution_direction': 'backward'}
[-9.69849756 -9.60858267 -9.7141721  ... -0.04516695  0.1469323
  0.06781169] Gyr


ValueError: operands could not be broadcast together with shapes (12282219,) (903838,) 

Now that the convolution is finished, we can read out the results. Remember, in this case, the bins in which the data is stored indicate the range of lookback times in which the systems are born. They are all evolved to the current day, however, so we need to add the results of all these bins together to get a complete view of the population of dwd systems that will be observable by LISA at current times!

In [None]:
import legwork.strain as strain

total_forb_array = np.array([])
total_ecc_array = np.array([])
total_m1_array = np.array([])
total_m2_array = np.array([])
total_dist_array = np.array([])


# 
df = pd.read_hdf(output_hdf5_filename, key="input_data/example_usecase_UCB_events")
print(df)

# read out content and integrate until today
with h5py.File(convolution_config["output_filename"], "r") as output_hdf5_file:
    group_key = "output_data/example_usecase_UCB_events/example_usecase_UCB_events/convolution_results"
    
    formation_time_bin_keys = list(
        output_hdf5_file[group_key].keys()
    )

    ################
    #

    # loop over the formation-time bins
    formation_time_bin_keys = sorted(
        formation_time_bin_keys, key=lambda x: float(x.split(" ")[0])
    )
    for formation_time_bin_key in formation_time_bin_keys:
        # 
        time_bin_group_key = f"{group_key}/{formation_time_bin_key}"
            
        # print("=================================")
        # print(f"formation_time_bin_key: {formation_time_bin_key}")
        # print("=================================")
        
        #####
        # convert units
        unit_dict = json.loads(
            output_hdf5_file[
                f"{group_key}/{formation_time_bin_key}"
            ].attrs["units"]
        )
        unit_dict = {key: u.Unit(val) for key, val in unit_dict.items()}

        ###########
        # Read out data
        sampled_indices = output_hdf5_file[f'{time_bin_group_key}/sampled_indices'][()]
        dist = output_hdf5_file[f'{time_bin_group_key}/dist'][()] * u.kpc
        f_orb_now = output_hdf5_file[f'{time_bin_group_key}/f_orb_now'][()] * unit_dict['f_orb_now']

        # select with indices
        eccentricity = df.iloc[sampled_indices]['eccentricity'].to_numpy()
        mass1 = df.iloc[sampled_indices]['mass1'].to_numpy() * u.Msun
        mass2 = df.iloc[sampled_indices]['mass2'].to_numpy() * u.Msun

        # add to combined arrays
        total_forb_array = np.concatenate([total_forb_array, f_orb_now])
        total_ecc_array = np.concatenate([total_ecc_array, eccentricity])
        total_m1_array = np.concatenate([total_m1_array, mass1])
        total_m2_array = np.concatenate([total_m2_array, mass2])
        total_dist_array = np.concatenate([total_dist_array, dist])

print("Found a total number of {} detectable systems".format(len(total_forb_array)))

#######
# Set up the sources again
sources = lw.source.Source(
    m_1=total_m1_array,
    m_2=total_m2_array,
    ecc=total_ecc_array,
    f_orb=total_forb_array,
    dist=total_dist_array,
    interpolate_g=True,
)                                      
snr = sources.get_snr(verbose=True)
print(len(snr[snr>7]))
fig, ax = sources.plot_source_variables(xstr="f_orb", ystr="snr", disttype="kde", log_scale=(True, True))      

In [None]:
import astropy.units as u
import legwork as lw

# -----------------------------
# 1) Define your source parameters
# -----------------------------
# For a double white-dwarf binary:
m_1 = 0.3 * u.Msun   # mass of the first WD
m_2 = 0.3 * u.Msun   # mass of the second WD
eccentricity = 0

# Orbital frequency (f_orb): frequency of the binary's orbit
# Note that the gravitational wave frequency for a circular orbit is 2*f_orb
f_orb = 1.0e-3 * u.Hz

# Distance to the source
dist = 1000.0 * u.pc

# -----------------------------
# 2) Create a Source object
# -----------------------------
# The Source class in LEGWORK can take arrays or scalars of parameters. 
# Here we demonstrate with single (scalar) values.
binary = lw.source.Source(
    m_1=m_1,
    m_2=m_2,
    ecc=eccentricity,
    f_orb=f_orb,
    dist=dist,
    # interpolate_g=len(local_indices) > 1000,
)

# -----------------------------
# 3) Compute the SNR
# -----------------------------
# By default, LEGWORK uses the standard LISA sensitivity PSD from Robson et al. (2019).
snr = binary.get_snr()
print(snr)
# print(f"SNR = {snr:.3f}")

# -----------------------------
# 4) Compute the detection probability
# -----------------------------
# LEGWORK uses a “characteristic SNR threshold” approach by default 
# (commonly ~7 for a "resolvable" detection in many LISA studies).
# You can override the default threshold via the "snr_thresh" argument if desired.
# p_det = binary.get_detection_probability()
# print(p_det)
# print(p_det
# # p_det = binary.get_detection_probability()


# Check if the binary is detectable above SNR=7
is_detectable = binary.detectable(snr_threshold=7.0)
print(is_detectable)  # returns True or False
      
# print(f"Detection probability = {p_det:.3f}")


## Advanced steps
This example is a good step towards a solid predictive calculation for the population of observable white dwarf systems for the LISA mission, but it does lack some sophistication. In particular, the star formation history information is not so realistic. 

This example can be made more sophisticated by e.g.:
- Using a spatially-defined star-formation rate history. One can provide a list of starformation histories to the code, each element then representing a part of the grid where the SFR is defined in.
- Splitting off systems that may undergo RLOF
- providing on-sky angle (dec, asc), inclination of system relative to us, orbital phase for eccentric systems

https://legwork.readthedocs.io/en/latest/notebooks/Source.html#Position-inclination-polarisation-specfic-sources

TODO: determine which systems that are (at present day) in the lisa frequency range should have interacted through RLOF
TODO: of the systems that are not RLOFing and are within the lisa waveband, store: indices, source.f_orb_now. the rest can be retrieved elsewhere
