Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correcting prior and adding pv2 #35

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,5 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array:
return samples # TODO: remember to cast this to a named array

def log_prob(self, x: Array) -> Float:
output = jax.lax.cond(not jnp.where((x>=self.xmax) | (x<=self.xmin))[0].any(), lambda: 0., lambda: -jnp.inf)
output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x)))
return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin)))
23 changes: 20 additions & 3 deletions src/jimgw/waveform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jaxtyping import Array
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_polar
from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc
from ripple.waveforms.IMRPhenomPv2 import gen_IMRPhenomPv2_hphc
import jax.numpy as jnp
from abc import ABC

Expand All @@ -22,10 +23,26 @@ def __call__(self, frequency: Array, params: dict) -> dict:
output = {}
ra = params['ra']
dec = params['dec']
theta = [params['M_c'], params['eta'], params['s1_z'], params['s2_z'], params['d_L'], 0, params['phase_c'], params['iota'], params['psi'], ra, dec]
hp, hc = gen_IMRPhenomD_polar(frequency, theta, self.f_ref)
theta = [params['M_c'], params['eta'], params['s1_z'], params['s2_z'], params['d_L'], 0, params['phase_c'], params['iota']]
hp, hc = gen_IMRPhenomD_hphc(frequency, theta, self.f_ref)
output['p'] = hp
output['c'] = hc
return output

class RippleIMRPhenomPv2(Waveform):

f_ref: float

def __init__(self, f_ref: float = 20.0):
self.f_ref = f_ref

def __call__(self, frequency: Array, params: dict) -> Array:
output = {}
theta = [params['M_c'], params['eta'], 0.0, 0.0, params['s1_z'],
0.0, 0.0, params['s2_z'],
params['d_L'], 0, params['phase_c'], params['iota']]
hp, hc = gen_IMRPhenomPv2_hphc(frequency, theta, self.f_ref)
output['p'] = hp
output['c'] = hc
return output