-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from nasa/feature/scatter
Scatter Plot Feature
- Loading branch information
Showing
8 changed files
with
320 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. | ||
|
||
from .plot_scatter import plot_scatter | ||
__all__ = ['plot_scatter'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. | ||
from matplotlib.collections import PathCollection | ||
import matplotlib.pyplot as plt | ||
from math import sqrt | ||
|
||
def plot_scatter(samples, fig = None, keys = None, legend = 'auto', **kwargs): | ||
""" | ||
Produce a scatter plot for a given list of states | ||
Args: | ||
samples ([dict]): Non-empty list of states where each element is a dictionary containing a single sample | ||
fig (Figure, optional): Existing figure previously used to plot states. If passed a figure argument additional data will be added to the plot. Defaults to creating new figure | ||
keys (list of strings, optional): Keys to plot. Defaults to all keys. | ||
legend (optional): When the legend should be shown, options: | ||
False: Dont show legend | ||
"auto": Show legend automatically if more than one data set | ||
True: Always show legend | ||
**kwargs (optional): Additional keyword arguments passed to scatter function. Includes those supported by scatter | ||
Returns: | ||
Figure | ||
Example: | ||
states = UnweightedSamples([1, 2, 3, 4, 5]) | ||
plot_scatter(states.sample(100)) # With 100 samples | ||
plot_scatter(states.sample(100), keys=['state1', 'state2']) # only plot those keys | ||
""" | ||
# Input checks | ||
if len(samples) <= 0: | ||
raise Exception('Must include atleast one sample to plot') | ||
|
||
if keys is not None: | ||
try: | ||
iter(keys) | ||
except TypeError: | ||
raise TypeError("Keys should be a list of strings (e.g., ['state1', 'state2'], was {}".format(type(keys))) | ||
|
||
for key in keys: | ||
if key not in samples[0].keys(): | ||
raise TypeError("Key {} was not present in samples (keys: {})".format(key, list(samples[0].keys()))) | ||
|
||
# Handle input | ||
parameters = { # defaults | ||
'alpha': 0.5 | ||
} | ||
parameters.update(kwargs) | ||
|
||
if keys is None: | ||
keys = samples[0].keys() | ||
keys = list(keys) | ||
|
||
n = len(keys) | ||
if n < 2: | ||
raise Exception("At least 2 states required for scatter, got {}".format(n)) | ||
|
||
if fig is None: | ||
# If no figure provided, create one | ||
fig = plt.figure() | ||
axes = [[fig.add_subplot(n-1, n-1, 1 + i + j*(n-1)) for i in range(n-1)] for j in range(n-1)] | ||
else: | ||
# Check size of axes | ||
if len(fig.axes) != (n-1)*(n-1): | ||
raise Exception("Cannot use existing figure - Existing figure graphs {} states, data includes {} states".format(sqrt(len(fig.axes))+1, n)) | ||
|
||
# Unpack axes | ||
axes = [[fig.axes[i + j*(n-1)] for i in range(n-1)] for j in range(n-1)] | ||
|
||
for i in range(n-1): | ||
# For each column | ||
x_key = keys[i] | ||
|
||
# Set labels on extremes | ||
axes[-1][i].set_xlabel(x_key) # Bottom row | ||
axes[i][0].set_ylabel(keys[i+1]) # Left column | ||
|
||
# plot | ||
for j in range(i, n-1): | ||
# for each row | ||
y_key = keys[j+1] | ||
x1 = [x[x_key] for x in samples] | ||
x2 = [x[y_key] for x in samples] | ||
axes[j][i].scatter(x1, x2, **parameters) | ||
|
||
# Hide axes not used in plots | ||
for j in range(0, i): | ||
axes[j][i].set_visible(False) | ||
|
||
# Set legend | ||
if legend == 'auto' or legend: | ||
labels = [thing.get_label() for thing in axes[0][0].get_children() | ||
if isinstance(thing, PathCollection)] | ||
if legend == 'auto' and len(labels) > 0 or legend: | ||
fig.legend().remove() # Remove any existing legend - prevents "ghost effect" | ||
fig.legend(labels=labels, loc='upper right') | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. | ||
|
||
__all__ = ['test_state_estimators', 'test_predictors', 'test_examples', 'test_integration', 'test_uncertain_data'] | ||
__all__ = ['test_state_estimators', 'test_predictors', 'test_examples', 'test_integration', 'test_uncertain_data', 'test_visualize'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import unittest | ||
from prog_algs.visualize import plot_scatter | ||
|
||
|
||
class TestVisualize(unittest.TestCase): | ||
def test_scatter(self): | ||
# Nominal | ||
data = [{'x': 1, 'y': 2, 'z': 3}, {'x': 1.5, 'y': 2.2, 'z': -1}, {'x': 0.9, 'y': 2.1, 'z': 7}] | ||
fig = plot_scatter(data) | ||
fig = plot_scatter(data, fig=fig) # Add to figure | ||
plot_scatter(data, fig=fig, keys=['x', 'y', 'z']) # All keys | ||
plot_scatter(data, keys=['y', 'z']) # Subset of keys | ||
|
||
# Incorrect keys | ||
try: | ||
plot_scatter(data, keys=7) # Not iterable | ||
self.fail() | ||
except Exception: | ||
pass | ||
|
||
try: | ||
plot_scatter(data, keys=['x', 'i']) # Not present | ||
self.fail() | ||
except Exception: | ||
pass | ||
|
||
# Changing number of keys | ||
fig = plot_scatter(data) | ||
try: | ||
plot_scatter(data, fig=fig, keys=['y', 'z']) # Different number of keys | ||
self.fail() | ||
except Exception: | ||
pass | ||
|
||
# Too few keys | ||
try: | ||
plot_scatter(data, keys=['y']) # Only one key | ||
self.fail() | ||
except Exception: | ||
pass | ||
|
||
# This allows the module to be executed directly | ||
def run_tests(): | ||
l = unittest.TestLoader() | ||
runner = unittest.TextTestRunner() | ||
print("\n\nTesting Visualize") | ||
result = runner.run(l.loadTestsFromTestCase(TestVisualize)).wasSuccessful() | ||
|
||
if not result: | ||
raise Exception("Failed test") | ||
|
||
if __name__ == '__main__': | ||
run_tests() |
Oops, something went wrong.