From 6ace4086a40b4d7d7da8478f910540be8c7a59ef Mon Sep 17 00:00:00 2001 From: Miryam S Date: Thu, 17 Aug 2023 15:48:06 -0700 Subject: [PATCH 1/2] fixing bug: array and keys did not match in particle_filter --- src/progpy/state_estimators/particle_filter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/progpy/state_estimators/particle_filter.py b/src/progpy/state_estimators/particle_filter.py index 680c56c6..965c27f9 100644 --- a/src/progpy/state_estimators/particle_filter.py +++ b/src/progpy/state_estimators/particle_filter.py @@ -67,9 +67,8 @@ def __init__(self, model, x0, **kwargs): # Added to avoid float/int issues self.parameters['num_particles'] = int(self.parameters['num_particles']) sample_gen = x0.sample(self.parameters['num_particles']) - samples = [array(sample_gen.key(k), dtype=float64) for k in x0.keys()] - - self.particles = model.StateContainer(array(samples, dtype=float64)) + samples = {k: array(sample_gen.key(k), dtype=float64) for k in x0.keys()} + self.particles = model.StateContainer(samples) if 'R' in self.parameters: # For backwards compatibility From 9ba241a09643f049d92bbe8da1ffce92b13b69cf Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 17 Aug 2023 16:33:17 -0700 Subject: [PATCH 2/2] Add new test --- tests/test_state_estimators.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_state_estimators.py b/tests/test_state_estimators.py index 75e5a903..d57129b8 100644 --- a/tests/test_state_estimators.py +++ b/tests/test_state_estimators.py @@ -7,7 +7,7 @@ sys.path.append(join(dirname(__file__), "..")) from progpy import PrognosticsModel, LinearModel -from progpy.models import ThrownObject, BatteryElectroChem, PneumaticValveBase +from progpy.models import ThrownObject, BatteryElectroChem, PneumaticValveBase, BatteryElectroChemEOD from progpy.state_estimators import ParticleFilter, KalmanFilter, UnscentedKalmanFilter from progpy.uncertain_data import ScalarData, MultivariateNormalDist, UnweightedSamples @@ -561,6 +561,20 @@ def future_loading(t, x=None): for t, u, z in zip(times, inputs.data, outputs.data): kf.estimate(t, u, z) + def test_PF_particle_ordering(self): + """ + This is testing for a bug found by @mstraut where particle filter was mixing up the keys if users: + 1. Do not call m.initialize(), and instead + 2. provide a state as a dictionary instead of a state container, and + 3. order the states in a different order than m.states + """ + m = BatteryElectroChemEOD() + x0 = m.parameters['x0'] # state as a dictionary with the wrong order + filt = ParticleFilter(m, x0, num_particles=2) + for key in m.states: + self.assertEqual(filt.particles[key][0], x0[key]) + self.assertEqual(filt.particles[key][1], x0[key]) + # This allows the module to be executed directly def main(): # This ensures that the directory containing StateEstimatorTemplate is in the python search directory