-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Prediction #60
Changes from 4 commits
1715976
f3f5fff
0f0b975
af5325c
2fdeb2d
27bc560
a923b6a
5888662
f1aa965
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,7 @@ def future_loading(t, x = None): | |
if isinstance(filt, state_estimators.UnscentedKalmanFilter): | ||
samples = filt.x.sample(20) | ||
else: # Particle Filter | ||
samples = filt.x.raw_samples() | ||
samples = filt.x | ||
|
||
# Predict with a step size of 0.1 | ||
(times, inputs, states, outputs, event_states, eol) = mc.predict(samples, future_loading, dt=0.1) | ||
|
@@ -63,11 +63,11 @@ def future_loading(t, x = None): | |
print('\tP(Success) if mission ends at 3002.25: ', metrics.prob_success(eol, 3005.25)) | ||
|
||
# Plot state transition | ||
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0][0]))) | ||
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[0][10]))) | ||
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[0][50]))) | ||
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0]))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. times returned from prediction is now times[time_index] instead of times[sample_id][time_index]. This is a breaking change that requires users to update how they're using it (like I do here). |
||
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[10]))) | ||
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[50]))) | ||
|
||
states.snapshot(-1).plot_scatter(fig = fig, label = "t={}".format(int(times[0][-1]))) | ||
states.snapshot(-1).plot_scatter(fig = fig, label = "t={}".format(int(times[-1]))) | ||
plt.show() | ||
|
||
# This allows the module to be executed directly | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth): | |
Returns: | ||
float: mean square error of eol predictions | ||
""" | ||
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values) | ||
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change supports the case where x is not unweighted_samples, but instead is a list (some users keep doing that). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good change Chris. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python has a builtin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated to use numpy::mean so it'll be compatible with the upcoming vectorization changes. |
||
|
||
def eol_profile_metrics(eol, ground_truth): | ||
"""Calculate eol profile metrics | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the refactor I make raw_samples an depreciated call, so I'm updating the example to not use it