-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #216 from sglvladi/gaters
Add Gater classes
- Loading branch information
Showing
12 changed files
with
175 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Gater | ||
===== | ||
|
||
.. automodule:: stonesoup.gater | ||
:no-members: | ||
|
||
.. automodule:: stonesoup.gater.base | ||
:show-inheritance: | ||
|
||
Distance | ||
-------- | ||
|
||
.. automodule:: stonesoup.gater.distance | ||
:show-inheritance: | ||
|
||
Filtered | ||
-------- | ||
|
||
.. automodule:: stonesoup.gater.filtered | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
from .base import Gater | ||
|
||
__all__ = ['Gater'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- coding: utf-8 -*- | ||
from ..base import Property | ||
from ..hypothesiser import Hypothesiser | ||
|
||
|
||
class Gater(Hypothesiser): | ||
"""Gater base class | ||
Gaters wrap :class:`.Hypothesiser` objects and can be used to modify (typically reduce) the | ||
returned hypotheses. | ||
""" | ||
|
||
hypothesiser = Property( | ||
Hypothesiser, doc="Hypothesiser that is being wrapped.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# -*- coding: utf-8 -*- | ||
from ..base import Property | ||
from ..measures import Measure | ||
from ..types.multihypothesis import MultipleHypothesis | ||
from .base import Gater | ||
|
||
|
||
class DistanceGater(Gater): | ||
""" Distance based gater | ||
Uses a measure to calculate the distance between a hypothesis' measurement prediction and the | ||
hypothised measurement, then removes any hypotheses whose calculated distance exceeds the | ||
specified gate threshold. | ||
""" | ||
measure = Property(Measure, | ||
doc="Measure class used to calculate the distance between the measurement " | ||
"prediction and the hypothesised measurement.") | ||
gate_threshold = Property(float, | ||
doc="The gate threshold. Hypotheses whose calculated distance " | ||
"exceeds this threshold will be filtered out.") | ||
|
||
def hypothesise(self, track, detections, *args, **kwargs): | ||
|
||
hypotheses = self.hypothesiser.hypothesise(track, detections, *args, **kwargs) | ||
|
||
gated_hypotheses = [hypothesis for hypothesis in hypotheses | ||
if (not hypothesis | ||
or self.measure(hypothesis.measurement_prediction, | ||
hypothesis.measurement) < self.gate_threshold)] | ||
|
||
return MultipleHypothesis(sorted(gated_hypotheses, reverse=True)) |
6 changes: 2 additions & 4 deletions
6
stonesoup/hypothesiser/filtered.py → stonesoup/gater/filtered.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# -*- coding: utf-8 -*- | ||
import pytest | ||
|
||
from ...types.prediction import ( | ||
GaussianMeasurementPrediction, GaussianStatePrediction) | ||
|
||
|
||
@pytest.fixture() | ||
def predictor(): | ||
class TestGaussianPredictor: | ||
def predict(self, prior, control_input=None, timestamp=None, **kwargs): | ||
return GaussianStatePrediction(prior.state_vector + 1, | ||
prior.covar * 2, timestamp) | ||
return TestGaussianPredictor() | ||
|
||
|
||
@pytest.fixture() | ||
def updater(): | ||
class TestGaussianUpdater: | ||
def predict_measurement(self, state_prediction, | ||
measurement_model=None, **kwargs): | ||
return GaussianMeasurementPrediction(state_prediction.state_vector, | ||
state_prediction.covar, | ||
state_prediction.timestamp) | ||
return TestGaussianUpdater() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import datetime | ||
import pytest | ||
import numpy as np | ||
|
||
from ..distance import DistanceGater | ||
from ...hypothesiser.probability import PDAHypothesiser | ||
from ...types.detection import Detection | ||
from ...types.hypothesis import SingleHypothesis | ||
from ...types.track import Track | ||
from ...types.update import GaussianStateUpdate | ||
from ... import measures as measures | ||
|
||
measure = measures.Mahalanobis() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"detections, gate_threshold, num_gated", | ||
[ | ||
( # Test 1 | ||
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])), | ||
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))}, | ||
1, | ||
3 | ||
), | ||
( # Test 2 | ||
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])), | ||
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))}, | ||
2, | ||
5 | ||
), | ||
( # Test 3 | ||
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])), | ||
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))}, | ||
4, | ||
7 | ||
) | ||
], | ||
ids=["test1", "test2", "test3"] | ||
) | ||
def test_distance(predictor, updater, detections, gate_threshold, num_gated): | ||
|
||
timestamp = datetime.datetime.now() | ||
|
||
hypothesiser = PDAHypothesiser(predictor, updater, clutter_spatial_density=0.000001) | ||
gater = DistanceGater(hypothesiser, measure=measure, gate_threshold=gate_threshold) | ||
|
||
track = Track([GaussianStateUpdate( | ||
np.array([[0]]), | ||
np.array([[1]]), | ||
SingleHypothesis( | ||
None, | ||
Detection(np.array([[0]]), metadata={"MMSI": 12345})), | ||
timestamp=timestamp)]) | ||
|
||
hypotheses = gater.hypothesise(track, detections, timestamp) | ||
|
||
# The number of gated hypotheses matches the expected | ||
assert len(hypotheses) == num_gated | ||
|
||
# The gated hypotheses are either the null hypothesis or their distance is less than the set | ||
# gate threshold | ||
assert all(not hypothesis.measurement or | ||
measure(hypothesis.measurement_prediction, hypothesis.measurement) < gate_threshold | ||
for hypothesis in hypotheses) | ||
|
||
# There is a SINGLE missed detection hypothesis | ||
assert len([hypothesis for hypothesis in hypotheses if not hypothesis]) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters