Skip to content

Commit

Permalink
Merge pull request #1012 from dstl/predict_meas_noise
Browse files Browse the repository at this point in the history
Allow generation of predicted measurement without measurement noise
  • Loading branch information
sdhiscocks committed May 13, 2024
2 parents ab6af2b + 32d1af0 commit 48865c8
Show file tree
Hide file tree
Showing 18 changed files with 182 additions and 58 deletions.
4 changes: 2 additions & 2 deletions stonesoup/hypothesiser/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HMMHypothesiser(Hypothesiser):
doc="Gate Probability - prob. gate contains true "
"measurement if detected")

def hypothesise(self, track, detections, timestamp):
def hypothesise(self, track, detections, timestamp, **kwargs):
""" Evaluate and return all track association hypotheses.
For a given track and a set of N available detections, return a MultipleHypothesis object
Expand Down Expand Up @@ -70,7 +70,7 @@ def hypothesise(self, track, detections, timestamp):
measurement_prediction = self.updater.predict_measurement(
predicted_state=prediction,
measurement_model=detection.measurement_model,
noise=False
measurement_noise=False
)

probability = self.measure(measurement_prediction, detection)
Expand Down
11 changes: 10 additions & 1 deletion stonesoup/updater/alphabeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,28 @@ class AlphaBetaUpdater(Updater):
"the position elements in the state vector.")

@lru_cache()
def predict_measurement(self, prediction, measurement_model=None, **kwargs):
def predict_measurement(self, prediction, measurement_model=None, measurement_noise=False,
**kwargs):
"""Return the predicted measurement
Parameters
----------
prediction : :class:`~.StatePrediction`
The state prediction
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater object
is used
measurement_noise : bool
Whether to include measurement noise, in this case on `False` is valid.
Default `False`
Returns
-------
: :class:`~.StateVector`
The predicted measurement
"""
if measurement_noise:
raise ValueError("measurement noise must be False")
# This necessary if predict_measurement called on its own
measurement_model = self._check_measurement_model(measurement_model)
pred_meas = measurement_model.matrix(**kwargs) @ prediction.state_vector
Expand Down
8 changes: 6 additions & 2 deletions stonesoup/updater/asd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ASDKalmanUpdater(KalmanUpdater):
vol. 47, no. 4, pp. 2766-2778, OCTOBER 2011, doi: 10.1109/TAES.2011.6034663.
"""
@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None,
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
r"""Predict the measurement implied by the predicted state mean
Expand All @@ -33,6 +33,8 @@ def predict_measurement(self, predicted_state, measurement_model=None,
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater
object is used
measurement_noise : bool
Whether to include measurement noise :math:`R` with innovation covariance
**kwargs : various
These are passed to :meth:`~.MeasurementModel.function` and
:meth:`~.MeasurementModel.matrix`
Expand All @@ -53,7 +55,9 @@ def predict_measurement(self, predicted_state, measurement_model=None,
hh = self._measurement_matrix(predicted_state=state_at_t,
measurement_model=measurement_model,
**kwargs)
innov_cov = hh@state_at_t.covar@hh.T + measurement_model.covar()
innov_cov = hh@state_at_t.covar@hh.T
if measurement_noise:
innov_cov += measurement_model.covar()

t2t_plus = slice(t_index * predicted_state.ndim, (t_index+1) * predicted_state.ndim)
meas_cross_cov = predicted_state.multi_covar[:, t2t_plus] @ hh.T
Expand Down
6 changes: 4 additions & 2 deletions stonesoup/updater/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ def _check_measurement_model(self, measurement_model):

@abstractmethod
def predict_measurement(
self, state_prediction, measurement_model=None, **kwargs):
self, predicted_state, measurement_model=None, measurement_noise=True, **kwargs):
"""Get measurement prediction from state prediction
Parameters
----------
state_prediction : :class:`~.StatePrediction`
predicted_state : :class:`~.StatePrediction`
The state prediction
measurement_model: :class:`~.MeasurementModel`, optional
The measurement model used to generate the measurement prediction.
Should be used in cases where the measurement model is dependent
on the received measurement. The default is `None`, in which case
the updater will use the measurement model specified on
initialisation
measurement_noise : bool
Whether to include measurement noise predicted measurement. Default `True`
Returns
-------
Expand Down
9 changes: 5 additions & 4 deletions stonesoup/updater/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _check_measurement_model(self, measurement_model):

return measurement_model

def predict_measurement(self, predicted_state, measurement_model, **kwargs):
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=False,
**kwargs):
r"""Predict the measurement implied by the predicted state.
Parameters
Expand All @@ -110,8 +111,8 @@ def predict_measurement(self, predicted_state, measurement_model, **kwargs):
The predicted state.
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater object is used.
measurement : :class:`~.CategoricalState`.
The measurement.
measurement_noise : bool
Whether to include measurement noise. Default `False`
**kwargs : various
These are passed to :meth:`~.MeasurementModel.function`.
Expand All @@ -123,7 +124,7 @@ def predict_measurement(self, predicted_state, measurement_model, **kwargs):

measurement_model = self._check_measurement_model(measurement_model)

pred_meas = measurement_model.function(predicted_state, **kwargs)
pred_meas = measurement_model.function(predicted_state, noise=measurement_noise, **kwargs)

return MeasurementPrediction.from_state(
predicted_state,
Expand Down
15 changes: 11 additions & 4 deletions stonesoup/updater/chernoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class ChernoffUpdater(Updater):
doc="A weighting parameter in the range :math:`(0,1]`")

@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None, **kwargs):
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
r"""
This function predicts the measurement of a state in situations where measurements consist
of a covariance and state vector.
Expand All @@ -93,6 +94,9 @@ def predict_measurement(self, predicted_state, measurement_model=None, **kwargs
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the updater will use the model that was specified
on initialization.
measurement_noise : bool
Whether to include measurement noise. Default `True`. Where `False` the
predicted state covariance is used directly without omega factor.
Returns
-------
Expand All @@ -102,9 +106,12 @@ def predict_measurement(self, predicted_state, measurement_model=None, **kwargs

measurement_model = self._check_measurement_model(measurement_model)

# The innovation covariance uses the noise covariance from the measurement model
state_covar_m = measurement_model.noise_covar
innov_covar = 1/(1-self.omega)*state_covar_m + 1/self.omega*predicted_state.covar
if measurement_noise:
# The innovation covariance uses the noise covariance from the measurement model
state_covar_m = measurement_model.noise_covar
innov_covar = 1/(1-self.omega)*state_covar_m + 1/self.omega*predicted_state.covar
else:
innov_covar = predicted_state.covar

# The predicted measurement and measurement cross covariance can be taken from
# the predicted state
Expand Down
11 changes: 6 additions & 5 deletions stonesoup/updater/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _check_measurement_prediction(self, hypothesis, **kwargs):
return hypothesis

@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None,
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
r"""Predict the measurement implied by the predicted state mean
Expand All @@ -119,7 +119,8 @@ def predict_measurement(self, predicted_state, measurement_model=None,
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater object
is used
measurement_noise : bool
Whether to include measurement noise :math:`R` when generating ensemble. Default `True`
Returns
-------
Expand All @@ -132,10 +133,10 @@ def predict_measurement(self, predicted_state, measurement_model=None,
measurement_model = self._check_measurement_model(measurement_model)

# Propagate each vector through the measurement model.
pred_meas_ensemble = measurement_model.function(predicted_state, noise=True)
pred_meas_ensemble = measurement_model.function(
predicted_state, noise=measurement_noise, **kwargs)

return MeasurementPrediction.from_state(
predicted_state, pred_meas_ensemble)
return MeasurementPrediction.from_state(predicted_state, state_vector=pred_meas_ensemble)

def update(self, hypothesis, **kwargs):
r"""The Ensemble Kalman update method. The Ensemble Kalman filter
Expand Down
12 changes: 9 additions & 3 deletions stonesoup/updater/information.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ def _inverse_measurement_covar(self, measurement_model, **kwargs):
return inv_measurement_covar

@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None, **kwargs):
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
r"""There's no direct analogue of a predicted measurement in the information form. This
method is therefore provided to return the predicted measurement as would the standard
Kalman updater. This is mainly for compatibility as it's not anticipated that it would
be used in the usual operation of the information filter.
Parameters
----------
predicted_information_state : :class:`~.State`
predicted_state : :class:`~.State`
The predicted state in information form :math:`\mathbf{y}_{k|k-1}`
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater object
is used
measurement_noise : bool
Whether to include measurement noise :math:`R` with innovation covariance.
Default `True`
**kwargs : various
These are passed to :meth:`~.MeasurementModel.matrix()`
Expand All @@ -96,7 +100,9 @@ def predict_measurement(self, predicted_state, measurement_model=None, **kwargs)
predicted_state_mean = predicted_covariance @ predicted_state.state_vector

predicted_measurement = hh @ predicted_state_mean
innovation_covariance = hh @ predicted_covariance @ hh.T + measurement_model.covar()
innovation_covariance = hh @ predicted_covariance @ hh.T
if measurement_noise:
innovation_covariance += measurement_model.covar(**kwargs)

return GaussianMeasurementPrediction(predicted_measurement, innovation_covariance,
predicted_state.timestamp,
Expand Down
56 changes: 36 additions & 20 deletions stonesoup/updater/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _measurement_cross_covariance(self, predicted_state, measurement_matrix):
"""
return predicted_state.covar @ measurement_matrix.T

def _innovation_covariance(self, m_cross_cov, meas_mat, meas_mod, **kwargs):
def _innovation_covariance(self, m_cross_cov, meas_mat, meas_mod, measurement_noise, **kwargs):
"""Compute the innovation covariance
Parameters
Expand All @@ -126,14 +126,19 @@ def _innovation_covariance(self, m_cross_cov, meas_mat, meas_mod, **kwargs):
Measurement matrix
meas_mod : :class:~.MeasurementModel`
Measurement model
measurement_noise : bool
Include measurement noise or not
Returns
-------
: numpy.ndarray
The innovation covariance
"""
return meas_mat @ m_cross_cov + meas_mod.covar()
innov_covar = meas_mat @ m_cross_cov
if measurement_noise:
innov_covar += meas_mod.covar(**kwargs)
return innov_covar

def _posterior_mean(self, predicted_state, kalman_gain, measurement, measurement_prediction):
r"""Compute the posterior mean, :math:`\mathbf{x}_{k|k} = \mathbf{x}_{k|k-1} + K_k
Expand Down Expand Up @@ -189,7 +194,7 @@ def _posterior_covariance(self, hypothesis):
return post_cov.view(CovarianceMatrix), kalman_gain

@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None,
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
r"""Predict the measurement implied by the predicted state mean
Expand All @@ -200,6 +205,9 @@ def predict_measurement(self, predicted_state, measurement_model=None,
measurement_model : :class:`~.MeasurementModel`
The measurement model. If omitted, the model in the updater object
is used
measurement_noise : bool
Whether to include measurement noise :math:`R` with innovation covariance.
Default `True`
**kwargs : various
These are passed to :meth:`~.MeasurementModel.function` and
:meth:`~.MeasurementModel.matrix`
Expand All @@ -222,7 +230,8 @@ def predict_measurement(self, predicted_state, measurement_model=None,

# The measurement cross covariance and innovation covariance
meas_cross_cov = self._measurement_cross_covariance(predicted_state, hh)
innov_cov = self._innovation_covariance(meas_cross_cov, hh, measurement_model, **kwargs)
innov_cov = self._innovation_covariance(
meas_cross_cov, hh, measurement_model, measurement_noise, **kwargs)

return MeasurementPrediction.from_state(
predicted_state, pred_meas, innov_cov, cross_covar=meas_cross_cov)
Expand Down Expand Up @@ -331,8 +340,7 @@ def _measurement_matrix(self, predicted_state, measurement_model=None,
else:
if linearisation_point is None:
linearisation_point = predicted_state
return measurement_model.jacobian(linearisation_point,
**kwargs)
return measurement_model.jacobian(linearisation_point, **kwargs)


class UnscentedKalmanUpdater(KalmanUpdater):
Expand Down Expand Up @@ -365,7 +373,8 @@ class UnscentedKalmanUpdater(KalmanUpdater):
"3-Ns")

@lru_cache()
def predict_measurement(self, predicted_state, measurement_model=None):
def predict_measurement(self, predicted_state, measurement_model=None, measurement_noise=True,
**kwargs):
"""Unscented Kalman Filter measurement prediction step. Uses the
unscented transform to estimate a Gauss-distributed predicted
measurement.
Expand All @@ -380,6 +389,8 @@ def predict_measurement(self, predicted_state, measurement_model=None):
dependent on the received measurement (the default is `None`, in
which case the updater will use the measurement model specified on
initialisation)
measurement_noise : bool
Whether to include measurement noise :math:`R` with innovation covariance
Returns
-------
Expand All @@ -394,10 +405,10 @@ def predict_measurement(self, predicted_state, measurement_model=None):
gauss2sigma(predicted_state,
self.alpha, self.beta, self.kappa)

meas_pred_mean, meas_pred_covar, cross_covar, _, _, _ = \
covar_noise = measurement_model.covar(**kwargs) if measurement_noise else None
meas_pred_mean, meas_pred_covar, cross_covar, *_ = \
unscented_transform(sigma_points, mean_weights, covar_weights,
measurement_model.function,
covar_noise=measurement_model.covar())
measurement_model.function, covar_noise=covar_noise)

return MeasurementPrediction.from_state(
predicted_state, meas_pred_mean, meas_pred_covar, cross_covar=cross_covar)
Expand Down Expand Up @@ -451,33 +462,38 @@ def _measurement_cross_covariance(self, predicted_state, measurement_matrix):
"""
return predicted_state.sqrt_covar.T @ measurement_matrix.T

def _innovation_covariance(self, m_cross_cov, meas_mat, meas_mod):
def _innovation_covariance(self, m_cross_cov, meas_mat, meas_mod, measurement_noise, **kwargs):
"""Compute the innovation covariance
Parameters
----------
m_cross_cov : numpy.array
m_cross_cov : numpy.ndarray
The measurement cross covariance matrix
meas_mat : numpy.array
meas_mat : numpy.ndarray
The measurement matrix. Not required in this instance. Ignored.
meas_mod : :class:`~.MeasurementModel`
Measurement model. The class attribute :attr:`sqrt_covar` indicates whether this is
passed in square root form. If it doesn't exist then :attr:`covar` is assumed to exist
and is used instead.
measurement_noise : bool
Include measurement noise or not
Returns
-------
: numpy.ndarray
The innovation covariance
"""
# If the measurement covariance matrix is square root then square it
try:
meas_cov = meas_mod.sqrt_covar @ meas_mod.sqrt_covar.T
except AttributeError:
meas_cov = meas_mod.covar()
innov_covar = m_cross_cov.T @ m_cross_cov
if measurement_noise:
# If the measurement covariance matrix is square root then square it
try:
meas_cov = meas_mod.sqrt_covar @ meas_mod.sqrt_covar.T
except AttributeError:
meas_cov = meas_mod.covar(**kwargs)
innov_covar += meas_cov

return m_cross_cov.T @ m_cross_cov + meas_cov
return innov_covar

def _posterior_covariance(self, hypothesis):
"""
Expand Down Expand Up @@ -645,7 +661,7 @@ def update(self, hypothesis, **kwargs):
cross_cov = self._measurement_cross_covariance(hypothesis.prediction, hh)
post_state.hypothesis.measurement_prediction.cross_covar = cross_cov
post_state.hypothesis.measurement_prediction.covar = \
self._innovation_covariance(cross_cov, hh, measurement_model)
self._innovation_covariance(cross_cov, hh, measurement_model, True)

prev_state = post_state
post_state = super().update(post_state.hypothesis, **kwargs)
Expand Down
Loading

0 comments on commit 48865c8

Please sign in to comment.