Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add acceleration option to JointPrimaryMarginalizedModel likelihood #4688

Merged
merged 56 commits into from
Sep 9, 2024
Merged
Changes from 29 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
55e5541
Update hierarchical.py
WuShichao Apr 7, 2024
a6ac76d
Update hierarchical.py
WuShichao Apr 8, 2024
f4ed98d
Update hierarchical.py
WuShichao Apr 8, 2024
41224cc
Update hierarchical.py
WuShichao Apr 10, 2024
7109db9
Update hierarchical.py
WuShichao Apr 10, 2024
0bbe7a4
fix cc issues
WuShichao Apr 10, 2024
6920a8c
Update hierarchical.py
WuShichao Apr 15, 2024
55fba36
Update relbin.py
WuShichao Apr 15, 2024
b993833
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao May 16, 2024
7099515
add complex phase correction for sh_others
WuShichao May 16, 2024
94ba798
Update hierarchical.py
WuShichao May 16, 2024
6afbc4e
Update relbin.py
WuShichao May 16, 2024
e20b6f5
fix cc issues
WuShichao May 16, 2024
25a4562
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jun 4, 2024
8fb82a1
make code more general
WuShichao Jun 7, 2024
757a30b
update
WuShichao Jun 13, 2024
a2d64d1
fix
WuShichao Jun 13, 2024
8a9287e
rename
WuShichao Jun 13, 2024
5423d8c
update
WuShichao Jun 15, 2024
0b9d44e
WIP
WuShichao Jun 17, 2024
eeb8890
fix a bug in frame transform
WuShichao Jun 18, 2024
50e3599
fix overwritten issues
WuShichao Jun 19, 2024
dba5292
update
WuShichao Jun 19, 2024
537256e
update
WuShichao Jun 19, 2024
0af3fed
fix reconstruct
WuShichao Jun 19, 2024
0a09b12
make this PR general
WuShichao Jun 28, 2024
eb57268
update
WuShichao Jun 28, 2024
6d856b3
update
WuShichao Jun 28, 2024
075c39a
fix cc issues
WuShichao Jun 28, 2024
9ffbb70
rename
WuShichao Jun 28, 2024
e0f1ec4
rename
WuShichao Jun 28, 2024
6ce67c3
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 1, 2024
bf105a4
add multiband description
WuShichao Jul 1, 2024
273264f
fix
WuShichao Jul 4, 2024
5226b7d
add comments
WuShichao Jul 4, 2024
21fd035
fix hdf's config
WuShichao Jul 5, 2024
ba3816d
fix
WuShichao Jul 5, 2024
28fc1b2
fix
WuShichao Jul 6, 2024
b06d32e
fix
WuShichao Jul 28, 2024
b4a47af
fix
WuShichao Jul 29, 2024
a5b6d8c
remove print
WuShichao Jul 29, 2024
ca096ec
update for Alex's comments
WuShichao Jul 29, 2024
e8825be
wip
WuShichao Jul 30, 2024
be2b066
update
WuShichao Jul 30, 2024
c03652e
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 31, 2024
02b6937
fix
WuShichao Jul 31, 2024
87f10cb
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Jul 31, 2024
cbcd5a2
update
WuShichao Aug 1, 2024
36af111
seems work
WuShichao Aug 12, 2024
c084f4f
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 12, 2024
709b524
fix CC issue
WuShichao Aug 12, 2024
3a17bf5
fix
WuShichao Aug 12, 2024
3865d5e
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 12, 2024
0df23f2
fix demargin
WuShichao Aug 25, 2024
da34461
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Aug 25, 2024
f2b0798
Merge branch 'gwastro:master' into accelerate_multiband
WuShichao Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 67 additions & 35 deletions pycbc/inference/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,26 +607,12 @@ def _loglikelihood(self):


class JointPrimaryMarginalizedModel(HierarchicalModel):
""" Hierarchical heterodyne likelihood for coherent multiband
parameter estimation which combines data from space-borne and
ground-based GW detectors coherently. Currently, this only
supports LISA as the space-borne GW detector.

Sub models are treated as if the same GW source (such as a GW
from stellar-mass BBH) is observed in different frequency bands by
space-borne and ground-based GW detectors, then transform all
the parameters into the same frame in the sub model level, use
`HierarchicalModel` to get the joint likelihood, and marginalize
over all the extrinsic parameters supported by `RelativeTimeDom`
or its variants. Note that LISA submodel only supports the `Relative`
for now, for ground-based detectors, please use `RelativeTimeDom`
or its variants.

Although this likelihood model is used for multiband parameter
estimation, users can still use it for other purposes, such as
GW + EM parameter estimation, in this case, please use `RelativeTimeDom`
or its variants for the GW data, for the likelihood of EM data,
there is no restrictions.
"""This likelihood model can be used for cases when one of the submodels
can be marginalized to accelerate the total likelihood. This model
likelihood also allows for further acceleration of other models during
marginalization if some extrinsic parameters can be tightly constrained.
More specifically, such as the EM + GW parameter estimation, the sky
localization can be well measured.
"""
name = 'joint_primary_marginalized'

Expand All @@ -640,6 +626,10 @@ def __init__(self, variable_params, submodels, **kwargs):
self.other_models.pop(kwargs['primary_lbl'][0])
self.other_models = list(self.other_models.values())

# determine whether to accelerate total_loglr
from pycbc.inference.models.tools import str_to_bool
self.accelerate_loglr = str_to_bool(kwargs['acclerate_loglr'][0])

def write_metadata(self, fp, group=None):
"""Adds metadata to the output files

Expand Down Expand Up @@ -686,8 +676,6 @@ def total_loglr(self):
"""
# calculate <d-h|d-h> = <h|h> - 2<h|d> + <d|d> up to a constant

# note that for SOBHB signals, ground-based detectors dominant SNR
# and accuracy of (tc, ra, dec)
self.primary_model.return_sh_hh = True
sh_primary, hh_primary = self.primary_model.loglr
self.primary_model.return_sh_hh = False
Expand All @@ -696,15 +684,37 @@ def total_loglr(self):
self.primary_model.marginalize_vector_params.keys())
if 'logw_partial' in margin_names_vector:
margin_names_vector.remove('logw_partial')

margin_params = {}
nums = 1
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in margin_names_vector:
margin_params[key] = value
if isinstance(value, numpy.ndarray):
nums = len(value)

if self.accelerate_loglr:
# Due to the high precision of extrinsic parameters constrined
# by the primary model, the mismatch of wavefroms in others by
# varing those parameters is pretty small, so we can keep them
# static to accelerate total_loglr. Here, we use matched-filering
# SNR instead of lilkelihood, because luminosity distance and
# inclination has a very strong degeneracy, change of inclination
# will change best match distance, so change the amplitude of
# waveform. Using SNR will cancel out the effect of amplitude.err
i_max_extrinsic = numpy.argmax(
numpy.abs(sh_primary) / hh_primary**0.5)
for p in margin_names_vector:
if isinstance(self.primary_model.current_params[p],
numpy.ndarray):
margin_params[p] = \
self.primary_model.current_params[p][i_max_extrinsic]
nums = len(self.primary_model.current_params[p])
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WuShichao This logic should already take care of the distance case, except that later on you assume that if any parameter is a scalar they all are. That's the part you should stop assuming. Don't assume they are any particular mix of scalar or vector.

margin_params[p] = self.primary_model.current_params[p]
nums = 1
else:
for key, value in self.primary_model.current_params.items():
# add marginalize_vector_params
if key in margin_names_vector:
margin_params[key] = value
if isinstance(value, numpy.ndarray):
nums = len(value)
else:
nums = 1
# add distance if it has been marginalized,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, avoid needed to know explicitly about distance here. Think instead about the format that you require. If the format differs how to generically convert.

# use numpy array for it is just let it has the same
# shape as marginalize_vector_params, here we assume
Expand All @@ -713,7 +723,7 @@ def total_loglr(self):
margin_params['distance'] = numpy.full(
nums, self.primary_model.current_params['distance'])

# add likelihood contribution from space-borne detectors, we
# add likelihood contribution from other_models, we
# calculate sh/hh for each marginalized parameter point
sh_others = numpy.full(nums, 0 + 0.0j)
hh_others = numpy.zeros(nums)
Expand All @@ -723,24 +733,43 @@ def total_loglr(self):
# not using self.primary_model.current_params, because others_model
# may have its own static parameters
current_params_other = other_model.current_params.copy()
for i in range(nums):
if not self.accelerate_loglr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you choose a more descriptive name to what the option does (e.g. how does it make the likelihood faster than than just the fact that it does)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self.static_margin_params_in_other_models better?

Copy link
Contributor Author

@WuShichao WuShichao Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahnitz In this PR, we don't apply any amplitude and phase correction (as general as possible), so I will not include these in the option name.

for i in range(nums):
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_other, hh_other = other_model.loglr
sh_others[i] += sh_other
hh_others[i] += hh_other
other_model.return_sh_hh = False
else:
# use one margin point set to approximate all the others
current_params_other.update(
{key: value[i] if isinstance(value, numpy.ndarray) else
value for key, value in margin_params.items()})
{key: value[0] if isinstance(value, numpy.ndarray)
else value for key, value in margin_params.items()})
other_model.update(**current_params_other)
other_model.return_sh_hh = True
sh_others[i], hh_others[i] = other_model.loglr
sh_other, hh_other = other_model.loglr
other_model.return_sh_hh = False
sh_others += sh_other
hh_others += hh_other

if nums == 1:
# the type of the original sh/hh_others are numpy.array,
# might not the same as sh/hh_primary during reconstruct,
# during reconstruct of distance, sh/hh_others need to be scalar
sh_others = sh_others[0]
hh_others = hh_others[0]
sh_total = sh_primary + sh_others
hh_total = hh_primary + hh_others

# calculate marginalize_vector_weights
self.primary_model.marginalize_vector_weights = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line shouldn't be here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, removed.

- numpy.log(self.primary_model.vsamples)
loglr = self.primary_model.marginalize_loglr(sh_total, hh_total)

return loglr

def others_lognl(self):
Expand Down Expand Up @@ -804,6 +833,9 @@ def from_config(cls, cp, **kwargs):
submodel_lbls))
sparam_map = map_params(hpiter(cp.options('static_params'),
submodel_lbls))
# get the acceleration label
kwargs['acclerate_loglr'] = shlex.split(
cp.get('model', 'accelerate_others_in_total_loglr'))

# we'll need any waveform transforms for the initializing sub-models,
# as the underlying models will receive the output of those transforms
Expand Down
Loading