Skip to content

Commit

Permalink
Add visualization of QDax repertoires (#353)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

QDax’s repertoires are not too different from pyribs in terms of the
data stored, which means our visualization tools can easily be used for
both libraries. This PR makes it possible to visualize one of the main
repertoires used in QDax, i.e., the MAPElitesRepertoire. The key idea is
to transform the MAPElitesRepertoire into a CVTArchive and then plot the
result with `cvt_archive_heatmap`. In this PR, this functionality is
exposed via a new function, `ribs.visualize.qdax_repertoire_heatmap`. We
update the tests and CI accordingly.

Note 1: This PR does not introduce a dependency on QDax for pyribs. If a
QDax repertoire is passed into the method introduced here, it can be
plotted, but we do not import qdax itself at any point.

Note 2: QDax does not have a separate grid archive. Instead, both grid
and CVT archives from pyribs are represented with the
MAPElitesRepertoire. Specifically, one can choose Euclidean centroids
when constructing the repertoire in order to create a "grid archive."

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x]  Write/document function
- [x]  Add tests in tests/visualize_qdax
- [x]  Set up CI for testing qdax visualization

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 7, 2023
1 parent b411612 commit 25dc442
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 3 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ jobs:
pytest tests/archives tests/emitters tests/schedulers
- name: Install extras deps
run: pip install -r pinned_reqs/extras_visualize.txt
- name: Test extras
- name: Test visualize extra
run: pytest tests/visualize
- name: Install QDax
run: pip install qdax
- name: Test visualize extra for QDax
run: pytest tests/visualize_qdax
coverage:
runs-on: ubuntu-latest
steps:
Expand All @@ -87,7 +91,11 @@ jobs:
- name: Test coverage
env:
NUMBA_DISABLE_JIT: 1
run: pytest tests
# Exclude `visualize_qdax` since we don't install QDax here. We also
# exclude `tests` since we don't want the base directory here.
run:
pytest $(find tests -maxdepth 1 -type d -not -name 'tests' -not -name
'visualize_qdax')
benchmarks:
runs-on: ubuntu-latest
steps:
Expand Down
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### API

- Drop Python 3.7 support and upgrade dependencies (#350)
- Add visualization of QDax repertoires (#353)

#### Documentation

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,5 @@
"python": ("https://docs.python.org/3/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"qdax": ("https://qdax.readthedocs.io/en/latest/", None),
}
50 changes: 50 additions & 0 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from matplotlib.cm import ScalarMappable
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

from ribs.archives import CVTArchive

# Matplotlib functions tend to have a ton of args.
# pylint: disable = too-many-arguments

Expand Down Expand Up @@ -782,3 +784,51 @@ def parallel_axes_plot(archive,
ax=host_ax,
pad=cbar_pad,
orientation=cbar_orientation)


def qdax_repertoire_heatmap(
repertoire,
ranges,
*args,
**kwargs,
):
# pylint: disable = line-too-long
"""Plots a heatmap of a QDax MapElitesRepertoire.
Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into
a :class:`~ribs.archives.CVTArchive` and plots it with
:meth:`cvt_archive_heatmap`.
Args:
repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire):
A MAP-Elites repertoire output by an algorithm in QDax.
ranges (array-like of (float, float)): Upper and lower bound of each
dimension of the measure space, e.g. ``[(-1, 1), (-2, 2)]``
indicates the first dimension should have bounds :math:`[-1,1]`
(inclusive), and the second dimension should have bounds
:math:`[-2,2]` (inclusive).
*args: Positional arguments to pass to :meth:`cvt_archive_heatmap`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_heatmap`.
"""
# pylint: enable = line-too-long

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

# Plot the archive.
cvt_archive_heatmap(cvt_archive, *args, **kwargs)
7 changes: 6 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Tests

This directory contains tests and micro-benchmarks for ribs. The tests mirror
This directory contains tests and micro-benchmarks for pyribs. The tests mirror
the directory structure of `ribs`. To run these tests, install the dev
dependencies for ribs with `pip install ribs[dev]` or `pip install -e .[dev]`
(from the root directory of the repo).

For information on running tests, see [CONTRIBUTING.md](../CONTRIBUTING.md).

## Visualization Tests

We divide the visualization tests into `visualize` and `visualize_qdax`, where
`visualize_qdax` tests visualizations of QDax components.

## Additional Tests

This directory also contains:
Expand Down
Empty file.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tests for ribs.visualize that use qdax.
Instructions are identical as in visualize_test.py, but images are stored in
tests/visualize_qdax_test/baseline_images/visualize_qdax_test instead.
"""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pytest
from matplotlib.testing.decorators import image_comparison
from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire,
compute_cvt_centroids)

from ribs.visualize import qdax_repertoire_heatmap


@pytest.fixture(autouse=True)
def clean_matplotlib():
"""Cleans up matplotlib figures before and after each test."""
# Before the test.
plt.close("all")

yield

# After the test.
plt.close("all")


@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"])
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

# Compute the CVT centroids.
centroids, _ = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=1000,
num_centroids=100,
minval=-1,
maxval=1,
random_key=jax.random.PRNGKey(42),
)

# Create initial population.
init_pop_x, init_pop_y = jnp.meshgrid(jnp.linspace(-1, 1, 50),
jnp.linspace(-1, 1, 50))
init_pop = jnp.stack((init_pop_x.flatten(), init_pop_y.flatten()), axis=1)

# Create repertoire with the initial population inserted.
repertoire = MapElitesRepertoire.init(
genotypes=init_pop,
# Negative sphere function.
fitnesses=-jnp.sum(jnp.square(init_pop), axis=1),
descriptors=init_pop,
centroids=centroids,
)

# Plot heatmap.
qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)])

0 comments on commit 25dc442

Please sign in to comment.