Skip to content

Commit

Permalink
Improve visualization docs examples (#372)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

Miscellaneous edits to the examples in the visualization docs.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Use np.random instead of linspace/meshgrid to make it easier to
understand
- [x] Add captions to plots
- [x] Add 1D grid_archive_heatmap example

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 14, 2023
1 parent 251a9b0 commit b843a0d
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 20 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- pip install swig before gymnasium[box2d] in lunar lander tutorial ({pr}`346`)
- Fix lunar lander dependency issues ({pr}`366`, {pr}`367`)
- Simplify DQD tutorial imports ({pr}`369`)
- Improve visualization docs examples ({pr}`372`)

#### Improvements

Expand Down
4 changes: 2 additions & 2 deletions ribs/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
ribs.visualize.cvt_archive_heatmap
ribs.visualize.grid_archive_heatmap
ribs.visualize.parallel_axes_plot
ribs.visualize.qdax_repertoire_heatmap
ribs.visualize.sliding_boundaries_archive_heatmap
ribs.visualize.qdax_repertoire_heatmap
"""
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap
Expand All @@ -35,6 +35,6 @@
"cvt_archive_heatmap",
"grid_archive_heatmap",
"parallel_axes_plot",
"qdax_repertoire_heatmap",
"sliding_boundaries_archive_heatmap",
"qdax_repertoire_heatmap",
]
17 changes: 10 additions & 7 deletions ribs/visualize/_cvt_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ def cvt_archive_heatmap(archive,
.. plot::
:context: close-figs
Heatmap of a 2D CVTArchive
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import CVTArchive
>>> from ribs.visualize import cvt_archive_heatmap
>>> # Populate the archive with the negative sphere function.
>>> archive = CVTArchive(solution_dim=2,
... cells=100, ranges=[(-1, 1), (-1, 1)])
>>> x = y = np.linspace(-1, 1, 100)
>>> xxs, yys = np.meshgrid(x, y)
>>> xxs, yys = xxs.flatten(), yys.flatten()
>>> archive.add(solution_batch=np.stack((xxs, yys), axis=1),
... objective_batch=-(xxs**2 + yys**2),
... measures_batch=np.stack((xxs, yys), axis=1))
>>> x = np.random.uniform(-1, 1, 10000)
>>> y = np.random.uniform(-1, 1, 10000)
>>> archive.add(solution_batch=np.stack((x, y), axis=1),
... objective_batch=-(x**2 + y**2),
... measures_batch=np.stack((x, y), axis=1))
>>> # Plot a heatmap of the archive.
>>> plt.figure(figsize=(8, 6))
>>> cvt_archive_heatmap(archive)
Expand All @@ -73,14 +74,16 @@ def cvt_archive_heatmap(archive,
.. plot::
:context: close-figs
Heatmap of a 1D CVTArchive
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import CVTArchive
>>> from ribs.visualize import cvt_archive_heatmap
>>> # Populate the archive with the negative sphere function.
>>> archive = CVTArchive(solution_dim=2,
... cells=20, ranges=[(-1, 1)])
>>> x = np.linspace(-1, 1, 100)
>>> x = np.random.uniform(-1, 1, 1000)
>>> archive.add(solution_batch=np.stack((x, x), axis=1),
... objective_batch=-x**2,
... measures_batch=x[:, None])
Expand Down
36 changes: 30 additions & 6 deletions ribs/visualize/_grid_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def grid_archive_heatmap(archive,
.. plot::
:context: close-figs
Heatmap of a 2D GridArchive
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import GridArchive
Expand All @@ -44,12 +46,11 @@ def grid_archive_heatmap(archive,
>>> archive = GridArchive(solution_dim=2,
... dims=[20, 20],
... ranges=[(-1, 1), (-1, 1)])
>>> x = y = np.linspace(-1, 1, 100)
>>> xxs, yys = np.meshgrid(x, y)
>>> xxs, yys = xxs.flatten(), yys.flatten()
>>> archive.add(solution_batch=np.stack((xxs, yys), axis=1),
... objective_batch=-(xxs**2 + yys**2),
... measures_batch=np.stack((xxs, yys), axis=1))
>>> x = np.random.uniform(-1, 1, 10000)
>>> y = np.random.uniform(-1, 1, 10000)
>>> archive.add(solution_batch=np.stack((x, y), axis=1),
... objective_batch=-(x**2 + y**2),
... measures_batch=np.stack((x, y), axis=1))
>>> # Plot a heatmap of the archive.
>>> plt.figure(figsize=(8, 6))
>>> grid_archive_heatmap(archive)
Expand All @@ -58,6 +59,29 @@ def grid_archive_heatmap(archive,
>>> plt.ylabel("y coords")
>>> plt.show()
.. plot::
:context: close-figs
Heatmap of a 1D GridArchive
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from ribs.archives import GridArchive
>>> from ribs.visualize import grid_archive_heatmap
>>> # Populate the archive with the negative sphere function.
>>> archive = GridArchive(solution_dim=2,
... dims=[20], ranges=[(-1, 1)])
>>> x = np.random.uniform(-1, 1, 1000)
>>> archive.add(solution_batch=np.stack((x, x), axis=1),
... objective_batch=-x**2,
... measures_batch=x[:, None])
>>> # Plot a heatmap of the archive.
>>> plt.figure(figsize=(8, 6))
>>> grid_archive_heatmap(archive)
>>> plt.title("Negative sphere function with 1D measures")
>>> plt.xlabel("x coords")
>>> plt.show()
Args:
archive (GridArchive): A 1D or 2D :class:`~ribs.archives.GridArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
Expand Down
9 changes: 4 additions & 5 deletions ribs/visualize/_sliding_boundaries_archive_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ def sliding_boundaries_archive_heatmap(archive,
... ranges=[(-1, 1), (-1, 1)],
... seed=42)
>>> # Populate the archive with the negative sphere function.
>>> rng = np.random.default_rng(seed=10)
>>> coords = np.clip(rng.standard_normal((1000, 2)), -1.5, 1.5)
>>> archive.add(solution_batch=coords,
... objective_batch=-np.sum(coords**2, axis=1),
... measures_batch=coords)
>>> xy = np.clip(np.random.standard_normal((1000, 2)), -1.5, 1.5)
>>> archive.add(solution_batch=xy,
... objective_batch=-np.sum(xy**2, axis=1),
... measures_batch=xy)
>>> # Plot heatmaps of the archive.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,6))
>>> fig.suptitle("Negative sphere function")
Expand Down

0 comments on commit b843a0d

Please sign in to comment.