Skip to content

Commit

Permalink
[probability] Implemented exclude option for logpdf of MultivariateNu…
Browse files Browse the repository at this point in the history
…mericalDistribution
  • Loading branch information
DavidMStraub committed Feb 8, 2017
1 parent 394cbd7 commit 0ac8a97
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
25 changes: 18 additions & 7 deletions flavio/statistics/probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,10 @@ def __init__(self, xi, y, central_value=None):
mode = np.asarray(np.meshgrid(*xi, indexing='ij'))[mode_index]
super().__init__(central_value=mode, support=None)
_bin_volume = np.prod([x[1] - x[0] for x in xi])
_y_norm = y / np.sum(y) / _bin_volume # normalize PDF to 1
self.y_norm = y / np.sum(y) / _bin_volume # normalize PDF to 1
# ignore warning from log(0)=-np.inf
with np.errstate(divide='ignore', invalid='ignore'):
self.logpdf_interp = scipy.interpolate.RegularGridInterpolator(xi, np.log(_y_norm),
self.logpdf_interp = scipy.interpolate.RegularGridInterpolator(xi, np.log(self.y_norm),
fill_value=-np.inf, bounds_error=False)
# the following is needed for get_random: initialize to None
self._y_flat = None
Expand Down Expand Up @@ -732,13 +732,24 @@ def logpdf(self, x, exclude=None):
Parameters:
- x: vector; position at which PDF should be evaluated
Note: the exclude parameter is not implemented yet.
- exclude: optional; if an iterable of integers is given, the parameters
at these positions will be ignored by maximizing the likelihood
along the remaining directions, i.e., they will be "profiled out".
"""
if exclude is not None:
raise NotImplementedError(
"Excluding individual parameters from multivariate numerical distributions not implemented")

try:
exclude = tuple(exclude)
except TypeError:
exclude = (exclude,)
xi = np.delete(self.xi, tuple(exclude), axis=0)
y = np.amax(self.y_norm, axis=tuple(exclude))
cv = np.delete(self.central_value, tuple(exclude))
if len(xi) == 1:
# if there is just 1 dimension left, use univariate
dist = NumericalDistribution(xi[0], y, cv)
else:
dist = MultivariateNumericalDistribution(xi, y, cv)
return dist.logpdf(x)
if np.asarray(x).shape == (len(self.central_value),):
# return a scalar
return self.logpdf_interp(x)[0]
Expand Down
16 changes: 10 additions & 6 deletions flavio/statistics/test_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def test_multiv_numerical(self):
p_num = MultivariateNumericalDistribution((x0, x1), y_crazy)
p_norm = MultivariateNormalDistribution([0, 1], cov)
self.assertAlmostEqual(p_num.logpdf([0.237, 0.346]), p_norm.logpdf([0.237, 0.346]), delta=0.02)
self.assertAlmostEqual(p_num.logpdf([0.237], exclude=(1,)),
p_norm.logpdf([0.237], exclude=(1,)), delta=0.02)
# test exceptions
with self.assertRaises(NotImplementedError):
p_num.error_left
with self.assertRaises(NotImplementedError):
p_num.error_right
with self.assertRaises(NotImplementedError):
p_num.logpdf([0.237, 0.346], exclude=(0))
self.assertEqual(len(p_num.get_random(10)), 10)

def test_numerical_from_analytic(self):
Expand Down Expand Up @@ -271,14 +271,18 @@ def test_vectorize(self):
xr2 = np.random.rand(10, 2)
self.assertEqual(d.logpdf(xr3[0]).shape, ())
self.assertEqual(d.logpdf(xr3).shape, (10,))
self.assertEqual(d.logpdf(xr2[0], exclude=[0]).shape, ())
self.assertEqual(d.logpdf(xr2, exclude=[0]).shape, (10,))
self.assertEqual(d.logpdf(xr[0], exclude=[0, 1]).shape, ())
self.assertEqual(d.logpdf(xr, exclude=[0, 1]).shape, (10,))
self.assertEqual(d.logpdf(xr2[0], exclude=(0)).shape, ())
self.assertEqual(d.logpdf(xr2, exclude=(0)).shape, (10,))
self.assertEqual(d.logpdf(xr[0], exclude=(0, 1)).shape, ())
self.assertEqual(d.logpdf(xr, exclude=(0, 1)).shape, (10,))
xi = [np.linspace(-1,1,5), np.linspace(-1,1,6), np.linspace(-1,1,7)]
y = np.random.rand(5,6,7)
d = MultivariateNumericalDistribution(xi, y)
xr3 = np.random.rand(10, 3)
xr2 = np.random.rand(10, 2)
self.assertEqual(d.logpdf(xr3[0]).shape, ())
self.assertEqual(d.logpdf(xr3).shape, (10,))
self.assertEqual(d.logpdf(xr2[0], exclude=(0)).shape, ())
self.assertEqual(d.logpdf(xr2, exclude=(0)).shape, (10,))
self.assertEqual(d.logpdf(xr[0], exclude=(0, 1)).shape, ())
self.assertEqual(d.logpdf(xr, exclude=(0, 1)).shape, (10,))

0 comments on commit 0ac8a97

Please sign in to comment.