Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Prediction #60

Merged
merged 9 commits into from
Oct 13, 2021
Merged

Refactor Prediction #60

merged 9 commits into from
Oct 13, 2021

Conversation

teubert
Copy link
Collaborator

@teubert teubert commented Oct 5, 2021

Refactor Prediction to support non-sample based predictors.

Refactor Prediction (an unreleased feature, so breaking changes are acceptable) to support non-sample based predictor such as the Unscented Kalman Predictor in PR #40. In this case predicted states would be of the type MultivariateNormalDistribution, not Unweighted samples.

New prediction class includes "mean" function. This is most useful for the MultivariateNormalPrediction, but has use for sample-based as well.

This change also fixes a bug in UnweightedSamplesPrediction where getting a snapshot at later timesteps after some samples have reached EOL results in an exception. Implementation fills them with none, so the result will look like this:

t1 t2 t3 t4
sample1 1.2 1.3 1.4 1.5
sample2 1.4 1.6 None None
sample3 1.3 1.4 1.6 None
sample4 1.1 1.15 1.3 1.45

With this change I updated the predict step to return a single list of times (times[time_index]), rather than a list of list of times of format times[sample][time_index]. NOTE: THIS IS A BREAKING CHANGE THAT WILL REQUIRE USERS TO MAKE CHANGES. But the previous implementation makes no sense for a non sample-based approach.

@teubert teubert added the enhancement New feature or request label Oct 5, 2021
@@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth):
Returns:
float: mean square error of eol predictions
"""
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values)
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change supports the case where x is not unweighted_samples, but instead is a list (some users keep doing that).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change Chris.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python has a builtin mean function: https://docs.python.org/3/library/statistics.html#statistics.mean. I've never used it, but I expect it would work here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to use numpy::mean so it'll be compatible with the upcoming vectorization changes.

@@ -40,7 +40,7 @@ def future_loading(t, x = None):
if isinstance(filt, state_estimators.UnscentedKalmanFilter):
samples = filt.x.sample(20)
else: # Particle Filter
samples = filt.x.raw_samples()
samples = filt.x
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the refactor I make raw_samples an depreciated call, so I'm updating the example to not use it

fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0][0])))
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[0][10])))
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[0][50])))
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0])))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

times returned from prediction is now times[time_index] instead of times[sample_id][time_index]. This is a breaking change that requires users to update how they're using it (like I do here).

@teubert teubert added this to the v1.1 milestone Oct 5, 2021
@teubert
Copy link
Collaborator Author

teubert commented Oct 5, 2021

Because of other project deadlines the review for this will be delayed by 1 week.

@teubert teubert force-pushed the feature/refactor_prediction branch from a6e4fe5 to af5325c Compare October 5, 2021 16:54
@@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth):
Returns:
float: mean square error of eol predictions
"""
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values)
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python has a builtin mean function: https://docs.python.org/3/library/statistics.html#statistics.mean. I've never used it, but I expect it would work here.

src/prog_algs/predictors/prediction.py Outdated Show resolved Hide resolved
src/prog_algs/predictors/prediction.py Outdated Show resolved Hide resolved
src/prog_algs/predictors/prediction.py Outdated Show resolved Hide resolved
teubert and others added 5 commits October 13, 2021 12:39
Co-authored-by: Jason Watkins <jason.watkins@nasa.gov>
Co-authored-by: Jason Watkins <jason.watkins@nasa.gov>
Co-authored-by: Jason Watkins <jason.watkins@nasa.gov>
@teubert teubert merged commit 570359e into dev Oct 13, 2021
@teubert teubert deleted the feature/refactor_prediction branch October 13, 2021 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants