-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Prediction #60
Conversation
src/prog_algs/metrics/samples.py
Outdated
@@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth): | |||
Returns: | |||
float: mean square error of eol predictions | |||
""" | |||
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values) | |||
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change supports the case where x is not unweighted_samples, but instead is a list (some users keep doing that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good change Chris.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python has a builtin mean
function: https://docs.python.org/3/library/statistics.html#statistics.mean. I've never used it, but I expect it would work here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated to use numpy::mean so it'll be compatible with the upcoming vectorization changes.
@@ -40,7 +40,7 @@ def future_loading(t, x = None): | |||
if isinstance(filt, state_estimators.UnscentedKalmanFilter): | |||
samples = filt.x.sample(20) | |||
else: # Particle Filter | |||
samples = filt.x.raw_samples() | |||
samples = filt.x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the refactor I make raw_samples an depreciated call, so I'm updating the example to not use it
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0][0]))) | ||
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[0][10]))) | ||
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[0][50]))) | ||
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
times returned from prediction is now times[time_index] instead of times[sample_id][time_index]. This is a breaking change that requires users to update how they're using it (like I do here).
Because of other project deadlines the review for this will be delayed by 1 week. |
a6e4fe5
to
af5325c
Compare
src/prog_algs/metrics/samples.py
Outdated
@@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth): | |||
Returns: | |||
float: mean square error of eol predictions | |||
""" | |||
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values) | |||
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python has a builtin mean
function: https://docs.python.org/3/library/statistics.html#statistics.mean. I've never used it, but I expect it would work here.
Co-authored-by: Jason Watkins <jason.watkins@nasa.gov>
Co-authored-by: Jason Watkins <jason.watkins@nasa.gov>
Refactor Prediction to support non-sample based predictors.
Refactor Prediction (an unreleased feature, so breaking changes are acceptable) to support non-sample based predictor such as the Unscented Kalman Predictor in PR #40. In this case predicted states would be of the type MultivariateNormalDistribution, not Unweighted samples.
New prediction class includes "mean" function. This is most useful for the MultivariateNormalPrediction, but has use for sample-based as well.
This change also fixes a bug in UnweightedSamplesPrediction where getting a snapshot at later timesteps after some samples have reached EOL results in an exception. Implementation fills them with none, so the result will look like this:
With this change I updated the predict step to return a single list of times (times[time_index]), rather than a list of list of times of format times[sample][time_index]. NOTE: THIS IS A BREAKING CHANGE THAT WILL REQUIRE USERS TO MAKE CHANGES. But the previous implementation makes no sense for a non sample-based approach.