Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Prediction Return Type #25

Merged
merged 10 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,28 @@ def future_loading(t, x = None):
filt.estimate(t, load, {'t': 32.2, 'v': 3.915})
print("Posterior State:", filt.x.mean)
print('\tSOC: ', batt.event_state(filt.x.mean)['EOD'])

## Prediction - Predict EOD given current state
# Setup prediction
mc = predictors.MonteCarlo(batt)
if isinstance(filt, state_estimators.UnscentedKalmanFilter):
samples = filt.x.sample(20)
else: # Particle Filter
samples = filt.x.raw_samples()

# Predict with a step size of 0.1
(times, inputs, states, outputs, event_states, eol) = mc.predict(samples, future_loading, dt=0.1)

# The results of prediction can be accessed by sample, e.g.,
times_sample_1 = times[1]
states_sample_1 = states[1]
# now states_sample_1[n] corresponds to time_sample_1[n]
# you can also plot the results (state_sample_1.plot())

# You can also access a state at a specific time using the .snapshot function
states_time_1 = states.snapshot(1)
# now you have all the samples from the times[sample][1]

## Print Metrics
print("\nEOD Predictions (s):")
from prog_algs.metrics import samples as metrics
Expand Down
11 changes: 6 additions & 5 deletions src/prog_algs/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from . import predictor
from .prediction import Prediction
from numpy import empty
from ..exceptions import ProgAlgTypeError
from copy import deepcopy
Expand Down Expand Up @@ -96,9 +97,9 @@ def predict(self, state_samples, future_loading_eqn, **kwargs):

result = [pred_fcn(sample) for sample in state_samples]
times_all = [tmp[0] for tmp in result]
inputs_all = [tmp[1] for tmp in result]
states_all = [tmp[2] for tmp in result]
outputs_all = [tmp[3] for tmp in result]
event_states_all = [tmp[4] for tmp in result]
time_of_event = [tmp[5] for tmp in result]
inputs_all = Prediction(times_all, [tmp[1] for tmp in result])
states_all = Prediction(times_all, [tmp[2] for tmp in result])
outputs_all = Prediction(times_all, [tmp[3] for tmp in result])
event_states_all = Prediction(times_all, [tmp[4] for tmp in result])
time_of_event = Prediction(times_all, [tmp[5] for tmp in result])
return (times_all, inputs_all, states_all, outputs_all, event_states_all, time_of_event)
62 changes: 62 additions & 0 deletions src/prog_algs/predictors/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from collections import UserList
from ..uncertain_data import UnweightedSamples


class Prediction(UserList):
"""
Result of a prediction
"""
__slots__ = ['times', 'data'] # Optimization

def __init__(self, times, data):
"""
Args:
times (array(float)): Times for each data point where times[n] corresponds to data[n]
data (array(dict)): Data points where data[n] corresponds to times[n]
"""
self.times = times
self.data = data

def __eq__(self, other):
"""Compare 2 Predictions

Args:
other (Precition)

Returns:
bool: If the two Predictions are equal
"""
return self.times == other.times and self.data == other.data

def sample(self, index):
"""Get sample by sample_id, equivalent to prediction[index]

Args:
index (int): index of sample

Returns:
SimResult: Values for that sample at different times where result[i] corresponds to time[i]
"""
return self[index]

def snapshot(self, index):
"""Get all samples from a specific timestep

Args:
index (int): Timestep (index number from times)

Returns:
UnweightedSamples: Samples for time corresponding to times[timestep]
"""
return UnweightedSamples([sample[index] for sample in self.data])

def time(self, index):
"""Get time for data point at index `index`

Args:
index (int)

Returns:
float: Time for which the data point at index `index` corresponds
"""
return self.times[index]
6 changes: 6 additions & 0 deletions src/prog_algs/uncertain_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, state):
"""
self.__state = state

def __eq__(self, other):
return isinstance(other, ScalarData) and other.mean() == self.__state

@property
def mean(self):
return self.__state
Expand Down Expand Up @@ -87,6 +90,9 @@ def sample(self, num_samples = 1):
# Completely random resample
return choice(self.__samples, num_samples)

def __eq__(self, other):
return isinstance(other, UnweightedSamples) and self.__samples == other.raw_samples()

@property
def mean(self):
mean = {}
Expand Down
55 changes: 53 additions & 2 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,56 @@ def future_loading(t, x={}):
else:
return {'i1': -4, 'i2': 2.5}



def test_prediction(self):
from prog_algs.predictors.prediction import Prediction
from prog_algs.uncertain_data import UnweightedSamples
times = [list(range(10))]*3
states = [list(range(10)), list(range(1, 11)), list(range(-1, 9))]
p = Prediction(times, states)

self.assertEqual(p[0], states[0])
self.assertEqual(p.sample(0), states[0])
self.assertEqual(p.sample(-1), states[-1])
self.assertEqual(p.snapshot(0), UnweightedSamples([0, 1, -1]))
self.assertEqual(p.snapshot(-1), UnweightedSamples([9, 10, 8]))
self.assertEqual(p.time(0), times[0])
self.assertEqual(p.time(-1), times[-1])

# Out of range
try:
tmp = p[10]
self.fail()
except Exception:
pass

try:
tmp = p.sample(10)
self.fail()
except Exception:
pass

try:
tmp = p.time(10)
self.fail()
except Exception:
pass

# Bad type
try:
tmp = p.sample('abc')
self.fail()
except Exception:
pass

# This allows the module to be executed directly
def run_tests():
l = unittest.TestLoader()
runner = unittest.TextTestRunner()
print("\n\nTesting Predictor")
result = runner.run(l.loadTestsFromTestCase(TestPredictors)).wasSuccessful()

if not result:
raise Exception("Failed test")

if __name__ == '__main__':
run_tests()