Skip to content

Commit

Permalink
Added tests for skip_non_reversible in MultiMeasurementInitiator
Browse files Browse the repository at this point in the history
  • Loading branch information
gawebb-dstl committed Oct 4, 2023
1 parent 540b96e commit 6070eb0
Showing 1 changed file with 50 additions and 6 deletions.
56 changes: 50 additions & 6 deletions stonesoup/initiator/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def rvs(self):
measurement_model.matrix().T == approx(measurement_model.covar())


def create_multi_measurement_initiator(obj_class, updates_only: bool):
def create_multi_measurement_initiator(obj_class, **kwargs):
transition_model = CombinedLinearGaussianTransitionModel(
(ConstantVelocity(0.05), ConstantVelocity(0.05)))
measurement_model = LinearGaussian(
Expand All @@ -292,10 +292,15 @@ def create_multi_measurement_initiator(obj_class, updates_only: bool):
data_associator = NearestNeighbour(hypothesiser)
deleter = UpdateTimeDeleter(datetime.timedelta(seconds=59))

measurement_initiator = obj_class(
prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
deleter=deleter, data_associator=data_associator, updater=updater,
measurement_model=measurement_model, updates_only=updates_only)
obj_kwargs = dict(prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 15, 0, 15])),
deleter=deleter,
data_associator=data_associator,
updater=updater,
measurement_model=measurement_model)

obj_kwargs.update(kwargs)
measurement_initiator = obj_class(**obj_kwargs)

return measurement_initiator


Expand Down Expand Up @@ -326,7 +331,7 @@ def test_multi_measurement(updates_only):
assert len(measurement_initiator.holding_tracks) == 0


def test_multi_measurement2():
def test_no_history_multi_measurement():
measurement_initiator = create_multi_measurement_initiator(NoHistoryMultiMeasurementInitiator,
updates_only=False)

Expand All @@ -346,6 +351,45 @@ def test_multi_measurement2():
assert len(track) == 1


@pytest.mark.parametrize("measurement_model_class",
(CartesianToBearingRange, Cartesian2DToBearing, LinearGaussian))
@pytest.mark.parametrize("skip_non_reversible", (True, False))
def test_skip_in_multi_measurement(measurement_model_class, skip_non_reversible):
timestamp = datetime.datetime.now()

if measurement_model_class == Cartesian2DToBearing:
state_len = 1
else:
state_len = 2

measurement_model = measurement_model_class(ndim_state=2, mapping=(0, 1),
noise_covar=np.diag([1]*state_len))

det = Detection(state_vector=np.array([[2]*state_len]),
timestamp=timestamp,
measurement_model=measurement_model
)

interal_initiator = SinglePointInitiator(
prior_state=GaussianState([[0], [0]], np.diag([15, 15])))

measurement_initiator = create_multi_measurement_initiator(
MultiMeasurementInitiator,
initiator=interal_initiator,
measurement_model=None,
skip_non_reversible=skip_non_reversible)

measurement_initiator.initiate({det}, timestamp)
holding_tracks = measurement_initiator.holding_tracks

if isinstance(measurement_model, Cartesian2DToBearing) and skip_non_reversible:
assert len(holding_tracks) == 0
else:
assert len(holding_tracks) == 1
for track in holding_tracks:
assert track.timestamp == timestamp


@pytest.mark.parametrize("initiator", [
SinglePointInitiator(
GaussianState(np.array([[0]]), np.array([[100]]))
Expand Down

0 comments on commit 6070eb0

Please sign in to comment.