Skip to content

Commit

Permalink
Merge pull request #16 from mj-will/option-to-skip-caching
Browse files Browse the repository at this point in the history
Add option to skip generator caching
  • Loading branch information
WuShichao authored May 22, 2024
2 parents c3d0805 + 4ce6a41 commit 8e35f56
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
25 changes: 21 additions & 4 deletions BBHX_Phenom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from warnings import warn


@functools.lru_cache(maxsize=128)
def get_waveform_genner(log_mf_min, run_phenomd=True):
# See below where this function is called for description of how we handle
# log_mf_min.
Expand All @@ -19,6 +18,12 @@ def get_waveform_genner(log_mf_min, run_phenomd=True):
return wave_gen


@functools.lru_cache(maxsize=128)
def cached_get_waveform_genner(log_mf_fin, run_phenomd=True):
"""Cached version of get_waveform_genner"""
return get_waveform_genner(log_mf_fin, run_phenomd)


@functools.lru_cache(maxsize=10)
def cached_arange(start, stop, spacing):
return np.arange(start, stop, spacing)
Expand Down Expand Up @@ -129,6 +134,7 @@ def _bbhx_fd(
direct=False,
num_interp=100,
interp_f_lower=1e-4,
cache_generator=True,
**params
):

Expand Down Expand Up @@ -157,6 +163,8 @@ def _bbhx_fd(
interp_f_lower : float
Lower frequency cutoff used for interpolation when computing the
chirp time.
cache_generator : bool
If true, the BBHx waveform generator is cached based on
Returns
-------
Expand Down Expand Up @@ -269,9 +277,18 @@ def _bbhx_fd(
# To solve this we *round* the *logarithm* of this mass-dependent start
# frequency. The factor of 25 ensures reasonable spacing while doing this.
# So we round down to the nearest 1/25 of the logarithm of the frequency
log_mf_min = int(math.log(f_min*MTSUN_SI*(m1+m2)) * 25)

wave_gen = get_waveform_genner(log_mf_min, run_phenomd=run_phenomd)
log_mf_min = math.log(f_min*MTSUN_SI*(m1+m2)) * 25
if cache_generator:
# Use int to round down
wave_gen = cached_get_waveform_genner(
int(log_mf_min),
run_phenomd=run_phenomd,
)
else:
wave_gen = get_waveform_genner(
log_mf_min,
run_phenomd=run_phenomd,
)

if sample_points is None:
if 'delta_f' in params and params['delta_f'] > 0:
Expand Down
25 changes: 25 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,31 @@ def test_phenomhm_mode_array(params, mode_array):
assert len(wf) == 3


@pytest.mark.parametrize("cache_generator", [False, True])
def test_cache_generator(params, cache_generator):
from BBHX_Phenom import cached_get_waveform_genner

# Clear cache for these tests
cached_get_waveform_genner.cache_clear()

params["approximant"] = "BBHX_PhenomD"
params["cache_generator"] = cache_generator

# Build cache if using it
get_fd_det_waveform(**params)

n_calls = 2
for _ in range(n_calls):
get_fd_det_waveform(**params)

cache_info = cached_get_waveform_genner.cache_info()
if cache_generator:
assert cache_info.hits == n_calls
else:
assert cache_info.hits == 0



def test_length_in_time(params, approximant):
params["approximant"] = approximant
# print(params)
Expand Down

0 comments on commit 8e35f56

Please sign in to comment.