diff --git a/flavio/statistics/probability.py b/flavio/statistics/probability.py index 4f688901..b885a8aa 100644 --- a/flavio/statistics/probability.py +++ b/flavio/statistics/probability.py @@ -678,10 +678,10 @@ def __init__(self, xi, y, central_value=None): mode = np.asarray(np.meshgrid(*xi, indexing='ij'))[mode_index] super().__init__(central_value=mode, support=None) _bin_volume = np.prod([x[1] - x[0] for x in xi]) - _y_norm = y / np.sum(y) / _bin_volume # normalize PDF to 1 + self.y_norm = y / np.sum(y) / _bin_volume # normalize PDF to 1 # ignore warning from log(0)=-np.inf with np.errstate(divide='ignore', invalid='ignore'): - self.logpdf_interp = scipy.interpolate.RegularGridInterpolator(xi, np.log(_y_norm), + self.logpdf_interp = scipy.interpolate.RegularGridInterpolator(xi, np.log(self.y_norm), fill_value=-np.inf, bounds_error=False) # the following is needed for get_random: initialize to None self._y_flat = None @@ -732,13 +732,24 @@ def logpdf(self, x, exclude=None): Parameters: - x: vector; position at which PDF should be evaluated - - Note: the exclude parameter is not implemented yet. + - exclude: optional; if an iterable of integers is given, the parameters + at these positions will be ignored by maximizing the likelihood + along the remaining directions, i.e., they will be "profiled out". """ if exclude is not None: - raise NotImplementedError( - "Excluding individual parameters from multivariate numerical distributions not implemented") - + try: + exclude = tuple(exclude) + except TypeError: + exclude = (exclude,) + xi = np.delete(self.xi, tuple(exclude), axis=0) + y = np.amax(self.y_norm, axis=tuple(exclude)) + cv = np.delete(self.central_value, tuple(exclude)) + if len(xi) == 1: + # if there is just 1 dimension left, use univariate + dist = NumericalDistribution(xi[0], y, cv) + else: + dist = MultivariateNumericalDistribution(xi, y, cv) + return dist.logpdf(x) if np.asarray(x).shape == (len(self.central_value),): # return a scalar return self.logpdf_interp(x)[0] diff --git a/flavio/statistics/test_probability.py b/flavio/statistics/test_probability.py index 66afd62f..437b3b34 100644 --- a/flavio/statistics/test_probability.py +++ b/flavio/statistics/test_probability.py @@ -99,13 +99,13 @@ def test_multiv_numerical(self): p_num = MultivariateNumericalDistribution((x0, x1), y_crazy) p_norm = MultivariateNormalDistribution([0, 1], cov) self.assertAlmostEqual(p_num.logpdf([0.237, 0.346]), p_norm.logpdf([0.237, 0.346]), delta=0.02) + self.assertAlmostEqual(p_num.logpdf([0.237], exclude=(1,)), + p_norm.logpdf([0.237], exclude=(1,)), delta=0.02) # test exceptions with self.assertRaises(NotImplementedError): p_num.error_left with self.assertRaises(NotImplementedError): p_num.error_right - with self.assertRaises(NotImplementedError): - p_num.logpdf([0.237, 0.346], exclude=(0)) self.assertEqual(len(p_num.get_random(10)), 10) def test_numerical_from_analytic(self): @@ -271,10 +271,10 @@ def test_vectorize(self): xr2 = np.random.rand(10, 2) self.assertEqual(d.logpdf(xr3[0]).shape, ()) self.assertEqual(d.logpdf(xr3).shape, (10,)) - self.assertEqual(d.logpdf(xr2[0], exclude=[0]).shape, ()) - self.assertEqual(d.logpdf(xr2, exclude=[0]).shape, (10,)) - self.assertEqual(d.logpdf(xr[0], exclude=[0, 1]).shape, ()) - self.assertEqual(d.logpdf(xr, exclude=[0, 1]).shape, (10,)) + self.assertEqual(d.logpdf(xr2[0], exclude=(0)).shape, ()) + self.assertEqual(d.logpdf(xr2, exclude=(0)).shape, (10,)) + self.assertEqual(d.logpdf(xr[0], exclude=(0, 1)).shape, ()) + self.assertEqual(d.logpdf(xr, exclude=(0, 1)).shape, (10,)) xi = [np.linspace(-1,1,5), np.linspace(-1,1,6), np.linspace(-1,1,7)] y = np.random.rand(5,6,7) d = MultivariateNumericalDistribution(xi, y) @@ -282,3 +282,7 @@ def test_vectorize(self): xr2 = np.random.rand(10, 2) self.assertEqual(d.logpdf(xr3[0]).shape, ()) self.assertEqual(d.logpdf(xr3).shape, (10,)) + self.assertEqual(d.logpdf(xr2[0], exclude=(0)).shape, ()) + self.assertEqual(d.logpdf(xr2, exclude=(0)).shape, (10,)) + self.assertEqual(d.logpdf(xr[0], exclude=(0, 1)).shape, ()) + self.assertEqual(d.logpdf(xr, exclude=(0, 1)).shape, (10,))