Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add different options to compute stat_array on FluxPointsDatasets #5135

Merged
merged 25 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 84 additions & 14 deletions examples/tutorials/analysis-1d/spectral_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,23 +380,13 @@
#
# To round up our analysis we can compute flux points by fitting the norm
# of the global model in energy bands.
# We can utilise the `~gammapy.estimators.utils.resample_energy_edges`
# for defining the energy bins in which the minimum number of `sqrt_ts` is 2.
# To do so we first stack the individual datasets, only for obtaining the energies:
#

dataset_stacked = Datasets(datasets).stack_reduce()
energy_edges = resample_energy_edges(dataset_stacked, conditions={"sqrt_ts_min": 2})


######################################################################
# Now we create an instance of the
# We create an instance of the
# `~gammapy.estimators.FluxPointsEstimator`, by passing the dataset and
# the energy binning:
#

fpe = FluxPointsEstimator(
energy_edges=energy_edges, source="crab", selection_optional="all"
energy_edges=energy_axis.edges, source="crab", selection_optional="all"
)
flux_points = fpe.run(datasets=datasets)

Expand All @@ -416,6 +406,7 @@
fig, ax = plt.subplots()
flux_points.plot(ax=ax, sed_type="e2dnde", color="darkorange")
flux_points.plot_ts_profiles(ax=ax, sed_type="e2dnde")
ax.set_xlim(0.6, 40)
plt.show()

######################################################################
Expand All @@ -430,8 +421,11 @@
# quickly made like this:
#

flux_points_dataset = FluxPointsDataset(data=flux_points, models=model_best_joint)
flux_points_dataset.plot_fit()
flux_points_dataset = FluxPointsDataset(
data=flux_points, models=model_best_joint.copy()
)
ax, _ = flux_points_dataset.plot_fit()
ax.set_xlim(0.6, 40)
plt.show()


Expand Down Expand Up @@ -503,6 +497,82 @@

# sphinx_gallery_thumbnail_number = 5

######################################################################
QRemy marked this conversation as resolved.
Show resolved Hide resolved
# A note on statistics
# --------------------
#
# Different statistic are available for the FluxPointDataset :
# * chi2 : estimate from chi2 statistics.
# * profile : estimate from interpolation of the likelihood profile.
# * distrib : estimate from probability distributions,
# assuming that flux points correspond to asymmetric gaussians
# and upper limits complementary error functions.
# Default is `chi2`, in that case upper limits are ignored and the mean of asymetrics error is used.
# So it is recommended to use `profile` if `stat_scan` is available on flux points.
# The `distrib` case provides an approximation if the `profile` is not available
# which allows to take into accounts upper limit and asymetrics error.
#
# In the example below we can see that the `profile` case matches exactly the result
# from the joint analysis of the ON/OFF datasets using `wstat` (as labelled).


def plot_stat(fp_dataset):
fig, ax = plt.subplots()

plot_kwargs = {
"energy_bounds": [0.1, 30] * u.TeV,
"sed_type": "e2dnde",
"ax": ax,
}

fp_dataset.data.plot(energy_power=2, ax=ax)
model_best_joint.spectral_model.plot(
color="b", lw=0.5, **plot_kwargs, label="wstat"
)

stat_types = ["chi2", "profile", "distrib"]
colors = ["red", "g", "c"]
lss = ["--", ":", "--"]

for ks, stat in enumerate(stat_types):

fp_dataset.stat_type = stat

fit = Fit()
fit.run([fp_dataset])

fp_dataset.models[0].spectral_model.plot(
color=colors[ks], ls=lss[ks], **plot_kwargs, label=stat
)
fp_dataset.models[0].spectral_model.plot_error(
facecolor=colors[ks], **plot_kwargs
)
plt.legend()


plot_stat(flux_points_dataset)

######################################################################

# In order to avoid discrepancies due to the treatment of upper limits
# we can utilise the `~gammapy.estimators.utils.resample_energy_edges`
# for defining energy bins in which the minimum number of `sqrt_ts` is 2.
# In that case all the statistics definitions give equivalent results.
#

energy_edges = resample_energy_edges(dataset_stacked, conditions={"sqrt_ts_min": 2})

fpe_no_ul = FluxPointsEstimator(
energy_edges=energy_edges, source="crab", selection_optional="all"
)
flux_points_no_ul = fpe_no_ul.run(datasets=datasets)
flux_points_dataset_no_ul = FluxPointsDataset(
data=flux_points_no_ul,
models=model_best_joint.copy(),
)

plot_stat(flux_points_dataset_no_ul)

######################################################################
# Exercises
# ---------
Expand Down
154 changes: 148 additions & 6 deletions gammapy/datasets/flux_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import logging
import numpy as np
from scipy.special import erfc
from scipy.stats import norm
from astropy import units as u
from astropy.io import fits
from astropy.table import Table
Expand All @@ -14,6 +16,7 @@
SkyModel,
TemplateSpatialModel,
)
from gammapy.utils.interpolation import interpolate_profile
from gammapy.utils.scripts import make_name, make_path
from .core import Dataset

Expand Down Expand Up @@ -55,6 +58,25 @@
meta_table : `~astropy.table.Table`
Table listing information on observations used to create the dataset.
One line per observation for stacked datasets.
stat_type : str
Method used to compute the statistics:
* chi2 : estimate from chi2 statistics.
* profile : estimate from interpolation of the likelihood profile.
* distrib : Assuming gaussian errors the likelihood is given by the
probability density function of the normal distribution.
For the upper limit case it is necessary to marginalize over the unknown measurement,
So we integrate the normal distribution up to the upper limit value
which gives the complementary error function.
See eq. C7 of Mohanty et al (2013) :
https://iopscience.iop.org/article/10.1088/0004-637X/773/2/168/pdf

Default is `chi2`, in that case upper limits are ignored and the mean of asymetrics error is used.
However it is recommended to use `profile` if `stat_scan` is available on flux points.
The `distrib` case provides an approximation if the profile is not available.
stat_kwargs : dict
Extra arguments specifying the interpolation scheme of the likelihood profile.
Used only if `stat_type=="profile"`. In that case the default is :
`stat_kwargs={"interp_scale":"sqrt", "extrapolate":True}

Examples
--------
Expand Down Expand Up @@ -102,7 +124,6 @@
``gammapy download datasets --tests --out $GAMMAPY_DATA``
"""

stat_type = "chi2"
tag = "FluxPointsDataset"

def __init__(
Expand All @@ -113,6 +134,8 @@
mask_safe=None,
name=None,
meta_table=None,
stat_type="chi2",
stat_kwargs=None,
):
if not data.geom.has_energy_axis:
raise ValueError("FluxPointsDataset needs an energy axis")
Expand All @@ -122,11 +145,62 @@
self.models = models
self.meta_table = meta_table

if mask_safe is None:
mask_safe = (~data.is_ul).data
self._available_stat_type = dict(
chi2=self._stat_array_chi2,
profile=self._stat_array_profile,
distrib=self._stat_array_distrib,
)

if stat_kwargs is None:
stat_kwargs = dict()
self.stat_kwargs = stat_kwargs

self.stat_type = stat_type

if mask_safe is None:
QRemy marked this conversation as resolved.
Show resolved Hide resolved
mask_safe = np.ones(self.data.dnde.data.shape, dtype=bool)
self.mask_safe = mask_safe

@property
def available_stat_type(self):
return list(self._available_stat_type.keys())

@property
def stat_type(self):
return self._stat_type

@stat_type.setter
def stat_type(self, stat_type):
QRemy marked this conversation as resolved.
Show resolved Hide resolved

if stat_type not in self.available_stat_type:
raise ValueError(

Check warning on line 176 in gammapy/datasets/flux_points.py

View check run for this annotation

Codecov / codecov/patch

gammapy/datasets/flux_points.py#L176

Added line #L176 was not covered by tests
f"Invalid stat_type: possible options are {self.available_stat_type}"
)

if stat_type == "chi2":
self._mask_valid = (~self.data.is_ul).data & np.isfinite(self.data.dnde)
elif stat_type == "distrib":
self._mask_valid = (
self.data.is_ul.data & np.isfinite(self.data.dnde_ul)
) | np.isfinite(self.data.dnde)
elif stat_type == "profile":
self.stat_kwargs.setdefault("interp_scale", "sqrt")
self.stat_kwargs.setdefault("extrapolate", True)
self._profile_interpolators = self._get_valid_profile_interpolators()
self._stat_type = stat_type

@property
def mask_valid(self):
return self._mask_valid

@property
def mask_safe(self):
return self._mask_safe & self.mask_valid

@mask_safe.setter
def mask_safe(self, mask_safe):
self._mask_safe = mask_safe

@property
def name(self):
return self._name
Expand Down Expand Up @@ -310,13 +384,81 @@

def stat_array(self):
"""Fit statistic array."""
return self._available_stat_type[self.stat_type]()

def _stat_array_chi2(self):
"""Chi2 statistics."""
model = self.flux_pred()
data = self.data.dnde.quantity
try:
sigma = self.data.dnde_err
sigma = self.data.dnde_err.quantity
except AttributeError:
sigma = (self.data.dnde_errn + self.data.dnde_errp) / 2
return ((data - model) / sigma.quantity).to_value("") ** 2
sigma = (self.data.dnde_errn + self.data.dnde_errp).quantity / 2

Check warning on line 396 in gammapy/datasets/flux_points.py

View check run for this annotation

Codecov / codecov/patch

gammapy/datasets/flux_points.py#L396

Added line #L396 was not covered by tests
return ((data - model) / sigma).to_value("") ** 2

def _stat_array_profile(self):
"""Estimate statitistic from interpolation of the likelihood profile."""
model = np.zeros(self.data.dnde.data.shape) + (
self.flux_pred() / self.data.dnde_ref
).to_value("")
stat = np.zeros(model.shape)
for idx in np.ndindex(self._profile_interpolators.shape):
stat[idx] = self._profile_interpolators[idx](model[idx])
return stat

def _get_valid_profile_interpolators(self):
value_scan = self.data.stat_scan.geom.axes["norm"].center
shape_axes = self.data.stat_scan.geom._shape[slice(3, None)][::-1]
interpolators = np.empty(shape_axes, dtype=object)
self._mask_valid = np.ones(self.data.dnde.data.shape, dtype=bool)
for idx in np.ndindex(shape_axes):
stat_scan = np.abs(
registerrier marked this conversation as resolved.
Show resolved Hide resolved
self.data.stat_scan.data[idx].squeeze()
QRemy marked this conversation as resolved.
Show resolved Hide resolved
- self.data.stat.data[idx].squeeze()
)
self._mask_valid[idx] = np.all(np.isfinite(stat_scan))
interpolators[idx] = interpolate_profile(
registerrier marked this conversation as resolved.
Show resolved Hide resolved
value_scan,
stat_scan,
interp_scale=self.stat_kwargs["interp_scale"],
extrapolate=self.stat_kwargs["extrapolate"],
)
return interpolators

def _stat_array_distrib(self):
"""Estimate statistic from probability distributions,
assumes that flux points correspond to asymmetric gaussians
and upper limits complementary error functions.
"""
model = np.zeros(self.data.dnde.data.shape) + self.flux_pred().to_value(
self.data.dnde.unit
)

stat = np.zeros(model.shape)

mask_valid = ~np.isnan(self.data.dnde.data)
loc = self.data.dnde.data[mask_valid]
value = model[mask_valid]
try:
mask_p = (model >= self.data.dnde.data)[mask_valid]
scale = np.zeros(mask_p.shape)
scale[mask_p] = self.data.dnde_errp.data[mask_valid][mask_p]
scale[~mask_p] = self.data.dnde_errn.data[mask_valid][~mask_p]
except AttributeError:
scale = self.data.dnde_err.data[mask_valid]
stat[mask_valid] = -2 * np.log(
norm.pdf(value, loc=loc, scale=scale) / norm.pdf(loc, loc=loc, scale=scale)
)

mask_ul = self.data.is_ul.data & ~np.isnan(self.data.dnde_ul.data)
value = model[mask_ul]
loc_ul = self.data.dnde_ul.data[mask_ul]
scale_ul = self.data.dnde_ul.data[mask_ul]
stat[mask_ul] = 2 * np.log(
(erfc((loc_ul - value) / scale_ul) / 2)
/ (erfc((loc_ul - 0) / scale_ul) / 2)
)
return stat

def residuals(self, method="diff"):
"""Compute flux point residuals.
Expand Down
16 changes: 16 additions & 0 deletions gammapy/datasets/tests/test_flux_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ def test_flux_point_dataset_flux_pred(dataset):
assert_allclose(dataset.flux_pred()[0].value, 0.000472, rtol=1e-3)


@requires_data()
def test_flux_point_dataset_stat(dataset):
dataset.stat_type = "chi2"
fit = Fit()
fit.run([dataset])
assert_allclose(dataset.stat_sum(), 25.205933, rtol=1e-3)

dataset.stat_type = "distrib"
QRemy marked this conversation as resolved.
Show resolved Hide resolved
fit = Fit()
fit.run([dataset])
assert_allclose(dataset.stat_sum(), 36.153428, rtol=1e-3)


def test_flux_point_dataset_with_time_axis(tmp_path):
meta = dict(TIMESYS="utc", SED_TYPE="flux")

Expand Down Expand Up @@ -158,6 +171,9 @@ def test_flux_point_dataset_with_time_axis(tmp_path):
with pytest.raises(ValueError):
flux_points_dataset.plot_residuals()

flux_points_dataset.stat_type = "distrib"
QRemy marked this conversation as resolved.
Show resolved Hide resolved
assert_allclose(flux_points_dataset.stat_sum(), 193.8093, rtol=1e-3)


@requires_data()
class TestFluxPointFit:
Expand Down
Loading
Loading