-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add visualization of QDax repertoires (#353)
## Description <!-- Provide a brief description of the PR's purpose here. --> QDax’s repertoires are not too different from pyribs in terms of the data stored, which means our visualization tools can easily be used for both libraries. This PR makes it possible to visualize one of the main repertoires used in QDax, i.e., the MAPElitesRepertoire. The key idea is to transform the MAPElitesRepertoire into a CVTArchive and then plot the result with `cvt_archive_heatmap`. In this PR, this functionality is exposed via a new function, `ribs.visualize.qdax_repertoire_heatmap`. We update the tests and CI accordingly. Note 1: This PR does not introduce a dependency on QDax for pyribs. If a QDax repertoire is passed into the method introduced here, it can be plotted, but we do not import qdax itself at any point. Note 2: QDax does not have a separate grid archive. Instead, both grid and CVT archives from pyribs are represented with the MAPElitesRepertoire. Specifically, one can choose Euclidean centroids when constructing the repertoire in order to create a "grid archive." ## TODO <!-- Notable points that this PR has either accomplished or will accomplish. --> - [x] Write/document function - [x] Add tests in tests/visualize_qdax - [x] Set up CI for testing qdax visualization ## Questions <!-- Any concerns or points of confusion? --> ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
- Loading branch information
Showing
8 changed files
with
128 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Binary file added
BIN
+99.8 KB
.../visualize_qdax/baseline_images/visualize_qdax_test/qdax_repertoire_heatmap.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Tests for ribs.visualize that use qdax. | ||
Instructions are identical as in visualize_test.py, but images are stored in | ||
tests/visualize_qdax_test/baseline_images/visualize_qdax_test instead. | ||
""" | ||
import jax | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import pytest | ||
from matplotlib.testing.decorators import image_comparison | ||
from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire, | ||
compute_cvt_centroids) | ||
|
||
from ribs.visualize import qdax_repertoire_heatmap | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def clean_matplotlib(): | ||
"""Cleans up matplotlib figures before and after each test.""" | ||
# Before the test. | ||
plt.close("all") | ||
|
||
yield | ||
|
||
# After the test. | ||
plt.close("all") | ||
|
||
|
||
@image_comparison(baseline_images=["qdax_repertoire_heatmap"], | ||
remove_text=False, | ||
extensions=["png"]) | ||
def test_qdax_repertoire_heatmap(): | ||
plt.figure(figsize=(8, 6)) | ||
|
||
# Compute the CVT centroids. | ||
centroids, _ = compute_cvt_centroids( | ||
num_descriptors=2, | ||
num_init_cvt_samples=1000, | ||
num_centroids=100, | ||
minval=-1, | ||
maxval=1, | ||
random_key=jax.random.PRNGKey(42), | ||
) | ||
|
||
# Create initial population. | ||
init_pop_x, init_pop_y = jnp.meshgrid(jnp.linspace(-1, 1, 50), | ||
jnp.linspace(-1, 1, 50)) | ||
init_pop = jnp.stack((init_pop_x.flatten(), init_pop_y.flatten()), axis=1) | ||
|
||
# Create repertoire with the initial population inserted. | ||
repertoire = MapElitesRepertoire.init( | ||
genotypes=init_pop, | ||
# Negative sphere function. | ||
fitnesses=-jnp.sum(jnp.square(init_pop), axis=1), | ||
descriptors=init_pop, | ||
centroids=centroids, | ||
) | ||
|
||
# Plot heatmap. | ||
qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)]) |