From 765bba0dd2ab6c2a7ce5157940364b924b3f21bd Mon Sep 17 00:00:00 2001 From: Timothy Glover Date: Thu, 20 Jul 2023 15:21:14 +0100 Subject: [PATCH 1/3] Mode an ammendment to the regulariser. Incorporated the regulariser into the base ParticleUpdater --- stonesoup/predictor/particle.py | 1 + stonesoup/regulariser/particle.py | 28 ++++++---- stonesoup/regulariser/tests/test_particle.py | 18 +++++-- stonesoup/updater/particle.py | 9 +++- stonesoup/updater/tests/test_particle.py | 56 +++++++++++++++++++- 5 files changed, 95 insertions(+), 17 deletions(-) diff --git a/stonesoup/predictor/particle.py b/stonesoup/predictor/particle.py index 9b22b8f3b..306ae40e0 100644 --- a/stonesoup/predictor/particle.py +++ b/stonesoup/predictor/particle.py @@ -341,6 +341,7 @@ def predict(self, prior, timestamp=None, **kwargs): existence_probability=predicted_existence, parent=untransitioned_state, timestamp=timestamp, + transition_model=self.transition_model ) return new_particle_state diff --git a/stonesoup/regulariser/particle.py b/stonesoup/regulariser/particle.py index 0d560bbc3..3d905f118 100644 --- a/stonesoup/regulariser/particle.py +++ b/stonesoup/regulariser/particle.py @@ -14,7 +14,7 @@ class MCMCRegulariser(Regulariser): of effectiveness. Sometimes this is not desirable, or possible, when a particular algorithm requires the introduction of new samples as part of the filtering process for example. - This is a particlar implementation of a MCMC move step that uses the Metropolis-Hastings + This is a particular implementation of a MCMC move step that uses the Metropolis-Hastings algorithm [1]_. After resampling, particles are moved a small amount, according do a Gaussian kernel, to a new state only if the Metropolis-Hastings acceptance probability is met by a random number assigned to each particle from a uniform random distribution, otherwise they @@ -24,21 +24,23 @@ class MCMCRegulariser(Regulariser): ---------- .. [1] Robert, Christian P. & Casella, George, Monte Carlo Statistical Methods, Springer, 1999. - .. [2] Ristic, Branco & Arulampalam, Sanjeev & Gordon, Neil, Beyond the Kalman Filter: + .. [2] Ristic, Branko & Arulampalam, Sanjeev & Gordon, Neil, Beyond the Kalman Filter: Particle Filters for Target Tracking Applications, Artech House, 2004. """ - def regularise(self, prior, posterior, detections): + def regularise(self, prior, posterior, detections, transition_model=None): """Regularise the particles Parameters ---------- prior : :class:`~.ParticleState` type or list of :class:`~.Particle` - prior particle distribution + prior particle distribution. posterior : :class:`~.ParticleState` type or list of :class:`~.Particle` posterior particle distribution detections : set of :class:`~.Detection` set of detections containing clutter, true detections or both + transition_model : :class:`~.TransitionModel` + Transition model used in the prediction step. Returns ------- @@ -54,6 +56,14 @@ def regularise(self, prior, posterior, detections): regularised_particles = copy.copy(posterior) moved_particles = copy.copy(posterior) + transitioned_prior = copy.copy(prior) + + if transition_model is not None: + time_interval = posterior.timestamp - prior.timestamp + new_state_vector = transition_model.function(prior, + noise=False, + time_interval=time_interval) + transitioned_prior.state_vector = new_state_vector if detections is not None: ndim = prior.state_vector.shape[0] @@ -70,13 +80,11 @@ def regularise(self, prior, posterior, detections): hopt * cholesky_eps(covar_est) @ np.random.randn(ndim, nparticles) # Evaluate likelihoods - part_diff = moved_particles.state_vector - prior.state_vector - part_diff_mean = np.average(part_diff, axis=1) - move_likelihood = multivariate_normal.logpdf((part_diff - part_diff_mean).T, + part_diff = moved_particles.state_vector - transitioned_prior.state_vector + move_likelihood = multivariate_normal.logpdf(part_diff.T, cov=covar_est) - post_part_diff = posterior.state_vector - prior.state_vector - post_part_diff_mean = np.average(post_part_diff, axis=1) - post_likelihood = multivariate_normal.logpdf((post_part_diff - post_part_diff_mean).T, + post_part_diff = posterior.state_vector - transitioned_prior.state_vector + post_likelihood = multivariate_normal.logpdf(post_part_diff.T, cov=covar_est) # Evaluate measurement likelihoods diff --git a/stonesoup/regulariser/tests/test_particle.py b/stonesoup/regulariser/tests/test_particle.py index 673855f49..6e2c3295c 100644 --- a/stonesoup/regulariser/tests/test_particle.py +++ b/stonesoup/regulariser/tests/test_particle.py @@ -6,12 +6,15 @@ from ...types.hypothesis import SingleHypothesis from ...types.prediction import ParticleStatePrediction, ParticleMeasurementPrediction from ...models.measurement.linear import LinearGaussian +from ...models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity from ...types.detection import Detection from ...types.update import ParticleStateUpdate from ..particle import MCMCRegulariser def test_regulariser(): + transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]) + particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]), 1 / 9), Particle(np.array([[10], [20]]), @@ -32,20 +35,25 @@ def test_regulariser(): 1 / 9), ]) timestamp = datetime.datetime.now() - prediction = ParticleStatePrediction(None, particle_list=particles.particle_list, - timestamp=timestamp) - meas_pred = ParticleMeasurementPrediction(None, particle_list=particles, timestamp=timestamp) + new_state_vector = transition_model.function(particles, + noise=True, + time_interval=datetime.timedelta(seconds=1)) + prediction = ParticleStatePrediction(new_state_vector, + timestamp=timestamp, + transition_model=transition_model) + meas_pred = ParticleMeasurementPrediction(prediction, timestamp=timestamp) measurement_model = LinearGaussian(ndim_state=2, mapping=(0, 1), noise_covar=np.eye(2)) measurement = [Detection(state_vector=np.array([[5], [7]]), timestamp=timestamp, measurement_model=measurement_model)] state_update = ParticleStateUpdate(None, SingleHypothesis(prediction=prediction, measurement=measurement, measurement_prediction=meas_pred), - particle_list=particles.particle_list, timestamp=timestamp) + particle_list=particles.particle_list, + timestamp=timestamp+datetime.timedelta(seconds=1)) regulariser = MCMCRegulariser() # state check - new_particles = regulariser.regularise(particles, state_update, measurement) + new_particles = regulariser.regularise(prediction, state_update, measurement, transition_model) # Check the shape of the new state vector assert new_particles.state_vector.shape == state_update.state_vector.shape # Check weights are unchanged diff --git a/stonesoup/updater/particle.py b/stonesoup/updater/particle.py index 245b5da75..d9aa4f780 100644 --- a/stonesoup/updater/particle.py +++ b/stonesoup/updater/particle.py @@ -65,6 +65,12 @@ def update(self, hypothesis, **kwargs): if self.resampler is not None: predicted_state = self.resampler.resample(predicted_state) + if self.regulariser is not None: + predicted_state = self.regulariser.regularise(predicted_state.parent, + predicted_state, + [hypothesis.measurement], + hypothesis.prediction.transition_model) + return Update.from_state( state=hypothesis.prediction, state_vector=predicted_state.state_vector, @@ -465,7 +471,8 @@ def update(self, hypotheses, **kwargs): if self.regulariser is not None: regularised_parts = self.regulariser.regularise(updated_state.parent, updated_state, - detections) + detections, + prediction.transition_model) updated_state.state_vector = regularised_parts.state_vector return Update.from_state( diff --git a/stonesoup/updater/tests/test_particle.py b/stonesoup/updater/tests/test_particle.py index 4f9f421ab..0d7630cba 100644 --- a/stonesoup/updater/tests/test_particle.py +++ b/stonesoup/updater/tests/test_particle.py @@ -12,13 +12,14 @@ from ...types.hypothesis import SingleHypothesis from ...types.multihypothesis import MultipleHypothesis from ...types.particle import Particle +from ...types.state import ParticleState from ...types.prediction import ( ParticleStatePrediction, ParticleMeasurementPrediction) from ...updater.particle import ( ParticleUpdater, GromovFlowParticleUpdater, GromovFlowKalmanParticleUpdater, BernoulliParticleUpdater) from ...predictor.particle import BernoulliParticlePredictor -from ...models.transition.linear import ConstantVelocity +from ...models.transition.linear import ConstantVelocity, CombinedLinearGaussianTransitionModel from ...types.update import BernoulliParticleStateUpdate from ...regulariser.particle import MCMCRegulariser from ...sampler.particle import ParticleSampler @@ -170,3 +171,56 @@ def test_bernoulli_particle(): assert update.hypothesis == hypotheses # Check that the existence probability is returned assert update.existence_probability is not None + + +def test_regularised_particle(): + + transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]) + measurement_model = LinearGaussian( + ndim_state=2, mapping=[0], noise_covar=np.array([[10]])) + + updater = ParticleUpdater(regulariser=MCMCRegulariser(), + measurement_model=measurement_model) + # Measurement model + timestamp = datetime.datetime.now() + particles = [Particle([[10], [10]], 1 / 9), + Particle([[10], [20]], 1 / 9), + Particle([[10], [30]], 1 / 9), + Particle([[20], [10]], 1 / 9), + Particle([[20], [20]], 1 / 9), + Particle([[20], [30]], 1 / 9), + Particle([[30], [10]], 1 / 9), + Particle([[30], [20]], 1 / 9), + Particle([[30], [30]], 1 / 9), + ] + + particles = ParticleState(None, particle_list=particles, timestamp=timestamp) + predicted_state = transition_model.function(particles, + noise=True, + time_interval=datetime.timedelta(seconds=1)) + prediction = ParticleStatePrediction(predicted_state, + weight=np.array([1/9]*9), + timestamp=timestamp, + transition_model=transition_model, + parent=particles) + + measurement = Detection([[40.0]], timestamp=timestamp, measurement_model=measurement_model) + eval_measurement_prediction = ParticleMeasurementPrediction( + StateVectors([prediction.state_vector[0, :]]), timestamp=timestamp) + + measurement_prediction = updater.predict_measurement(prediction) + + assert np.all(eval_measurement_prediction.state_vector == measurement_prediction.state_vector) + assert measurement_prediction.timestamp == timestamp + + updated_state = updater.update(SingleHypothesis( + prediction, measurement, measurement_prediction)) + + # Don't know what the particles will exactly be due to randomness so check + # some obvious properties + + assert np.all(weight == 1 / 9 for weight in updated_state.weight) + assert updated_state.timestamp == timestamp + assert updated_state.hypothesis.measurement_prediction == measurement_prediction + assert updated_state.hypothesis.prediction == prediction + assert updated_state.hypothesis.measurement == measurement From 62d43e6d4b85f2f16d9ac61ad488e9fe8e4d50eb Mon Sep 17 00:00:00 2001 From: Timothy Glover Date: Thu, 20 Jul 2023 16:18:22 +0100 Subject: [PATCH 2/3] Drop support for particle lists and move the transition model to a property of MCMCRegulariser() --- stonesoup/regulariser/particle.py | 24 ++++++++++--------- stonesoup/regulariser/tests/test_particle.py | 25 ++++++++++++-------- stonesoup/updater/particle.py | 6 ++--- stonesoup/updater/tests/test_particle.py | 4 ++-- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/stonesoup/regulariser/particle.py b/stonesoup/regulariser/particle.py index 3d905f118..0de4ccc85 100644 --- a/stonesoup/regulariser/particle.py +++ b/stonesoup/regulariser/particle.py @@ -5,6 +5,8 @@ from .base import Regulariser from ..functions import cholesky_eps from ..types.state import ParticleState +from ..models.transition import TransitionModel +from ..base import Property class MCMCRegulariser(Regulariser): @@ -27,20 +29,20 @@ class MCMCRegulariser(Regulariser): .. [2] Ristic, Branko & Arulampalam, Sanjeev & Gordon, Neil, Beyond the Kalman Filter: Particle Filters for Target Tracking Applications, Artech House, 2004. """ - def regularise(self, prior, posterior, detections, transition_model=None): + transition_model: TransitionModel = Property(doc="Transition model used for prediction") + + def regularise(self, prior, posterior, detections): """Regularise the particles Parameters ---------- - prior : :class:`~.ParticleState` type or list of :class:`~.Particle` + prior : :class:`~.ParticleState` type prior particle distribution. - posterior : :class:`~.ParticleState` type or list of :class:`~.Particle` + posterior : :class:`~.ParticleState` type posterior particle distribution detections : set of :class:`~.Detection` set of detections containing clutter, true detections or both - transition_model : :class:`~.TransitionModel` - Transition model used in the prediction step. Returns ------- @@ -49,20 +51,20 @@ def regularise(self, prior, posterior, detections, transition_model=None): """ if not isinstance(posterior, ParticleState): - posterior = ParticleState(None, particle_list=posterior) + raise TypeError('Only ParticleState type is supported!') if not isinstance(prior, ParticleState): - prior = ParticleState(None, particle_list=prior) + raise TypeError('Only ParticleState type is supported!') regularised_particles = copy.copy(posterior) moved_particles = copy.copy(posterior) transitioned_prior = copy.copy(prior) - if transition_model is not None: + if self.transition_model is not None: time_interval = posterior.timestamp - prior.timestamp - new_state_vector = transition_model.function(prior, - noise=False, - time_interval=time_interval) + new_state_vector = self.transition_model.function(prior, + noise=False, + time_interval=time_interval) transitioned_prior.state_vector = new_state_vector if detections is not None: diff --git a/stonesoup/regulariser/tests/test_particle.py b/stonesoup/regulariser/tests/test_particle.py index 6e2c3295c..7c37fe167 100644 --- a/stonesoup/regulariser/tests/test_particle.py +++ b/stonesoup/regulariser/tests/test_particle.py @@ -1,5 +1,6 @@ import numpy as np import datetime +import pytest from ...types.state import ParticleState from ...types.particle import Particle @@ -50,10 +51,10 @@ def test_regulariser(): measurement_prediction=meas_pred), particle_list=particles.particle_list, timestamp=timestamp+datetime.timedelta(seconds=1)) - regulariser = MCMCRegulariser() + regulariser = MCMCRegulariser(transition_model=transition_model) # state check - new_particles = regulariser.regularise(prediction, state_update, measurement, transition_model) + new_particles = regulariser.regularise(prediction, state_update, measurement) # Check the shape of the new state vector assert new_particles.state_vector.shape == state_update.state_vector.shape # Check weights are unchanged @@ -61,13 +62,17 @@ def test_regulariser(): # Check that the timestamp is the same assert new_particles.timestamp == state_update.timestamp - # list check - new_particles = regulariser.regularise(particles.particle_list, state_update.particle_list, - measurement) - # Check the shape of the new state vector - assert new_particles.state_vector.shape == state_update.state_vector.shape - # Check weights are unchanged - assert any(new_particles.weight == state_update.weight) + # list check3 + with pytest.raises(TypeError) as e: + new_particles = regulariser.regularise(particles.particle_list, + state_update, + measurement) + assert "Only ParticleState type is supported!" in str(e.value) + with pytest.raises(Exception) as e: + new_particles = regulariser.regularise(particles, + state_update.particle_list, + measurement) + assert "Only ParticleState type is supported!" in str(e.value) def test_no_measurement(): @@ -98,7 +103,7 @@ def test_no_measurement(): measurement=None, measurement_prediction=meas_pred), particle_list=particles.particle_list, timestamp=timestamp) - regulariser = MCMCRegulariser() + regulariser = MCMCRegulariser(transition_model=None) new_particles = regulariser.regularise(particles, state_update, detections=None) diff --git a/stonesoup/updater/particle.py b/stonesoup/updater/particle.py index d9aa4f780..5728da2b4 100644 --- a/stonesoup/updater/particle.py +++ b/stonesoup/updater/particle.py @@ -68,8 +68,7 @@ def update(self, hypothesis, **kwargs): if self.regulariser is not None: predicted_state = self.regulariser.regularise(predicted_state.parent, predicted_state, - [hypothesis.measurement], - hypothesis.prediction.transition_model) + [hypothesis.measurement]) return Update.from_state( state=hypothesis.prediction, @@ -471,8 +470,7 @@ def update(self, hypotheses, **kwargs): if self.regulariser is not None: regularised_parts = self.regulariser.regularise(updated_state.parent, updated_state, - detections, - prediction.transition_model) + detections) updated_state.state_vector = regularised_parts.state_vector return Update.from_state( diff --git a/stonesoup/updater/tests/test_particle.py b/stonesoup/updater/tests/test_particle.py index 0d7630cba..3fd7d763b 100644 --- a/stonesoup/updater/tests/test_particle.py +++ b/stonesoup/updater/tests/test_particle.py @@ -140,7 +140,7 @@ def test_bernoulli_particle(): prediction = predictor.predict(prior, timestamp=new_timestamp) resampler = SystematicResampler() - regulariser = MCMCRegulariser() + regulariser = MCMCRegulariser(transition_model=cv) updater = BernoulliParticleUpdater(measurement_model=None, resampler=resampler, @@ -179,7 +179,7 @@ def test_regularised_particle(): measurement_model = LinearGaussian( ndim_state=2, mapping=[0], noise_covar=np.array([[10]])) - updater = ParticleUpdater(regulariser=MCMCRegulariser(), + updater = ParticleUpdater(regulariser=MCMCRegulariser(transition_model=transition_model), measurement_model=measurement_model) # Measurement model timestamp = datetime.datetime.now() From 465b167f478f0227aa6ebae581fca4a3529ba679 Mon Sep 17 00:00:00 2001 From: Timothy Glover Date: Mon, 21 Aug 2023 16:07:35 +0100 Subject: [PATCH 3/3] Implement suggested comments and optimise user interface of regularise method --- stonesoup/predictor/particle.py | 1 + stonesoup/regulariser/particle.py | 27 ++++---- stonesoup/regulariser/tests/test_particle.py | 73 ++++++++++++++------ stonesoup/types/multihypothesis.py | 12 ++-- stonesoup/updater/particle.py | 43 ++++++------ stonesoup/updater/tests/test_particle.py | 38 +++++++--- 6 files changed, 121 insertions(+), 73 deletions(-) diff --git a/stonesoup/predictor/particle.py b/stonesoup/predictor/particle.py index 306ae40e0..936c92b98 100644 --- a/stonesoup/predictor/particle.py +++ b/stonesoup/predictor/particle.py @@ -54,6 +54,7 @@ def predict(self, prior, timestamp=None, **kwargs): **kwargs) return Prediction.from_state(prior, + parent=prior, state_vector=new_state_vector, timestamp=timestamp, transition_model=self.transition_model) diff --git a/stonesoup/regulariser/particle.py b/stonesoup/regulariser/particle.py index 0de4ccc85..8ac602af6 100644 --- a/stonesoup/regulariser/particle.py +++ b/stonesoup/regulariser/particle.py @@ -1,6 +1,7 @@ import copy import numpy as np from scipy.stats import multivariate_normal, uniform +from typing import Sequence from .base import Regulariser from ..functions import cholesky_eps @@ -29,9 +30,10 @@ class MCMCRegulariser(Regulariser): .. [2] Ristic, Branko & Arulampalam, Sanjeev & Gordon, Neil, Beyond the Kalman Filter: Particle Filters for Target Tracking Applications, Artech House, 2004. """ - transition_model: TransitionModel = Property(doc="Transition model used for prediction") + transition_model: TransitionModel = Property(doc="Transition model used for prediction", + default=None) - def regularise(self, prior, posterior, detections): + def regularise(self, prior, posterior): """Regularise the particles Parameters @@ -39,10 +41,7 @@ def regularise(self, prior, posterior, detections): prior : :class:`~.ParticleState` type prior particle distribution. posterior : :class:`~.ParticleState` type - posterior particle distribution - detections : set of :class:`~.Detection` - set of detections containing clutter, - true detections or both + posterior particle distribution. Returns ------- @@ -60,14 +59,18 @@ def regularise(self, prior, posterior, detections): moved_particles = copy.copy(posterior) transitioned_prior = copy.copy(prior) - if self.transition_model is not None: + hypotheses = posterior.hypothesis if isinstance(posterior.hypothesis, Sequence) \ + else [posterior.hypothesis] + + transition_model = hypotheses[0].prediction.transition_model or self.transition_model + if transition_model is not None: time_interval = posterior.timestamp - prior.timestamp - new_state_vector = self.transition_model.function(prior, - noise=False, - time_interval=time_interval) - transitioned_prior.state_vector = new_state_vector + transitioned_prior.state_vector = \ + transition_model.function(prior, noise=False, time_interval=time_interval) + + detections = {hypothesis.measurement for hypothesis in hypotheses if hypothesis} - if detections is not None: + if detections: ndim = prior.state_vector.shape[0] nparticles = len(posterior) diff --git a/stonesoup/regulariser/tests/test_particle.py b/stonesoup/regulariser/tests/test_particle.py index 7c37fe167..70b601de2 100644 --- a/stonesoup/regulariser/tests/test_particle.py +++ b/stonesoup/regulariser/tests/test_particle.py @@ -9,13 +9,29 @@ from ...models.measurement.linear import LinearGaussian from ...models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity from ...types.detection import Detection -from ...types.update import ParticleStateUpdate +from ...types.update import Update, ParticleStateUpdate from ..particle import MCMCRegulariser -def test_regulariser(): - transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]) - +@pytest.mark.parametrize( + "transition_model, model_flag", + [ + ( + CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model + False, # model_flag + ), + ( + CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model + True, # model_flag + ), + ( + None, # transition_model + False, # model_flag + ) + ], + ids=["with_transition_model_init", "without_transition_model_init", "no_transition_model"] +) +def test_regulariser(transition_model, model_flag): particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]), 1 / 9), Particle(np.array([[10], [20]]), @@ -36,25 +52,38 @@ def test_regulariser(): 1 / 9), ]) timestamp = datetime.datetime.now() - new_state_vector = transition_model.function(particles, - noise=True, - time_interval=datetime.timedelta(seconds=1)) + if transition_model is not None: + new_state_vector = transition_model.function(particles, + noise=True, + time_interval=datetime.timedelta(seconds=1)) + else: + new_state_vector = particles.state_vector + prediction = ParticleStatePrediction(new_state_vector, timestamp=timestamp, transition_model=transition_model) - meas_pred = ParticleMeasurementPrediction(prediction, timestamp=timestamp) + measurement_model = LinearGaussian(ndim_state=2, mapping=(0, 1), noise_covar=np.eye(2)) - measurement = [Detection(state_vector=np.array([[5], [7]]), - timestamp=timestamp, measurement_model=measurement_model)] - state_update = ParticleStateUpdate(None, SingleHypothesis(prediction=prediction, - measurement=measurement, - measurement_prediction=meas_pred), - particle_list=particles.particle_list, - timestamp=timestamp+datetime.timedelta(seconds=1)) - regulariser = MCMCRegulariser(transition_model=transition_model) + measurement = Detection(state_vector=np.array([[5], [7]]), + timestamp=timestamp, measurement_model=measurement_model) + hypothesis = SingleHypothesis(prediction=prediction, + measurement=measurement, + measurement_prediction=None) + + state_update = Update.from_state(state=prediction, + hypothesis=hypothesis, + timestamp=timestamp+datetime.timedelta(seconds=1)) + # A PredictedParticleState is used here as the point at which the regulariser is implemented + # in the updater is before the updated state has taken the updated state type. + state_update.weight = np.array([1/6, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48]) + + if model_flag: + regulariser = MCMCRegulariser() + else: + regulariser = MCMCRegulariser(transition_model=transition_model) # state check - new_particles = regulariser.regularise(prediction, state_update, measurement) + new_particles = regulariser.regularise(prediction, state_update) # Check the shape of the new state vector assert new_particles.state_vector.shape == state_update.state_vector.shape # Check weights are unchanged @@ -65,13 +94,11 @@ def test_regulariser(): # list check3 with pytest.raises(TypeError) as e: new_particles = regulariser.regularise(particles.particle_list, - state_update, - measurement) + state_update) assert "Only ParticleState type is supported!" in str(e.value) with pytest.raises(Exception) as e: new_particles = regulariser.regularise(particles, - state_update.particle_list, - measurement) + state_update.particle_list) assert "Only ParticleState type is supported!" in str(e.value) @@ -103,9 +130,9 @@ def test_no_measurement(): measurement=None, measurement_prediction=meas_pred), particle_list=particles.particle_list, timestamp=timestamp) - regulariser = MCMCRegulariser(transition_model=None) + regulariser = MCMCRegulariser() - new_particles = regulariser.regularise(particles, state_update, detections=None) + new_particles = regulariser.regularise(particles, state_update) # Check the shape of the new state vector assert new_particles.state_vector.shape == state_update.state_vector.shape diff --git a/stonesoup/types/multihypothesis.py b/stonesoup/types/multihypothesis.py index bf4e8589e..c862d9916 100644 --- a/stonesoup/types/multihypothesis.py +++ b/stonesoup/types/multihypothesis.py @@ -1,5 +1,5 @@ -from collections.abc import Sized, Iterable, Container -from typing import Sequence +from collections.abc import Sequence +import typing from .detection import MissedDetection from .numeric import Probability @@ -10,13 +10,13 @@ from ..types.prediction import Prediction -class MultipleHypothesis(Type, Sized, Iterable, Container): +class MultipleHypothesis(Type, Sequence): """Multiple Hypothesis base type A Multiple Hypothesis is a container to store a collection of hypotheses. """ - single_hypotheses: Sequence[SingleHypothesis] = Property( + single_hypotheses: typing.Sequence[SingleHypothesis] = Property( default=None, doc="The initial list of :class:`~.SingleHypothesis`. Default `None` " "which initialises with empty list.") @@ -119,7 +119,7 @@ def get_missed_detection_probability(self): return None -class MultipleCompositeHypothesis(Type, Sized, Iterable, Container): +class MultipleCompositeHypothesis(Type, Sequence): """Multiple composite hypothesis type A Multiple Composite Hypothesis is a container to store a collection of composite hypotheses. @@ -128,7 +128,7 @@ class MultipleCompositeHypothesis(Type, Sized, Iterable, Container): redefined. """ - single_hypotheses: Sequence[CompositeHypothesis] = Property( + single_hypotheses: typing.Sequence[CompositeHypothesis] = Property( default=None, doc="The initial list of :class:`~.CompositeHypothesis`. Default `None` which initialises " "with empty list.") diff --git a/stonesoup/updater/particle.py b/stonesoup/updater/particle.py index 5728da2b4..545b712a4 100644 --- a/stonesoup/updater/particle.py +++ b/stonesoup/updater/particle.py @@ -46,7 +46,12 @@ def update(self, hypothesis, **kwargs): : :class:`~.ParticleState` The state posterior """ - predicted_state = copy.copy(hypothesis.prediction) + + predicted_state = Update.from_state( + state=hypothesis.prediction, + hypothesis=hypothesis, + timestamp=hypothesis.prediction.timestamp + ) if hypothesis.measurement.measurement_model is None: measurement_model = self.measurement_model @@ -66,17 +71,11 @@ def update(self, hypothesis, **kwargs): predicted_state = self.resampler.resample(predicted_state) if self.regulariser is not None: - predicted_state = self.regulariser.regularise(predicted_state.parent, - predicted_state, - [hypothesis.measurement]) + prior = hypothesis.prediction.parent + predicted_state = self.regulariser.regularise(prior, + predicted_state) - return Update.from_state( - state=hypothesis.prediction, - state_vector=predicted_state.state_vector, - log_weight=predicted_state.log_weight, - hypothesis=hypothesis, - timestamp=hypothesis.measurement.timestamp, - ) + return predicted_state @lru_cache() def predict_measurement(self, state_prediction, measurement_model=None, @@ -419,8 +418,12 @@ def update(self, hypotheses, **kwargs): # copy prediction prediction = hypotheses.single_hypotheses[0].prediction - updated_state = copy.copy(prediction) - + # updated_state = copy.copy(prediction) + updated_state = Update.from_state( + state=prediction, + hypothesis=hypotheses, + timestamp=prediction.timestamp + ) if any(hypotheses): detections = [single_hypothesis.measurement for single_hypothesis in hypotheses.single_hypotheses] @@ -468,16 +471,10 @@ def update(self, hypotheses, **kwargs): if any(hypotheses): # Regularisation if self.regulariser is not None: - regularised_parts = self.regulariser.regularise(updated_state.parent, - updated_state, - detections) - updated_state.state_vector = regularised_parts.state_vector - - return Update.from_state( - updated_state, - timestamp=updated_state.timestamp, - hypothesis=hypotheses, - ) + updated_state = self.regulariser.regularise(updated_state.parent, + updated_state) + + return updated_state @staticmethod def _log_space_product(A, B): diff --git a/stonesoup/updater/tests/test_particle.py b/stonesoup/updater/tests/test_particle.py index 3fd7d763b..1ed009eba 100644 --- a/stonesoup/updater/tests/test_particle.py +++ b/stonesoup/updater/tests/test_particle.py @@ -173,14 +173,28 @@ def test_bernoulli_particle(): assert update.existence_probability is not None -def test_regularised_particle(): +@pytest.mark.parametrize("transition_model, model_flag", [ + ( + CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model + False # model_flag + ), + ( + CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model + True # model_flag + ) + ], ids=["with_transition_model_init", "without_transition_model_init"] +) +def test_regularised_particle(transition_model, model_flag): - transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]) measurement_model = LinearGaussian( ndim_state=2, mapping=[0], noise_covar=np.array([[10]])) - updater = ParticleUpdater(regulariser=MCMCRegulariser(transition_model=transition_model), - measurement_model=measurement_model) + if model_flag: + updater = ParticleUpdater(regulariser=MCMCRegulariser(), + measurement_model=measurement_model) + else: + updater = ParticleUpdater(regulariser=MCMCRegulariser(transition_model=transition_model), + measurement_model=measurement_model) # Measurement model timestamp = datetime.datetime.now() particles = [Particle([[10], [10]], 1 / 9), @@ -198,11 +212,17 @@ def test_regularised_particle(): predicted_state = transition_model.function(particles, noise=True, time_interval=datetime.timedelta(seconds=1)) - prediction = ParticleStatePrediction(predicted_state, - weight=np.array([1/9]*9), - timestamp=timestamp, - transition_model=transition_model, - parent=particles) + if not model_flag: + prediction = ParticleStatePrediction(predicted_state, + weight=np.array([1/9]*9), + timestamp=timestamp, + parent=particles) + else: + prediction = ParticleStatePrediction(predicted_state, + weight=np.array([1 / 9] * 9), + timestamp=timestamp, + transition_model=transition_model, + parent=particles) measurement = Detection([[40.0]], timestamp=timestamp, measurement_model=measurement_model) eval_measurement_prediction = ParticleMeasurementPrediction(