Skip to content

Commit

Permalink
Add visualization of 3D QDax repertoires (#373)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

Similar to #353 that introduced `qdax_repertoire_heatmap`, this PR
introduces `qdax_repertoire_3d_plot` for using `cvt_archive_3d_plot`
from #371 to visualize QDax repertoires with 3D measure space.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Implement
- [x] Test
- [x] Refactor to put all qdax vis in one file

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 14, 2023
1 parent b843a0d commit 4f8637c
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 24 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Add `rasterized` arg for heatmaps (#359)
- Support 1D cvt_archive_heatmap ({pr}`362`)
- Add 3D plots for CVTArchive ({pr}`371`)
- Add visualization of 3D QDax repertoires ({pr}`372`)

#### Documentation

Expand Down
5 changes: 4 additions & 1 deletion ribs/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,24 @@
ribs.visualize.grid_archive_heatmap
ribs.visualize.parallel_axes_plot
ribs.visualize.sliding_boundaries_archive_heatmap
ribs.visualize.qdax_repertoire_3d_plot
ribs.visualize.qdax_repertoire_heatmap
"""
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap
from ribs.visualize._grid_archive_heatmap import grid_archive_heatmap
from ribs.visualize._parallel_axes_plot import parallel_axes_plot
from ribs.visualize._qdax_repertoire_heatmap import qdax_repertoire_heatmap
from ribs.visualize._sliding_boundaries_archive_heatmap import \
sliding_boundaries_archive_heatmap
from ribs.visualize._visualize_qdax import (qdax_repertoire_3d_plot,
qdax_repertoire_heatmap)

__all__ = [
"cvt_archive_3d_plot",
"cvt_archive_heatmap",
"grid_archive_heatmap",
"parallel_axes_plot",
"sliding_boundaries_archive_heatmap",
"qdax_repertoire_3d_plot",
"qdax_repertoire_heatmap",
]
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
"""Provides qdax_repertoire_heatmap."""
"""Provides visualization functions for QDax repertoires."""
import numpy as np

from ribs.archives import CVTArchive
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap


def _as_cvt_archive(repertoire, ranges):
"""Converts a QDax repertoire into a CVTArchive."""

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

return cvt_archive


def qdax_repertoire_heatmap(
repertoire,
ranges,
Expand Down Expand Up @@ -33,22 +57,36 @@ def qdax_repertoire_heatmap(
"""
# pylint: enable = line-too-long

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)
cvt_archive_heatmap(_as_cvt_archive(repertoire, ranges), *args, **kwargs)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

# Plot the archive.
cvt_archive_heatmap(cvt_archive, *args, **kwargs)
def qdax_repertoire_3d_plot(
repertoire,
ranges,
*args,
**kwargs,
):
# pylint: disable = line-too-long
"""Plots a QDax MapElitesRepertoire with 3D measure space.
Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into
a :class:`~ribs.archives.CVTArchive` and plots it with
:meth:`cvt_archive_3d_plot`.
Args:
repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire):
A MAP-Elites repertoire output by an algorithm in QDax.
ranges (array-like of (float, float)): Upper and lower bound of each
dimension of the measure space, e.g. ``[(-1, 1), (-2, 2), (-3, 3)]``
indicates the first dimension should have bounds :math:`[-1,1]`
(inclusive), the second dimension should have bounds :math:`[-2,2]`,
and the third dimension should have bounds :math:`[-3,3]`
(inclusive). This is needed since the MapElitesRepertoire does not
store measure space bounds.
*args: Positional arguments to pass to :meth:`cvt_archive_3d_plot`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_3d_plot`.
"""
# pylint: enable = line-too-long

cvt_archive_3d_plot(_as_cvt_archive(repertoire, ranges), *args, **kwargs)
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 47 additions & 5 deletions tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire,
compute_cvt_centroids)

from ribs.visualize import qdax_repertoire_heatmap
from ribs.visualize import qdax_repertoire_3d_plot, qdax_repertoire_heatmap


@pytest.fixture(autouse=True)
Expand All @@ -26,10 +26,11 @@ def clean_matplotlib():
plt.close("all")


@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in visualize_test.py
@image_comparison(
baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in cvt_archive_heatmap_test.py
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

Expand Down Expand Up @@ -59,3 +60,44 @@ def test_qdax_repertoire_heatmap():

# Plot heatmap.
qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)])


@image_comparison(
baseline_images=["qdax_repertoire_3d_plot"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in cvt_archive_3d_plot_test.py
def test_qdax_repertoire_3d_plot():
plt.figure(figsize=(8, 6))

random_key = jax.random.PRNGKey(42)

# Compute the CVT centroids.
random_key, subkey = jax.random.split(random_key)
centroids, _ = compute_cvt_centroids(
num_descriptors=3,
num_init_cvt_samples=1000,
num_centroids=500,
minval=-1,
maxval=1,
random_key=subkey,
)

# Create initial population.
random_key, *subkeys = jax.random.split(random_key, 4)
x = jax.random.uniform(subkeys[0], (10000,), minval=-1.0, maxval=1.0)
y = jax.random.uniform(subkeys[1], (10000,), minval=-1.0, maxval=1.0)
z = jax.random.uniform(subkeys[2], (10000,), minval=-1.0, maxval=1.0)
init_pop = jnp.stack((x, y, z), axis=1)

# Create repertoire with the initial population inserted.
repertoire = MapElitesRepertoire.init(
genotypes=init_pop,
# Negative sphere function.
fitnesses=-jnp.sum(jnp.square(init_pop), axis=1),
descriptors=init_pop,
centroids=centroids,
)

# Plot heatmap.
qdax_repertoire_3d_plot(repertoire, ranges=[(-1, 1), (-1, 1), (-1, 1)])

0 comments on commit 4f8637c

Please sign in to comment.