Skip to content

Commit

Permalink
fixing impact parameter shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Sep 19, 2019
1 parent 2e8c760 commit 91fc861
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
16 changes: 15 additions & 1 deletion exoplanet/distributions/physical.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def logp(self, value):


class ImpactParameter(pm.Flat):
"""The impact parameter distribution for a transiting planet
Args:
ror: A scalar, tensor, or PyMC3 distribution representing the radius
ratio between the planet and star. Conditioned on a value of
``ror``, this will be uniformly distributed between ``0`` and
``1+ror``.
"""

def __init__(self, ror=None, **kwargs):
if ror is None:
raise ValueError("missing required parameter 'ror'")
Expand All @@ -76,11 +86,15 @@ def __init__(self, ror=None, **kwargs):
"transform", tr.ImpactParameterTransform(self.ror)
)

shape = kwargs.get("shape", None)
try:
shape = kwargs.get("shape", self.ror.distribution.shape)
except AttributeError:
shape = None
if shape is None:
testval = 0.5
else:
testval = 0.5 + np.zeros(shape)
kwargs["shape"] = shape
kwargs["testval"] = kwargs.pop("testval", testval)

super(ImpactParameter, self).__init__(**kwargs)
Expand Down
8 changes: 4 additions & 4 deletions exoplanet/distributions/physical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .physical import ImpactParameter, QuadLimbDark


class TestPhyscial(_Base):
class TestPhysical(_Base):
random_seed = 19860925

def test_quad_limb_dark(self):
Expand Down Expand Up @@ -46,8 +46,8 @@ def test_impact(self):
lower = 0.1
upper = 1.0
with self._model():
ror = pm.Uniform("ror", lower=lower, upper=upper)
dist = ImpactParameter("b", ror=ror, shape=(5, 2))
ror = pm.Uniform("ror", lower=lower, upper=upper, shape=(5, 2))
dist = ImpactParameter("b", ror=ror)

# Test random sampling
samples = dist.random(size=100)
Expand All @@ -63,4 +63,4 @@ def test_impact(self):
s, p = kstest(u[:, i], cdf)
assert s < 0.05

assert np.all(trace["b"] <= 1 + trace["ror"][:, None, None])
assert np.all(trace["b"] <= 1 + trace["ror"])

0 comments on commit 91fc861

Please sign in to comment.