From 3eefe9776018246a8bcc1754ac5be64f0d63a032 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 3 Apr 2019 15:53:33 +0200 Subject: [PATCH] Add cash_sum_cython and cstat_sum_cython reference tests --- gammapy/cube/fit.py | 2 ++ gammapy/stats/tests/test_fit_statistics.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/gammapy/cube/fit.py b/gammapy/cube/fit.py index d1249664ba..e7a23f296a 100644 --- a/gammapy/cube/fit.py +++ b/gammapy/cube/fit.py @@ -165,6 +165,8 @@ def likelihood(self, parameters, mask=None): """ counts, npred = self._counts_data, self.npred().data + #TODO: add mask handling to _stat_sum, so that the temp copy + # created by the fancy indexing is avoided if self.mask is None and mask is None: stat = self._stat_sum(counts.ravel(), npred.ravel()) elif self.mask is None: diff --git a/gammapy/stats/tests/test_fit_statistics.py b/gammapy/stats/tests/test_fit_statistics.py index 36ab5a5f46..0104579807 100644 --- a/gammapy/stats/tests/test_fit_statistics.py +++ b/gammapy/stats/tests/test_fit_statistics.py @@ -105,18 +105,32 @@ def test_wstat(test_data, reference_values): assert_allclose(statsvec, reference_values["wstat"]) -@requires_dependency("sherpa") def test_cash(test_data, reference_values): statsvec = stats.cash(n_on=test_data["n_on"], mu_on=test_data["mu_sig"]) assert_allclose(statsvec, reference_values["cash"]) -@requires_dependency("sherpa") def test_cstat(test_data, reference_values): statsvec = stats.cstat(n_on=test_data["n_on"], mu_on=test_data["mu_sig"]) assert_allclose(statsvec, reference_values["cstat"]) +def test_cash_sum_cython(test_data): + counts = np.array(test_data["n_on"], dtype=float) + npred = np.array(test_data["mu_sig"], dtype=float) + stat = stats.cash_sum_cython(counts=counts, npred=npred) + ref = stats.cash(counts, npred).sum() + assert_allclose(stat, ref) + + +def test_ctstat_sum_cython(test_data): + counts = np.array(test_data["n_on"], dtype=float) + npred = np.array(test_data["mu_sig"], dtype=float) + stat = stats.cstat_sum_cython(counts=counts, npred=npred) + ref = stats.cstat(counts, npred).sum() + assert_allclose(stat, ref) + + def test_wstat_corner_cases(): """test WSTAT formulae for corner cases""" n_on = 0