Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Prediction #60

Merged
merged 9 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def future_loading(t, x = None):
if isinstance(filt, state_estimators.UnscentedKalmanFilter):
samples = filt.x.sample(20)
else: # Particle Filter
samples = filt.x.raw_samples()
samples = filt.x
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the refactor I make raw_samples an depreciated call, so I'm updating the example to not use it


# Predict with a step size of 0.1
(times, inputs, states, outputs, event_states, eol) = mc.predict(samples, future_loading, dt=0.1)
Expand All @@ -63,11 +63,11 @@ def future_loading(t, x = None):
print('\tP(Success) if mission ends at 3002.25: ', metrics.prob_success(eol, 3005.25))

# Plot state transition
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0][0])))
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[0][10])))
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[0][50])))
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0])))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

times returned from prediction is now times[time_index] instead of times[sample_id][time_index]. This is a breaking change that requires users to update how they're using it (like I do here).

states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[10])))
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[50])))

states.snapshot(-1).plot_scatter(fig = fig, label = "t={}".format(int(times[0][-1])))
states.snapshot(-1).plot_scatter(fig = fig, label = "t={}".format(int(times[-1])))
plt.show()

# This allows the module to be executed directly
Expand Down
24 changes: 12 additions & 12 deletions src/prog_algs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ def run_prog_playback(obs, pred, future_loading, output_measurements, **kwargs):
config.update(kwargs)

next_predict = output_measurements[0][0] + config['predict_rate']
times = np.empty((len(output_measurements), config['num_samples']), dtype=object)
inputs = np.empty((len(output_measurements), config['num_samples']), dtype=object)
states = np.empty((len(output_measurements), config['num_samples']), dtype=object)
outputs = np.empty((len(output_measurements), config['num_samples']), dtype=object)
event_states = np.empty((len(output_measurements), config['num_samples']), dtype=object)
eols = np.empty((len(output_measurements), config['num_samples']), dtype=object)
times = []
inputs = []
states = []
outputs = []
event_states = []
eols = []
index = 0
for (t, measurement) in output_measurements:
obs.estimate(t, future_loading(t), measurement)
if t >= next_predict:
(t, u, x, z, es, eol) = pred.predict(obs.x.sample(config['num_samples']), future_loading, **config['predict_config'])
times[index, :] = t
inputs[index, :] = u
states[index, :] = x
outputs[index, :] = z
event_states[index, :] = es
eols[index, :] = eol
times.append(t)
inputs.append(u)
states.append(x)
outputs.append(z)
event_states.append(es)
eols.append(eol)
index += 1
next_predict += config['predict_rate']
return (times, inputs, states, outputs, event_states, eols)
2 changes: 1 addition & 1 deletion src/prog_algs/metrics/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def mean_square_error(values, ground_truth):
Returns:
float: mean square error of eol predictions
"""
return sum([(x.mean() - ground_truth)**2 for x in values])/len(values)
return sum([(sum(x)/len(x) - ground_truth)**2 for x in values])/len(values)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change supports the case where x is not unweighted_samples, but instead is a list (some users keep doing that).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change Chris.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python has a builtin mean function: https://docs.python.org/3/library/statistics.html#statistics.mean. I've never used it, but I expect it would work here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to use numpy::mean so it'll be compatible with the upcoming vectorization changes.


def eol_profile_metrics(eol, ground_truth):
"""Calculate eol profile metrics
Expand Down
18 changes: 11 additions & 7 deletions src/prog_algs/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from .prediction import Prediction
from .prediction import UnweightedSamplesPrediction
from .predictor import Predictor
from ..exceptions import ProgAlgTypeError
from copy import deepcopy
Expand Down Expand Up @@ -89,9 +89,13 @@ def predict(self, state_samples, future_loading_eqn, **kwargs):
result = [pred_fcn(sample) for sample in state_samples]
times_all, inputs_all, states_all, outputs_all, event_states_all, time_of_event = map(list, zip(*result))

inputs_all = Prediction(times_all, inputs_all)
states_all = Prediction(times_all, states_all)
outputs_all = Prediction(times_all, outputs_all)
event_states_all = Prediction(times_all, event_states_all)
time_of_event = UnweightedSamples(time_of_event)
return (times_all, inputs_all, states_all, outputs_all, event_states_all, time_of_event)
# Return longest time array
times_length = [len(t) for t in times_all]
times_max_len = max(times_length)
times = times_all[times_length.index(times_max_len)]

inputs_all = UnweightedSamplesPrediction(times, inputs_all)
states_all = UnweightedSamplesPrediction(times, states_all)
outputs_all = UnweightedSamplesPrediction(times, outputs_all)
event_states_all = UnweightedSamplesPrediction(times, event_states_all)
return (times, inputs_all, states_all, outputs_all, event_states_all, time_of_event)
129 changes: 113 additions & 16 deletions src/prog_algs/predictors/prediction.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from abc import ABC, abstractmethod, abstractproperty
from collections import UserList
from ..uncertain_data import UnweightedSamples
from warnings import warn

from ..uncertain_data import UnweightedSamples, MultivariateNormalDist

class Prediction(UserList):
"""
Result of a prediction

class Prediction(ABC):
"""
__slots__ = ['times', 'data'] # Optimization
Parent class for the result of a prediction. Is returned by the predict method of a predictor. Defines the interface for operations on a prediciton data object.

Note: This class is not intended to be instantiated directly, instead subclasses should be used
"""
def __init__(self, times, data):
"""
Args:
times (array(float)): Times for each data point where times[n] corresponds to data[n]
data (array(dict)): Data points where data[n] corresponds to times[n]
data
"""
self.times = times
self.data = data
Expand All @@ -28,18 +33,88 @@ def __eq__(self, other):
"""
return self.times == other.times and self.data == other.data

def sample(self, index):
"""Get sample by sample_id, equivalent to prediction[index]
@abstractmethod
def snapshot(self, time_index):
"""Get all samples from a specific timestep

Args:
index (int): Timestep (index number from times)

Returns:
UnweightedSamples: Samples for time corresponding to times[timestep]
"""
pass

@property
@abstractproperty
def mean(self):
"""Mean estimate

Example:
mean_value = data.mean
"""
pass

def time(self, index):
"""Get time for data point at index `index`

Args:
index (int)

Returns:
float: Time for which the data point at index `index` corresponds
"""
warn("Depreciated. Please use prediction.times[index] instead.")
teubert marked this conversation as resolved.
Show resolved Hide resolved
return self.times[index]


class UnweightedSamplesPrediction(Prediction, UserList):
"""
Data class for the result of a prediction, where the predictions are stored as UnweightedSamples. Is returned from the predict method of a sample based prediction class (e.g., MonteCarlo).
"""
def __init__(self, times, data):
"""
Initialize UnweightedSamplesPrediction

Args:
times (array(float)): Times for each data point where times[n] corresponds to data[n]
data (array(dict)): Data points where data[n] corresponds to times[n]
"""
super(UnweightedSamplesPrediction, self).__init__(times, data)
self.__transformed = False # If transform has been calculated

def __calculate_tranform(self):
"""
Calculate tranform of the data from data[sample_id][time_id] to data[time_id][sample_id]. Result is cached as self.__transform and is used in methods which look at a snapshot for a specific time
"""
# Lazy calculation of tranform - only if needed
# Note: prediction stops when event is reached, so for the length of all will not be the same.
# If the prediction doesn't go this far, then the value is set to None
self.__transform = [UnweightedSamples([sample[time_index] if len(sample) > time_index else None for sample in self.data]) for time_index in range(len(self.times))]
self.__transformed = True

def __str__(self):
return "UnweightedSamplesPrediction with {} savepoints".format(len(self.times))

@property
def mean(self):
if not self.__transformed:
self.__calculate_tranform()
return [dist.mean for dist in self.__transform]

def sample(self, sample_id):
"""Get sample by sample_id, equivalent to prediction[index]. Depreciated in favor of prediction[id]

Args:
index (int): index of sample

Returns:
SimResult: Values for that sample at different times where result[i] corresponds to time[i]
"""
return self[index]
warn("Depreciated. Please use prediction[sample_id] instead.")
teubert marked this conversation as resolved.
Show resolved Hide resolved
return self[sample_id]

def snapshot(self, index):
def snapshot(self, time_index):
"""Get all samples from a specific timestep

Args:
Expand All @@ -48,15 +123,37 @@ def snapshot(self, index):
Returns:
UnweightedSamples: Samples for time corresponding to times[timestep]
"""
return UnweightedSamples([sample[index] for sample in self.data])
if not self.__transformed:
self.__calculate_tranform()
return self.__transform[time_index]

def time(self, index):
"""Get time for data point at index `index`
def __not_implemented(self, *args, **kw):
"""
Called for not implemented functions. These functions are not used to make the class immutable
"""
raise ValueError("UnweightedSamplesPrediction is immutable (i.e., read only)")

append = pop = __setitem__ = __setslice__ = __delitem__ = __not_implemented
teubert marked this conversation as resolved.
Show resolved Hide resolved


class MultivariateNormalDistPrediction(Prediction):
"""
Data class for the result of a prediction, where the predictions are stored as MultivariateNormalDist. Is returned from the predict method of a MultivariateNormalDist-based prediction class (e.g., Unscented Kalman Predictor).
"""
def __str__(self):
return "MultivariateNormalDistPrediction with {} savepoints".format(len(self.times))

@property
def mean(self):
return [dist.mean for dist in self.data]

def snapshot(self, time_index):
"""Get all samples from a specific timestep

Args:
index (int)
index (int): Timestep (index number from times)

Returns:
float: Time for which the data point at index `index` corresponds
UnweightedSamples: Samples for time corresponding to times[timestep]
"""
return self.times[index]
return self.data[time_index]
9 changes: 6 additions & 3 deletions src/prog_algs/uncertain_data/unweighted_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,23 @@ def sample(self, num_samples = 1):
def keys(self):
if len(self.data) == 0:
return [[]]
return self[0].keys()
for sample in self:
if sample is not None:
return sample.keys()
return []

@property
def mean(self):
mean = {}
for key in self.keys():
mean[key] = array([x[key] for x in self.data]).mean()
mean[key] = array([x[key] for x in self.data if x is not None]).mean()
return mean

@property
def cov(self):
if len(self.data) == 0:
return [[]]
unlabeled_samples = array([[x[key] for x in self.data] for key in self.data[0].keys()])
unlabeled_samples = array([[x[key] for x in self.data if x is not None] for key in self.keys()])
return cov(unlabeled_samples)

def __str__(self):
Expand Down
8 changes: 2 additions & 6 deletions src/prog_algs/visualize/plot_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def plot_scatter(samples, fig = None, keys = None, legend = 'auto', **kwargs):
iter(keys)
except TypeError:
raise TypeError("Keys should be a list of strings (e.g., ['state1', 'state2'], was {}".format(type(keys)))

for key in keys:
if key not in samples[0].keys():
raise TypeError("Key {} was not present in samples (keys: {})".format(key, list(samples[0].keys())))

# Handle input
parameters = { # defaults
Expand Down Expand Up @@ -77,8 +73,8 @@ def plot_scatter(samples, fig = None, keys = None, legend = 'auto', **kwargs):
for j in range(i, n-1):
# for each row
y_key = keys[j+1]
x1 = [x[x_key] for x in samples]
x2 = [x[y_key] for x in samples]
x1 = [x[x_key] for x in samples if x is not None]
x2 = [x[y_key] for x in samples if x is not None]
axes[j][i].scatter(x1, x2, **parameters)

# Hide axes not used in plots
Expand Down
33 changes: 29 additions & 4 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,44 @@ def future_loading(t, x={}):
else:
return {'i1': -4, 'i2': 2.5}

def test_prediction(self):
from prog_algs.predictors.prediction import Prediction
def test_prediction_mvnormaldist(self):
from prog_algs.predictors.prediction import MultivariateNormalDistPrediction
from prog_algs.uncertain_data import MultivariateNormalDist
times = list(range(10))
covar = [[0.1, 0.01], [0.01, 0.1]]
means = [{'a': 1+i/10, 'b': 2-i/5} for i in range(10)]
states = [MultivariateNormalDist(means[i].keys(), means[i].values(), covar) for i in range(10)]
p = MultivariateNormalDistPrediction(times, states)

self.assertEqual(p.mean, means)
self.assertEqual(p.snapshot(0), states[0])
self.assertEqual(p.snapshot(-1), states[-1])
self.assertEqual(p.time(0), times[0])
self.assertEqual(p.times[0], times[0])
self.assertEqual(p.time(-1), times[-1])
self.assertEqual(p.times[-1], times[-1])

# Out of range
try:
tmp = p.time(10)
self.fail()
except Exception:
pass

def test_prediction_uwsamples(self):
from prog_algs.predictors.prediction import UnweightedSamplesPrediction
from prog_algs.uncertain_data import UnweightedSamples
times = [list(range(10))]*3
times = list(range(10))
states = [list(range(10)), list(range(1, 11)), list(range(-1, 9))]
p = Prediction(times, states)
p = UnweightedSamplesPrediction(times, states)

self.assertEqual(p[0], states[0])
self.assertEqual(p.sample(0), states[0])
self.assertEqual(p.sample(-1), states[-1])
self.assertEqual(p.snapshot(0), UnweightedSamples([0, 1, -1]))
self.assertEqual(p.snapshot(-1), UnweightedSamples([9, 10, 8]))
self.assertEqual(p.time(0), times[0])
self.assertEqual(p.times[0], times[0])
self.assertEqual(p.time(-1), times[-1])

# Out of range
Expand Down