Skip to content

Commit

Permalink
Merge pull request #835 from dstl/NoHistoryMultiMeasurementInitiator
Browse files Browse the repository at this point in the history
Minor Addition and Changes to MultiMeasurementInitiator
  • Loading branch information
sdhiscocks committed Oct 24, 2023
2 parents 71b47f9 + 6070eb0 commit 609e820
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 7 deletions.
21 changes: 21 additions & 0 deletions stonesoup/initiator/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class MultiMeasurementInitiator(GaussianInitiator):
doc="Initiator used to create tracks. If None, a :class:`SimpleMeasurementInitiator` will "
"be created using :attr:`prior_state` and :attr:`measurement_model`. Otherwise, these "
"attributes are ignored.")
skip_non_reversible: bool = Property(
default=False, doc="Skip measurements that do not have a reversible measurement model. "
"Only allow measurements with a measurement model that is an instance "
"of a :class:`~.LinearModel` or a :class:`~.ReversibleModel`.")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -196,6 +200,10 @@ def initiate(self, detections, timestamp, **kwargs):

associated_detections = set()

if self.skip_non_reversible:
detections = {det for det in detections
if isinstance(det.measurement_model, (ReversibleModel, LinearModel))}

if self.holding_tracks:
associations = self.data_associator.associate(
self.holding_tracks, detections, timestamp)
Expand All @@ -221,6 +229,19 @@ def initiate(self, detections, timestamp, **kwargs):
return sure_tracks


class NoHistoryMultiMeasurementInitiator(MultiMeasurementInitiator):
"""
This initiator is very similar to :class:`MultiMeasurementInitiator`. The only difference
being that the holding track’s history is moved to the metadata so that initialised tracks
only have one state.
"""
def initiate(self, *args, **kwargs):
tracks = super().initiate(*args, **kwargs)
return {Track(id=track.id, states=[track.state],
init_metadata=dict(holding_track=track, **track.metadata))
for track in tracks}


class GaussianParticleInitiator(ParticleInitiator):
"""Gaussian Particle Initiator class
Expand Down
85 changes: 78 additions & 7 deletions stonesoup/initiator/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from ...types.update import ParticleStateUpdate, Update
from ..simple import (
SinglePointInitiator, SimpleMeasurementInitiator,
MultiMeasurementInitiator, GaussianParticleInitiator
MultiMeasurementInitiator, GaussianParticleInitiator,
NoHistoryMultiMeasurementInitiator
)


Expand Down Expand Up @@ -278,8 +279,7 @@ def rvs(self):
measurement_model.matrix().T == approx(measurement_model.covar())


@pytest.mark.parametrize('updates_only', [False, True])
def test_multi_measurement(updates_only):
def create_multi_measurement_initiator(obj_class, **kwargs):
transition_model = CombinedLinearGaussianTransitionModel(
(ConstantVelocity(0.05), ConstantVelocity(0.05)))
measurement_model = LinearGaussian(
Expand All @@ -292,10 +292,22 @@ def test_multi_measurement(updates_only):
data_associator = NearestNeighbour(hypothesiser)
deleter = UpdateTimeDeleter(datetime.timedelta(seconds=59))

measurement_initiator = MultiMeasurementInitiator(
GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
deleter, data_associator, updater,
measurement_model=measurement_model, updates_only=updates_only)
obj_kwargs = dict(prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
deleter=deleter,
data_associator=data_associator,
updater=updater,
measurement_model=measurement_model)

obj_kwargs.update(kwargs)
measurement_initiator = obj_class(**obj_kwargs)

return measurement_initiator


@pytest.mark.parametrize('updates_only', [False, True])
def test_multi_measurement(updates_only):
measurement_initiator = create_multi_measurement_initiator(MultiMeasurementInitiator,
updates_only=updates_only)

timestamp = datetime.datetime.now()
first_detections = {Detection(np.array([[5], [2]]), timestamp),
Expand All @@ -319,6 +331,65 @@ def test_multi_measurement(updates_only):
assert len(measurement_initiator.holding_tracks) == 0


def test_no_history_multi_measurement():
measurement_initiator = create_multi_measurement_initiator(NoHistoryMultiMeasurementInitiator,
updates_only=False)

timestamp = datetime.datetime.now()
first_detections = {Detection(np.array([[5], [2]]), timestamp),
Detection(np.array([[-5], [-2]]), timestamp)}

first_tracks = measurement_initiator.initiate(first_detections, timestamp)
assert len(first_tracks) == 0
assert len(measurement_initiator.holding_tracks) == 2

timestamp = datetime.datetime.now() + datetime.timedelta(seconds=60)
second_detections = {Detection(np.array([[5], [3]]), timestamp)}

second_tracks = measurement_initiator.initiate(second_detections, timestamp)
for track in second_tracks:
assert len(track) == 1


@pytest.mark.parametrize("measurement_model_class",
(CartesianToBearingRange, Cartesian2DToBearing, LinearGaussian))
@pytest.mark.parametrize("skip_non_reversible", (True, False))
def test_skip_in_multi_measurement(measurement_model_class, skip_non_reversible):
timestamp = datetime.datetime.now()

if measurement_model_class == Cartesian2DToBearing:
state_len = 1
else:
state_len = 2

measurement_model = measurement_model_class(ndim_state=2, mapping=(0, 1),
noise_covar=np.diag([1]*state_len))

det = Detection(state_vector=np.array([[2]*state_len]),
timestamp=timestamp,
measurement_model=measurement_model
)

interal_initiator = SinglePointInitiator(
prior_state=GaussianState([[0], [0]], np.diag([15, 15])))

measurement_initiator = create_multi_measurement_initiator(
MultiMeasurementInitiator,
initiator=interal_initiator,
measurement_model=None,
skip_non_reversible=skip_non_reversible)

measurement_initiator.initiate({det}, timestamp)
holding_tracks = measurement_initiator.holding_tracks

if isinstance(measurement_model, Cartesian2DToBearing) and skip_non_reversible:
assert len(holding_tracks) == 0
else:
assert len(holding_tracks) == 1
for track in holding_tracks:
assert track.timestamp == timestamp


@pytest.mark.parametrize("initiator", [
SinglePointInitiator(
GaussianState(np.array([[0]]), np.array([[100]]))
Expand Down

0 comments on commit 609e820

Please sign in to comment.