-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prior statistics and set-up for multidimensional priors #4907
base: main
Are you sure you want to change the base?
Changes from all commits
65d1a8b
6732a8c
9b0433e
527b728
fa2f3c7
0bd7a34
12f92b4
9574159
53d2440
280684c
01b9efa
f194f46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
cash, | ||
cash_sum_cython, | ||
get_wstat_mu_bkg, | ||
prior_fit_statistic, | ||
wstat, | ||
) | ||
from gammapy.utils.deprecation import deprecated_renamed_argument | ||
|
@@ -220,7 +221,6 @@ def __init__( | |
): | ||
self._name = make_name(name) | ||
self._evaluators = {} | ||
|
||
self.counts = counts | ||
self.exposure = exposure | ||
self.background = background | ||
|
@@ -354,7 +354,6 @@ def models(self, models): | |
use_cache=USE_NPRED_CACHE, | ||
) | ||
self._evaluators[model.name] = evaluator | ||
|
||
self._models = models | ||
|
||
@property | ||
|
@@ -1114,12 +1113,11 @@ def plot_residuals( | |
return ax_spatial, ax_spectral | ||
|
||
def stat_sum(self): | ||
"""Total statistic function value given the current model parameters and priors.""" | ||
"""Total statistic function value given the current model parameters and priors set on the models.""" | ||
counts, npred = self.counts.data.astype(float), self.npred().data | ||
prior_stat_sum = 0.0 | ||
if self.models is not None: | ||
prior_stat_sum = self.models.parameters.prior_stat_sum() | ||
|
||
counts, npred = self.counts.data.astype(float), self.npred().data | ||
prior_stat_sum = prior_fit_statistic(self.models.priors) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this double count the prior, when it is evaluated again at the |
||
|
||
if self.mask is not None: | ||
return ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import logging | ||
import numpy as np | ||
import astropy.units as u | ||
from gammapy.modeling import PriorParameter, PriorParameters | ||
from gammapy.modeling import Parameter, Parameters, PriorParameter, PriorParameters | ||
from .core import ModelBase | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -32,7 +32,15 @@ | |
class Prior(ModelBase): | ||
_unit = "" | ||
|
||
def __init__(self, **kwargs): | ||
def __init__(self, modelparameters, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think taking the |
||
|
||
if isinstance(modelparameters, Parameter): | ||
self._modelparameters = Parameters([modelparameters]) | ||
elif isinstance(modelparameters, Parameters): | ||
self._modelparameters = modelparameters | ||
else: | ||
raise ValueError(f"Invalid model type {modelparameters}") | ||
|
||
# Copy default parameters from the class to the instance | ||
default_parameters = self.default_parameters.copy() | ||
|
||
|
@@ -52,6 +60,13 @@ | |
else: | ||
self._weight = 1 | ||
|
||
for par in self._modelparameters: | ||
par.prior = self | ||
|
||
@property | ||
def modelparameters(self): | ||
return self._modelparameters | ||
|
||
@property | ||
def parameters(self): | ||
"""PriorParameters (`~gammapy.modeling.PriorParameters`)""" | ||
|
@@ -73,11 +88,11 @@ | |
def weight(self, value): | ||
self._weight = value | ||
|
||
def __call__(self, value): | ||
def __call__(self): | ||
"""Call evaluate method""" | ||
# assuming the same unit as the PriorParamter here | ||
# assuming the same unit as the PriorParameter here | ||
kwargs = {par.name: par.value for par in self.parameters} | ||
return self.weight * self.evaluate(value.value, **kwargs) | ||
return self.weight * self.evaluate(self._modelparameters.value, **kwargs) | ||
|
||
def to_dict(self, full_output=False): | ||
"""Create dict for YAML serialisation""" | ||
|
@@ -99,30 +114,33 @@ | |
): | ||
del par[item] | ||
|
||
data = {"type": tag, "parameters": params, "weight": self.weight} | ||
data = { | ||
"type": tag, | ||
"parameters": params, | ||
"weight": self.weight, | ||
"modelparameters": self._modelparameters, | ||
} | ||
|
||
if self.type is None: | ||
return data | ||
else: | ||
return {self.type: data} | ||
return data | ||
|
||
@classmethod | ||
def from_dict(cls, data): | ||
from . import PRIOR_REGISTRY | ||
|
||
prior_cls = PRIOR_REGISTRY.get_cls(data["type"]) | ||
kwargs = {} | ||
|
||
key0 = next(iter(data)) | ||
if key0 in ["prior"]: | ||
data = data[key0] | ||
if data["type"] not in cls.tag: | ||
if data["type"] not in prior_cls.tag: | ||
raise ValueError( | ||
f"Invalid model type {data['type']} for class {cls.__name__}" | ||
) | ||
|
||
priorparameters = _build_priorparameters_from_dict( | ||
data["parameters"], cls.default_parameters | ||
data["parameters"], prior_cls.default_parameters | ||
) | ||
kwargs["weight"] = data["weight"] | ||
return cls.from_parameters(priorparameters, **kwargs) | ||
kwargs["modelparameters"] = data["modelparameters"] | ||
|
||
return prior_cls.from_parameters(priorparameters, **kwargs) | ||
|
||
|
||
class GaussianPrior(Prior): | ||
|
@@ -144,8 +162,7 @@ | |
mu = PriorParameter(name="mu", value=0) | ||
sigma = PriorParameter(name="sigma", value=1) | ||
|
||
@staticmethod | ||
def evaluate(value, mu, sigma): | ||
def evaluate(self, value, mu, sigma): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return ((value - mu) / sigma) ** 2 | ||
|
||
|
||
|
@@ -172,8 +189,7 @@ | |
min = PriorParameter(name="min", value=-np.inf, unit="") | ||
max = PriorParameter(name="max", value=np.inf, unit="") | ||
|
||
@staticmethod | ||
def evaluate(value, min, max): | ||
def evaluate(self, value, min, max): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above... |
||
if min < value < max: | ||
return 1.0 | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does one prior count positive the other negative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Datasets.model
includes all models, why is there an independent sum fordataset.models
?