diff --git a/src/progpy/state_estimators/particle_filter.py b/src/progpy/state_estimators/particle_filter.py index 680c56c6..965c27f9 100644 --- a/src/progpy/state_estimators/particle_filter.py +++ b/src/progpy/state_estimators/particle_filter.py @@ -67,9 +67,8 @@ def __init__(self, model, x0, **kwargs): # Added to avoid float/int issues self.parameters['num_particles'] = int(self.parameters['num_particles']) sample_gen = x0.sample(self.parameters['num_particles']) - samples = [array(sample_gen.key(k), dtype=float64) for k in x0.keys()] - - self.particles = model.StateContainer(array(samples, dtype=float64)) + samples = {k: array(sample_gen.key(k), dtype=float64) for k in x0.keys()} + self.particles = model.StateContainer(samples) if 'R' in self.parameters: # For backwards compatibility diff --git a/tests/test_state_estimators.py b/tests/test_state_estimators.py index 75e5a903..d57129b8 100644 --- a/tests/test_state_estimators.py +++ b/tests/test_state_estimators.py @@ -7,7 +7,7 @@ sys.path.append(join(dirname(__file__), "..")) from progpy import PrognosticsModel, LinearModel -from progpy.models import ThrownObject, BatteryElectroChem, PneumaticValveBase +from progpy.models import ThrownObject, BatteryElectroChem, PneumaticValveBase, BatteryElectroChemEOD from progpy.state_estimators import ParticleFilter, KalmanFilter, UnscentedKalmanFilter from progpy.uncertain_data import ScalarData, MultivariateNormalDist, UnweightedSamples @@ -561,6 +561,20 @@ def future_loading(t, x=None): for t, u, z in zip(times, inputs.data, outputs.data): kf.estimate(t, u, z) + def test_PF_particle_ordering(self): + """ + This is testing for a bug found by @mstraut where particle filter was mixing up the keys if users: + 1. Do not call m.initialize(), and instead + 2. provide a state as a dictionary instead of a state container, and + 3. order the states in a different order than m.states + """ + m = BatteryElectroChemEOD() + x0 = m.parameters['x0'] # state as a dictionary with the wrong order + filt = ParticleFilter(m, x0, num_particles=2) + for key in m.states: + self.assertEqual(filt.particles[key][0], x0[key]) + self.assertEqual(filt.particles[key][1], x0[key]) + # This allows the module to be executed directly def main(): # This ensures that the directory containing StateEstimatorTemplate is in the python search directory