Skip to content

Commit

Permalink
Merge pull request #216 from sglvladi/gaters
Browse files Browse the repository at this point in the history
Add Gater classes
  • Loading branch information
sdhiscocks committed May 26, 2020
2 parents b500e1d + 0210713 commit e7b6a8c
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 26 deletions.
20 changes: 20 additions & 0 deletions docs/source/stonesoup.gater.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Gater
=====

.. automodule:: stonesoup.gater
:no-members:

.. automodule:: stonesoup.gater.base
:show-inheritance:

Distance
--------

.. automodule:: stonesoup.gater.distance
:show-inheritance:

Filtered
--------

.. automodule:: stonesoup.gater.filtered
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/stonesoup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Tracker Components

stonesoup.dataassociator
stonesoup.deleter
stonesoup.gater
stonesoup.hypothesiser
stonesoup.initiator
stonesoup.mixturereducer
Expand Down
16 changes: 3 additions & 13 deletions stonesoup/dataassociator/probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ class JPDA(DataAssociator):
hypothesiser = Property(
PDAHypothesiser,
doc="Generate a set of hypotheses for each prediction-detection pair")
gate_ratio = Property(
float,
doc="If probability of Detection/Track association is less than this "
"many times less than probability of MissedDetection, treat "
"probability of association as 0."
)

def associate(self, tracks, detections, time):
"""Associate detections with predicted states.
Expand Down Expand Up @@ -109,7 +103,7 @@ def associate(self, tracks, detections, time):

# enumerate the Joint Hypotheses of track/detection associations
joint_hypotheses = \
self.enumerate_JPDA_hypotheses(tracks, hypotheses, self.gate_ratio)
self.enumerate_JPDA_hypotheses(tracks, hypotheses)

# Calculate MultiMeasurementHypothesis for each Track over all
# available Detections with probabilities drawn from JointHypotheses
Expand Down Expand Up @@ -156,7 +150,7 @@ def associate(self, tracks, detections, time):
return new_hypotheses

@classmethod
def enumerate_JPDA_hypotheses(cls, tracks, multihypths, gate_ratio):
def enumerate_JPDA_hypotheses(cls, tracks, multihypths):

joint_hypotheses = list()

Expand All @@ -171,13 +165,9 @@ def enumerate_JPDA_hypotheses(cls, tracks, multihypths, gate_ratio):

for track in tracks:
track_possible_assoc = list()
missed_probability = \
multihypths[track].get_missed_detection_probability()
missed_gate = missed_probability/gate_ratio
for hypothesis in multihypths[track]:
# Always include missed detection (gate ratio < 1)
if not hypothesis or hypothesis.probability >= missed_gate:
track_possible_assoc.append(hypothesis)
track_possible_assoc.append(hypothesis)
possible_assoc.append(track_possible_assoc)

# enumerate all valid JPDA joint hypotheses
Expand Down
2 changes: 1 addition & 1 deletion stonesoup/dataassociator/tests/test_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def associator(request, probability_hypothesiser):
if request.param is PDA:
return request.param(probability_hypothesiser)
elif request.param is JPDA:
return request.param(probability_hypothesiser, 5)
return request.param(probability_hypothesiser)


def test_probability(associator):
Expand Down
4 changes: 4 additions & 0 deletions stonesoup/gater/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
from .base import Gater

__all__ = ['Gater']
14 changes: 14 additions & 0 deletions stonesoup/gater/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from ..base import Property
from ..hypothesiser import Hypothesiser


class Gater(Hypothesiser):
"""Gater base class
Gaters wrap :class:`.Hypothesiser` objects and can be used to modify (typically reduce) the
returned hypotheses.
"""

hypothesiser = Property(
Hypothesiser, doc="Hypothesiser that is being wrapped.")
31 changes: 31 additions & 0 deletions stonesoup/gater/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
from ..base import Property
from ..measures import Measure
from ..types.multihypothesis import MultipleHypothesis
from .base import Gater


class DistanceGater(Gater):
""" Distance based gater
Uses a measure to calculate the distance between a hypothesis' measurement prediction and the
hypothised measurement, then removes any hypotheses whose calculated distance exceeds the
specified gate threshold.
"""
measure = Property(Measure,
doc="Measure class used to calculate the distance between the measurement "
"prediction and the hypothesised measurement.")
gate_threshold = Property(float,
doc="The gate threshold. Hypotheses whose calculated distance "
"exceeds this threshold will be filtered out.")

def hypothesise(self, track, detections, *args, **kwargs):

hypotheses = self.hypothesiser.hypothesise(track, detections, *args, **kwargs)

gated_hypotheses = [hypothesis for hypothesis in hypotheses
if (not hypothesis
or self.measure(hypothesis.measurement_prediction,
hypothesis.measurement) < self.gate_threshold)]

return MultipleHypothesis(sorted(gated_hypotheses, reverse=True))
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
from .base import Hypothesiser
from .base import Gater
from ..base import Property


class FilteredDetectionsHypothesiser(Hypothesiser):
class FilteredDetectionsGater(Gater):
"""Wrapper for Hypothesisers - filters input data
Wrapper for any type of hypothesiser - filters the 'detections' before
they are fed into the hypothesiser.
"""

hypothesiser = Property(
Hypothesiser, doc="Hypothesiser that is being wrapped.")
metadata_filter = Property(
str, doc="Metadata attribute used to filter which detections "
"tracks are valid for association.")
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions stonesoup/gater/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
import pytest

from ...types.prediction import (
GaussianMeasurementPrediction, GaussianStatePrediction)


@pytest.fixture()
def predictor():
class TestGaussianPredictor:
def predict(self, prior, control_input=None, timestamp=None, **kwargs):
return GaussianStatePrediction(prior.state_vector + 1,
prior.covar * 2, timestamp)
return TestGaussianPredictor()


@pytest.fixture()
def updater():
class TestGaussianUpdater:
def predict_measurement(self, state_prediction,
measurement_model=None, **kwargs):
return GaussianMeasurementPrediction(state_prediction.state_vector,
state_prediction.covar,
state_prediction.timestamp)
return TestGaussianUpdater()
67 changes: 67 additions & 0 deletions stonesoup/gater/tests/test_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import pytest
import numpy as np

from ..distance import DistanceGater
from ...hypothesiser.probability import PDAHypothesiser
from ...types.detection import Detection
from ...types.hypothesis import SingleHypothesis
from ...types.track import Track
from ...types.update import GaussianStateUpdate
from ... import measures as measures

measure = measures.Mahalanobis()


@pytest.mark.parametrize(
"detections, gate_threshold, num_gated",
[
( # Test 1
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])),
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))},
1,
3
),
( # Test 2
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])),
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))},
2,
5
),
( # Test 3
{Detection(np.array([[2]])), Detection(np.array([[3]])), Detection(np.array([[6]])),
Detection(np.array([[0]])), Detection(np.array([[-1]])), Detection(np.array([[-4]]))},
4,
7
)
],
ids=["test1", "test2", "test3"]
)
def test_distance(predictor, updater, detections, gate_threshold, num_gated):

timestamp = datetime.datetime.now()

hypothesiser = PDAHypothesiser(predictor, updater, clutter_spatial_density=0.000001)
gater = DistanceGater(hypothesiser, measure=measure, gate_threshold=gate_threshold)

track = Track([GaussianStateUpdate(
np.array([[0]]),
np.array([[1]]),
SingleHypothesis(
None,
Detection(np.array([[0]]), metadata={"MMSI": 12345})),
timestamp=timestamp)])

hypotheses = gater.hypothesise(track, detections, timestamp)

# The number of gated hypotheses matches the expected
assert len(hypotheses) == num_gated

# The gated hypotheses are either the null hypothesis or their distance is less than the set
# gate threshold
assert all(not hypothesis.measurement or
measure(hypothesis.measurement_prediction, hypothesis.measurement) < gate_threshold
for hypothesis in hypotheses)

# There is a SINGLE missed detection hypothesis
assert len([hypothesis for hypothesis in hypotheses if not hypothesis]) == 1
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import datetime
from operator import attrgetter

import numpy as np
from operator import attrgetter

from ..distance import DistanceHypothesiser
from ..filtered import FilteredDetectionsHypothesiser
from ..filtered import FilteredDetectionsGater
from ...hypothesiser.distance import DistanceHypothesiser
from ...types.detection import Detection
from ...types.hypothesis import SingleHypothesis
from ...types.track import Track
Expand All @@ -24,7 +23,7 @@ def test_filtereddetections(predictor, updater):
hypothesiser = DistanceHypothesiser(
predictor, updater, measure=measure, missed_distance=0.2,
include_all=True)
hypothesiser_wrapper = FilteredDetectionsHypothesiser(
hypothesiser_wrapper = FilteredDetectionsGater(
hypothesiser, "MMSI", match_missing=True)

track = Track([GaussianStateUpdate(
Expand Down Expand Up @@ -63,7 +62,7 @@ def test_filtereddetections_empty_detections(predictor, updater):
timestamp = datetime.datetime.now()
hypothesiser = DistanceHypothesiser(predictor, updater,
measure=measure, missed_distance=0.2)
hypothesiser_wrapper = FilteredDetectionsHypothesiser(
hypothesiser_wrapper = FilteredDetectionsGater(
hypothesiser, "MMSI", match_missing=False)

track = Track([GaussianStateUpdate(
Expand Down Expand Up @@ -94,7 +93,7 @@ def test_filtereddetections_no_track_metadata(predictor, updater):
hypothesiser = DistanceHypothesiser(
predictor, updater, measure=measure, missed_distance=0.2,
include_all=True)
hypothesiser_wrapper = FilteredDetectionsHypothesiser(
hypothesiser_wrapper = FilteredDetectionsGater(
hypothesiser, "MMSI", match_missing=True)

track = Track([GaussianStateUpdate(
Expand Down Expand Up @@ -135,7 +134,7 @@ def test_filtereddetections_no_matching_metadata(predictor, updater):
timestamp = datetime.datetime.now()
hypothesiser = DistanceHypothesiser(predictor, updater,
measure=measure, missed_distance=0.2)
hypothesiser_wrapper = FilteredDetectionsHypothesiser(
hypothesiser_wrapper = FilteredDetectionsGater(
hypothesiser, "MMSI", match_missing=True)

track = Track([GaussianStateUpdate(
Expand Down

0 comments on commit e7b6a8c

Please sign in to comment.