Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introducing union support in ObservationFilter #4616

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 79 additions & 13 deletions gammapy/data/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import copy
import logging
from itertools import groupby
import numpy as np
from astropy.table import unique, vstack

__all__ = ["ObservationFilter"]

Expand All @@ -23,8 +26,10 @@ class ObservationFilter:
class that corresponds to the filter type
(see `~gammapy.data.ObservationFilter.EVENT_FILTER_TYPES`)

The filtered event list will be an intersection of all filters. A union
of filters is not supported yet.
event_filter_method: str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an example with this argument in the docstring exemple?

The filterring method to use on the event_filters. Either "intersect" or "union". Default is "intersect".
If "intersect", the filtered event list will be the intersection of all the event filters.
If "union", the filtered event list will be the union of all the event fileters, under uniquness condition.

Examples
--------
Expand All @@ -44,9 +49,13 @@ class that corresponds to the filter type

EVENT_FILTER_TYPES = dict(sky_region="select_region", custom="select_parameter")

def __init__(self, time_filter=None, event_filters=None):
def __init__(
self, time_filter=None, event_filters=None, event_filter_method="intersect"
):
self.time_filter = time_filter
self.event_filters = event_filters or []
event_filters = event_filters or []
self.event_filters = self._merge_overlapping(event_filters)
self.event_filter_method = event_filter_method

@property
def livetime_fraction(self):
Expand All @@ -66,13 +75,38 @@ def filter_events(self, events):
filtered_events : `~gammapy.data.EventListBase`
The filtered event list
"""
from gammapy.data import EventList

filtered_events = self._filter_by_time(events)

for f in self.event_filters:
method_str = self.EVENT_FILTER_TYPES[f["type"]]
filtered_events = getattr(filtered_events, method_str)(**f["opts"])
if self.event_filter_method == "intersect":

for f in self.event_filters:
method_str = self.EVENT_FILTER_TYPES[f["type"]]
filtered_events = getattr(filtered_events, method_str)(**f["opts"])

return filtered_events

elif self.event_filter_method == "union":

filtered_events_list = []
for f in self.event_filters:
method_str = self.EVENT_FILTER_TYPES[f["type"]]
filtered_events_list.append(
getattr(filtered_events, method_str)(**f["opts"]).table
)

return filtered_events
table = unique(
vstack(filtered_events_list, join_type="exact").sort("TIME"),
keys="TIME",
)
tot_filtered_events = EventList(table)
return tot_filtered_events

else:
raise ValueError(
"event_filter_method has to be either 'intersect' or 'union'."
)

def filter_gti(self, gti):
"""Apply filters to a GTI table.
Expand Down Expand Up @@ -105,11 +139,43 @@ def copy(self):

@staticmethod
def _check_filter_phase(event_filter):
if not event_filter:
return 1
fraction = 0
for f in event_filter:
if f.get("opts").get("parameter") == "PHASE":
band = f.get("opts").get("band")
return band[1] - band[0]
else:
return 1
fraction += band[1] - band[0]

return 1 if fraction == 0 else fraction

@staticmethod
def _merge_overlapping(event_filter):

group_list = []
not_custom_list = []

sky_region_indices = [
idx for idx, f in enumerate(event_filter) if f["type"] == "sky_region"
]
for idx in reversed(sky_region_indices):
not_custom_list.append(event_filter.pop(idx))

for _, value in groupby(event_filter, lambda k: k["opts"]["parameter"]):
group_list.append(list(value))

new_event_filter = []
for group in group_list:
group.sort(key=lambda interval: interval["opts"]["band"][0])
merged = [group[0]]
for dictio in group:
previous = merged[-1]
if dictio["opts"]["band"][0] <= previous["opts"]["band"][1]:
band_max = max(
previous["opts"]["band"][1], dictio["opts"]["band"][1]
)
previous["opts"]["band"] = (previous["opts"]["band"][0], band_max)
else:
merged.append(dictio)
new_event_filter.append(merged)

new_event_filter.append(not_custom_list)
return np.concatenate(new_event_filter)
7 changes: 7 additions & 0 deletions gammapy/data/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def test_filter_gti(observation):
"p_in": [],
"p_out": 1,
},
{
"p_in": [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add other test with an union and an intersection?

{"type": "custom", "opts": dict(parameter="PHASE", band=(0.2, 0.4))},
{"type": "custom", "opts": dict(parameter="PHASE", band=(0.6, 0.8))},
],
"p_out": 0.4,
},
],
)
def test_check_filter_phase(pars):
Expand Down
Loading