Skip to content

Commit

Permalink
Adapt gammapy/cube/fit.py
Browse files Browse the repository at this point in the history
  • Loading branch information
adonath committed Apr 3, 2019
1 parent ccdcb73 commit f98867d
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions gammapy/cube/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import astropy.units as u
from astropy.nddata.utils import NoOverlapError
from ..utils.fitting import Parameters, Dataset
from ..stats import cash, cstat
from ..stats import cash, cstat, cash_sum_cython, cstat_sum_cython
from ..maps import Map, MapAxis
from .models import SkyModel, SkyModels

Expand Down Expand Up @@ -73,8 +73,10 @@ def __init__(

if likelihood == "cash":
self._stat = cash
self._stat_sum = cash_sum_cython
elif likelihood == "cstat":
self._stat = cstat
self._stat_sum = cstat_sum_cython
else:
raise ValueError("Invalid likelihood: {!r}".format(likelihood))

Expand Down Expand Up @@ -149,6 +151,10 @@ def likelihood_per_bin(self):
"""Likelihood per bin given the current model parameters"""
return self._stat(n_on=self.counts.data, mu_on=self.npred().data)

@lazyproperty
def _counts_data(self):
return self.counts.data.astype(float)

def likelihood(self, parameters, mask=None):
"""Total likelihood given the current model parameters.
Expand All @@ -157,16 +163,19 @@ def likelihood(self, parameters, mask=None):
mask : `~numpy.ndarray`
Mask to be combined with the dataset mask.
"""
counts, npred = self._counts_data, self.npred().data

if self.mask is None and mask is None:
stat = self.likelihood_per_bin()
stat = self._stat_sum(counts.ravel(), npred.ravel())
elif self.mask is None:
stat = self.likelihood_per_bin()[mask]
stat = self._stat_sum(counts[mask], npred[mask])
elif mask is None:
stat = self.likelihood_per_bin()[self.mask.data]
stat = self._stat_sum(counts[self.mask.data], npred[self.mask.data])
else:
stat = self.likelihood_per_bin()[mask & self.mask.data]
mask_joined = mask & self.mask.data
stat = self._stat_sum(counts[mask_joined], npred[mask_joined])

return np.sum(stat, dtype=np.float64)
return stat


class MapEvaluator:
Expand Down

0 comments on commit f98867d

Please sign in to comment.