Skip to content

Commit

Permalink
add a check for overlap of phase band and docstring
Browse files Browse the repository at this point in the history
Signed-off-by: Maxime Regeard <regeard@apc.in2p3.fr>
  • Loading branch information
MRegeard committed Jun 28, 2023
1 parent 6db2866 commit 255da27
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions gammapy/data/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import copy
import logging
import numpy as np
from astropy.table import unique, vstack
import pandas as pd

__all__ = ["ObservationFilter"]

Expand All @@ -24,8 +26,10 @@ class ObservationFilter:
class that corresponds to the filter type
(see `~gammapy.data.ObservationFilter.EVENT_FILTER_TYPES`)
The filtered event list will be an intersection of all filters. A union
of filters is not supported yet.
event_filter_method: str
The filterring method to use on the event_filters. Either "intersect" or "union". Default is "intersect".
If "intersect", the filtered event list will be the intersection of all the event filters.
If "union", the filtered event list will be the union of all the event fileters, under uniquness condition.
Examples
--------
Expand All @@ -45,10 +49,12 @@ class that corresponds to the filter type

EVENT_FILTER_TYPES = dict(sky_region="select_region", custom="select_parameter")

def __init__(self, time_filter=None, event_filters=None, event_logic="and"):
def __init__(
self, time_filter=None, event_filters=None, event_filter_method="intersect"
):
self.time_filter = time_filter
self.event_filters = event_filters or []
self.event_logic = event_logic
self.event_filters = self._check_overlap_phase(event_filters) or []
self.event_filter_method = event_filter_method

@property
def livetime_fraction(self):
Expand All @@ -72,15 +78,15 @@ def filter_events(self, events):

filtered_events = self._filter_by_time(events)

if self.event_logic == "and":
if self.event_filter_method == "intersect":

for f in self.event_filters:
method_str = self.EVENT_FILTER_TYPES[f["type"]]
filtered_events = getattr(filtered_events, method_str)(**f["opts"])

return filtered_events

elif self.event_logic == "or":
elif self.event_filter_method == "union":

filtered_events = []
for f in self.event_filters:
Expand All @@ -93,6 +99,11 @@ def filter_events(self, events):
tot_filtered_events = EventList(table)
return tot_filtered_events

else:
raise ValueError(
"event_filter_method has to be either 'intersect' or 'union'."
)

def filter_gti(self, gti):
"""Apply filters to a GTI table.
Expand Down Expand Up @@ -131,3 +142,22 @@ def _check_filter_phase(event_filter):
fraction += band[1] - band[0]

return 1 if fraction == 0 else fraction

@staticmethod
def _check_overlap_phase(event_filter):
bands = []
for f in event_filter:
if f.get("opts").get("parameter") == "PHASE":
bands.append(f.get("opts").get("band"))

intervals = pd.arrays.IntervalArray.from_tuples(bands)
interval_matrix = []
for b in bands:
interval_matrix.append(intervals.overlaps(pd.Interval(b[0], b[1])))

interval_matrix = np.array(interval_matrix)
overlap_array = interval_matrix[~np.eye(interval_matrix.shape[0], dtype=bool)]
if True in overlap_array:
raise ValueError(
"Overlapping bands in event_filters that apply to pulsar phase are not allowed."
)

0 comments on commit 255da27

Please sign in to comment.