Skip to content

Commit

Permalink
Merge pull request #24 from nasa/feature/scatter
Browse files Browse the repository at this point in the history
Scatter Plot Feature
  • Loading branch information
teubert committed Sep 30, 2021
2 parents 2450f78 + 27aa9e5 commit 6b70c1a
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 109 deletions.
16 changes: 14 additions & 2 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ def future_loading(t, x = None):
## State Estimation - perform a single ukf state estimate step
# filt = state_estimators.UnscentedKalmanFilter(batt, batt.parameters['x0'])
filt = state_estimators.ParticleFilter(batt, batt.parameters['x0'])

import matplotlib.pyplot as plt # For plotting
print("Prior State:", filt.x.mean)
print('\tSOC: ', batt.event_state(filt.x.mean)['EOD'])
fig = filt.x.plot_scatter(label='prior')
example_measurements = {'t': 32.2, 'v': 3.915}
t = 0.1
load = future_loading(t)
filt.estimate(t, load, {'t': 32.2, 'v': 3.915})
filt.estimate(t, future_loading(t), example_measurements)
print("Posterior State:", filt.x.mean)
print('\tSOC: ', batt.event_state(filt.x.mean)['EOD'])
filt.x.plot_scatter(fig= fig, label='posterior')

## Prediction - Predict EOD given current state
# Setup prediction
Expand Down Expand Up @@ -58,6 +62,14 @@ def future_loading(t, x = None):
print('\tAssuming ground truth 3002.25: ', metrics.eol_metrics(eol, 3005.25))
print('\tP(Success) if mission ends at 3002.25: ', metrics.prob_success(eol, 3005.25))

# Plot state transition
fig = states.snapshot(0).plot_scatter(label = "t={}".format(int(times[0][0])))
states.snapshot(10).plot_scatter(fig = fig, label = "t={}".format(int(times[0][10])))
states.snapshot(50).plot_scatter(fig = fig, label = "t={}".format(int(times[0][50])))

states.snapshot(-1).plot_scatter(fig = fig, label = "t={}".format(int(times[0][-1])))
plt.show()

# This allows the module to be executed directly
if __name__ == '__main__':
run_example()
45 changes: 40 additions & 5 deletions src/prog_algs/uncertain_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from abc import ABC, abstractmethod, abstractproperty
from numpy.random import choice, multivariate_normal
from numpy import array, append, delete, cov
# from prog_algs.visualize import plot_scatter
from prog_algs.visualize import plot_scatter
from collections import UserList
import warnings


class UncertainData(ABC):
"""
Abstract base class for data with uncertainty. Any new uncertainty type must implement this class
Expand Down Expand Up @@ -46,8 +45,38 @@ def cov(self):
[[float]]: covariance matrix
"""

@abstractmethod
def keys(self):
"""Get the keys for the property represented
Returns:
[string]: keys
"""
# TODO(CT): Consider median

def plot_scatter(self, fig = None, keys = None, num_samples = 100, **kwargs):
"""
Produce a scatter plot
Args:
fig (Figure, optional): Existing figure previously used to plot states. If passed a figure argument additional data will be added to the plot. Defaults to creating new figure
keys (list of strings, optional): Keys to plot. Defaults to all keys.
num_samples (int, optional): Number of samples to plot. Defaults to 100
**kwargs (optional): Additional keyword arguments passed to scatter function.
Returns:
Figure
Example:
states = UnweightedSamples([1, 2, 3, 4, 5])
states.plot_scatter() # With 100 samples
states.plot_scatter(num_samples=5) # Specifying the number of samples to plot
states.plot_scatter(keys=['state1', 'state2']) # only plot those keys
"""
samples = self.sample(num_samples)
return plot_scatter(samples, fig=fig, keys=keys, **kwargs)


class ScalarData(UncertainData):
"""
Data without uncertainty- single value
Expand All @@ -69,8 +98,11 @@ def mean(self):

@property
def cov(self):
return [[0 for i in self.__state] for j in self.__state]
return [[0]]

def keys(self):
return self.__state.keys()

def sample(self, num_samples = 1):
return array([self.__state] * num_samples)

Expand All @@ -95,8 +127,8 @@ def sample(self, num_samples = 1):
return choice(self.data, num_samples)

def keys(self):
if len(self.data) == 0:
return []
if len(self.__samples) == 0:
return [[]]
return self[0].keys()

@property
Expand Down Expand Up @@ -154,6 +186,9 @@ def sample(self, num_samples = 1):
samples = array([{key: value for (key, value) in zip(self.__labels, x)} for x in samples])
return samples

def keys(self):
return self.__labels

@property
def mean(self):
return {key: value for (key, value) in zip(self.__labels, self.__mean)}
Expand Down
4 changes: 4 additions & 0 deletions src/prog_algs/visualize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from .plot_scatter import plot_scatter
__all__ = ['plot_scatter']
96 changes: 96 additions & 0 deletions src/prog_algs/visualize/plot_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.
from matplotlib.collections import PathCollection
import matplotlib.pyplot as plt
from math import sqrt

def plot_scatter(samples, fig = None, keys = None, legend = 'auto', **kwargs):
"""
Produce a scatter plot for a given list of states
Args:
samples ([dict]): Non-empty list of states where each element is a dictionary containing a single sample
fig (Figure, optional): Existing figure previously used to plot states. If passed a figure argument additional data will be added to the plot. Defaults to creating new figure
keys (list of strings, optional): Keys to plot. Defaults to all keys.
legend (optional): When the legend should be shown, options:
False: Dont show legend
"auto": Show legend automatically if more than one data set
True: Always show legend
**kwargs (optional): Additional keyword arguments passed to scatter function. Includes those supported by scatter
Returns:
Figure
Example:
states = UnweightedSamples([1, 2, 3, 4, 5])
plot_scatter(states.sample(100)) # With 100 samples
plot_scatter(states.sample(100), keys=['state1', 'state2']) # only plot those keys
"""
# Input checks
if len(samples) <= 0:
raise Exception('Must include atleast one sample to plot')

if keys is not None:
try:
iter(keys)
except TypeError:
raise TypeError("Keys should be a list of strings (e.g., ['state1', 'state2'], was {}".format(type(keys)))

for key in keys:
if key not in samples[0].keys():
raise TypeError("Key {} was not present in samples (keys: {})".format(key, list(samples[0].keys())))

# Handle input
parameters = { # defaults
'alpha': 0.5
}
parameters.update(kwargs)

if keys is None:
keys = samples[0].keys()
keys = list(keys)

n = len(keys)
if n < 2:
raise Exception("At least 2 states required for scatter, got {}".format(n))

if fig is None:
# If no figure provided, create one
fig = plt.figure()
axes = [[fig.add_subplot(n-1, n-1, 1 + i + j*(n-1)) for i in range(n-1)] for j in range(n-1)]
else:
# Check size of axes
if len(fig.axes) != (n-1)*(n-1):
raise Exception("Cannot use existing figure - Existing figure graphs {} states, data includes {} states".format(sqrt(len(fig.axes))+1, n))

# Unpack axes
axes = [[fig.axes[i + j*(n-1)] for i in range(n-1)] for j in range(n-1)]

for i in range(n-1):
# For each column
x_key = keys[i]

# Set labels on extremes
axes[-1][i].set_xlabel(x_key) # Bottom row
axes[i][0].set_ylabel(keys[i+1]) # Left column

# plot
for j in range(i, n-1):
# for each row
y_key = keys[j+1]
x1 = [x[x_key] for x in samples]
x2 = [x[y_key] for x in samples]
axes[j][i].scatter(x1, x2, **parameters)

# Hide axes not used in plots
for j in range(0, i):
axes[j][i].set_visible(False)

# Set legend
if legend == 'auto' or legend:
labels = [thing.get_label() for thing in axes[0][0].get_children()
if isinstance(thing, PathCollection)]
if legend == 'auto' and len(labels) > 0 or legend:
fig.legend().remove() # Remove any existing legend - prevents "ghost effect"
fig.legend(labels=labels, loc='upper right')

return fig
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

__all__ = ['test_state_estimators', 'test_predictors', 'test_examples', 'test_integration', 'test_uncertain_data']
__all__ = ['test_state_estimators', 'test_predictors', 'test_examples', 'test_integration', 'test_uncertain_data', 'test_visualize']
6 changes: 5 additions & 1 deletion tests/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .test_uncertain_data import TestUncertainData
from .test_examples import TestExamples
from .test_misc import TestMisc
from .test_visualize import TestVisualize
import unittest
import sys
from examples import basic_example
Expand Down Expand Up @@ -35,11 +36,14 @@ def run_basic_ex():
print('\n\nUncertain Data Tests')
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestUncertainData))

print('\n\nVisualize Tests')
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestVisualize))

print('\n\nExamples Tests')
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestExamples))

print('\n\nIntegration Tests')
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestIntegration))

print('\n\nMisc Tests')
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestMisc))
unittest.TextTestRunner().run(l.loadTestsFromTestCase(TestMisc))
53 changes: 53 additions & 0 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
from prog_algs.visualize import plot_scatter


class TestVisualize(unittest.TestCase):
def test_scatter(self):
# Nominal
data = [{'x': 1, 'y': 2, 'z': 3}, {'x': 1.5, 'y': 2.2, 'z': -1}, {'x': 0.9, 'y': 2.1, 'z': 7}]
fig = plot_scatter(data)
fig = plot_scatter(data, fig=fig) # Add to figure
plot_scatter(data, fig=fig, keys=['x', 'y', 'z']) # All keys
plot_scatter(data, keys=['y', 'z']) # Subset of keys

# Incorrect keys
try:
plot_scatter(data, keys=7) # Not iterable
self.fail()
except Exception:
pass

try:
plot_scatter(data, keys=['x', 'i']) # Not present
self.fail()
except Exception:
pass

# Changing number of keys
fig = plot_scatter(data)
try:
plot_scatter(data, fig=fig, keys=['y', 'z']) # Different number of keys
self.fail()
except Exception:
pass

# Too few keys
try:
plot_scatter(data, keys=['y']) # Only one key
self.fail()
except Exception:
pass

# This allows the module to be executed directly
def run_tests():
l = unittest.TestLoader()
runner = unittest.TextTestRunner()
print("\n\nTesting Visualize")
result = runner.run(l.loadTestsFromTestCase(TestVisualize)).wasSuccessful()

if not result:
raise Exception("Failed test")

if __name__ == '__main__':
run_tests()

0 comments on commit 6b70c1a

Please sign in to comment.