From 0f61c2bd1c8113972e7754fa1f4bfcf7f27f6908 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Thu, 17 Aug 2023 10:43:42 -0700 Subject: [PATCH 1/2] Port over changes --- .github/workflows/python-package.yml | 18 +++++++ examples/horizon.py | 37 +++++--------- src/progpy/predictors/monte_carlo.py | 6 +++ tests/__main__.py | 6 +++ tests/test_horizon.py | 75 ++++++++++++++++++++++++++++ 5 files changed, 117 insertions(+), 25 deletions(-) create mode 100644 tests/test_horizon.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 18ed3e53..fc9d98bd 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -228,6 +228,24 @@ jobs: run: pip install --upgrade --upgrade-strategy eager -e . - name: Run tests run: python -m tests.test_examples + test_horizon: + timeout-minutes: 5 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.7' + - name: Install dependencies cache + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: pip-cache + - name: Update + run: pip install --upgrade --upgrade-strategy eager -e . + - name: Run tests + run: python -m tests.test_horizon test_linear_model: timeout-minutes: 5 runs-on: ubuntu-latest diff --git a/examples/horizon.py b/examples/horizon.py index ec081147..b4f46139 100644 --- a/examples/horizon.py +++ b/examples/horizon.py @@ -12,48 +12,35 @@ ii) Time event is predicted to occur (with uncertainty) """ +import numpy as np from progpy.models.thrown_object import ThrownObject -from progpy import * +from progpy.predictors import MonteCarlo +from progpy.uncertain_data import MultivariateNormalDist from pprint import pprint def run_example(): - # Step 1: Setup model & future loading + # Step 1: Setup model, future loading, and state def future_loading(t, x = None): return {} - m = ThrownObject(process_noise = 0.25, measurement_noise = 0.2) + m = ThrownObject(process_noise = 0.5, measurement_noise = 0.15) initial_state = m.initialize() - # Step 2: Demonstrating state estimator - print("\nPerforming State Estimation Step...") - - # Step 2a: Setup NUM_SAMPLES = 1000 - filt = state_estimators.ParticleFilter(m, initial_state, num_particles = NUM_SAMPLES) - # VVV Uncomment this to use UKF State Estimator VVV - # filt = state_estimators.UnscentedKalmanFilter(batt, initial_state) - - # Step 2b: One step of state estimator - u = m.InputContainer({}) # No input for ThrownObject - filt.estimate(0.1, u, m.output(initial_state)) - - # Note: in a prognostic application the above state estimation - # step would be repeated each time there is new data. - # Here we're doing one step to demonstrate how the state estimator is used + x = MultivariateNormalDist(initial_state.keys(), initial_state.values(), np.diag([x_i*0.01 for x_i in initial_state.values()])) - # Step 3: Demonstrating Predictor + # Step 2: Demonstrating Predictor print("\nPerforming Prediction Step...") - # Step 3a: Setup Predictor - mc = predictors.MonteCarlo(m) + # Step 2a: Setup Predictor + mc = MonteCarlo(m) - # Step 3b: Perform a prediction + # Step 2b: Perform a prediction # THIS IS WHERE WE DIVERGE FROM THE THROWN_OBJECT_EXAMPLE # Here we set a prediction horizon # We're saying we are not interested in any events that occur after this time - PREDICTION_HORIZON = 7.75 - samples = filt.x # Since we're using a particle filter, which is also sample-based, we can directly use the samples, without changes + PREDICTION_HORIZON = 7.7 STEP_SIZE = 0.01 - mc_results = mc.predict(samples, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON) + mc_results = mc.predict(x, future_loading, n_samples=NUM_SAMPLES,dt=STEP_SIZE, horizon = PREDICTION_HORIZON) print("\nPredicted Time of Event:") metrics = mc_results.time_of_event.metrics() pprint(metrics) # Note this takes some time diff --git a/src/progpy/predictors/monte_carlo.py b/src/progpy/predictors/monte_carlo.py index a1bdc5e4..af7a9f84 100644 --- a/src/progpy/predictors/monte_carlo.py +++ b/src/progpy/predictors/monte_carlo.py @@ -74,6 +74,7 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) # Perform prediction t0 = params.get('t0', 0) + HORIZON = params.get('horizon', float('inf')) # Save the horizon to be used later for x in state: first_output = self.model.output(x) @@ -82,6 +83,7 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) params['t0'] = t0 params['x'] = x + params['horizon'] = HORIZON # reset to initial horizon if 'save_freq' in params and not isinstance(params['save_freq'], tuple): params['save_freq'] = (params['t0'], params['save_freq']) @@ -103,6 +105,10 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) # Non-vectorized prediction while len(events_remaining) > 0: # Still events to predict + # Since horizon is relative to t0 (the simulation starting point), + # we must subtract the difference in current t0 from the initial (i.e., prediction t0) + # each subsequent simulation + params['horizon'] = HORIZON - (params['t0'] - t0) (t, u, xi, z, es) = simulate_to_threshold(future_loading_eqn, first_output, threshold_keys = events_remaining, diff --git a/tests/__main__.py b/tests/__main__.py index 5c07a7ad..b1faad58 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -11,6 +11,7 @@ from tests.test_ensemble import main as ensemble_main from tests.test_estimate_params import main as estimate_params_main from tests.test_examples import main as examples_main +from tests.test_horizon import main as horizon_main from tests.test_linear_model import main as linear_main from tests.test_metrics import main as metrics_main from tests.test_pneumatic_valve import main as pneumatic_valve_main @@ -77,6 +78,11 @@ except Exception: was_successful = False + try: + horizon_main() + except Exception: + was_successful = False + try: linear_main() except Exception: diff --git a/tests/test_horizon.py b/tests/test_horizon.py new file mode 100644 index 00000000..c83731e4 --- /dev/null +++ b/tests/test_horizon.py @@ -0,0 +1,75 @@ +from io import StringIO +import sys +import unittest + +from progpy import predictors +from progpy.models import ThrownObject + +class TestHorizon(unittest.TestCase): + def setUp(self): + # set stdout (so it won't print) + sys.stdout = StringIO() + + def tearDown(self): + sys.stdout = sys.__stdout__ + + def test_horizon_ex(self): + # Setup model + m = ThrownObject(process_noise=0.25, measurement_noise=0.2) + # Change parameters (to make simulation faster) + m.parameters['thrower_height'] = 1.0 + m.parameters['throwing_speed'] = 10.0 + initial_state = m.initialize() + + # Define future loading (necessary for prediction call) + def future_loading(t, x=None): + return {} + + # Setup Predictor (smaller sample size for efficiency) + mc = predictors.MonteCarlo(m) + mc.parameters['n_samples'] = 50 + + # Perform a prediction + # With this horizon, all samples will reach 'falling', but only some will reach 'impact' + PREDICTION_HORIZON = 2.127 + STEP_SIZE = 0.001 + mc_results = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON) + + # 'falling' happens before the horizon is met + falling_res = [mc_results.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['falling'] is not None] + self.assertEqual(len(falling_res), mc.parameters['n_samples']) + + # 'impact' happens around the horizon, so some samples have reached this event and others haven't + impact_res = [mc_results.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['impact'] is not None] + self.assertLess(len(impact_res), mc.parameters['n_samples']) + + # Try again with very low prediction_horizon, where no events are reached + # Note: here we count how many None values there are for each event (in the above and below examples, we count values that are NOT None) + mc_results_no_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon=0.3) + falling_res_no_event = [mc_results_no_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['falling'] is None] + impact_res_no_event = [mc_results_no_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['impact'] is None] + self.assertEqual(len(falling_res_no_event), mc.parameters['n_samples']) + self.assertEqual(len(impact_res_no_event), mc.parameters['n_samples']) + + # Finally, try without horizon, all events should be reached for all samples + mc_results_all_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE) + falling_res_all_event = [mc_results_all_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['falling'] is not None] + impact_res_all_event = [mc_results_all_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['impact'] is not None] + self.assertEqual(len(falling_res_all_event), mc.parameters['n_samples']) + self.assertEqual(len(impact_res_all_event), mc.parameters['n_samples']) + +# This allows the module to be executed directly +def run_tests(): + unittest.main() + +def main(): + load_test = unittest.TestLoader() + runner = unittest.TextTestRunner() + print("\n\nTesting Horizon functionality") + result = runner.run(load_test.loadTestsFromTestCase(TestHorizon)).wasSuccessful() + + if not result: + raise Exception("Failed test") + +if __name__ == '__main__': + main() From 464bae101796b948edafaa4ed9c43ba3c18b14b8 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Sun, 20 Aug 2023 15:32:21 -0700 Subject: [PATCH 2/2] Fix copyright --- tests/test_horizon.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_horizon.py b/tests/test_horizon.py index c83731e4..c9c226f7 100644 --- a/tests/test_horizon.py +++ b/tests/test_horizon.py @@ -1,3 +1,6 @@ +# Copyright © 2021 United States Government as represented by the Administrator of the +# National Aeronautics and Space Administration. All Rights Reserved. + from io import StringIO import sys import unittest