Skip to content

Commit

Permalink
Improve cvt_archive_heatmap flexibility (#354)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This PR makes cvt_archive_heatmap more flexible by adding an `ec`
parameter to control the edge color of the polygons.

We also rearrange several existing arguments to more closely mirror
`grid_archive_heatmap` — specifically, we move `plot_samples`,
`plot_centroids`, and `ms` to the end of the argument list. This
reordering should not break anyone since we require keyword arguments
for this method by putting a `*` in the signature.

Regarding style arguments for samples and centroids: I have decided to
leave the sample and centroid style arguments (i.e., the arguments to
`ax.plot`) as is for now, as it seems pretty rare to plot
centroids/samples, so the added complexity of making them configurable
may not be worth it. We can always add this later if there is a need for
it.

This PR also fixes several bugs encountered while making these API
changes.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x]  Add ec parameter for controlling color of Voronoi cell boundaries
- [x]  Test style for Voronoi cells (`ec` and `lw`)
- [x] Rearrange parameters to line up more closely with
`grid_archive_heatmap`
- [x] Make background color of empty cells be transparent — this is
important when someone wants to save a transparent image — in this case,
we don’t want the background to be white
- [x] Throw an error if attempting to pass `plot_samples` when the
archive has no samples
- [x] Fix bug with division by zero error when `min_obj` and `max_obj`
are identical — since we normalize by `max_obj - min_obj` , we were
getting such errors if `min_obj` equaled `max_obj`

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
        `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 7, 2023
1 parent 25dc442 commit d4b6ea4
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 19 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Drop Python 3.7 support and upgrade dependencies (#350)
- Add visualization of QDax repertoires (#353)
- Improve cvt_archive_heatmap flexibility (#354)

#### Documentation

Expand Down
59 changes: 40 additions & 19 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import axes
from matplotlib.cm import ScalarMappable
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

Expand Down Expand Up @@ -69,7 +68,8 @@ def _validate_heatmap_visual_args(aspect, cbar, measure_dim, valid_dims,
f"Invalid arg aspect='{aspect}'; must be 'auto', 'equal', or float")
if measure_dim not in valid_dims:
raise ValueError(error_msg_measure_dim)
if not (cbar == "auto" or isinstance(cbar, axes.Axes) or cbar is None):
if not (cbar == "auto" or isinstance(cbar, matplotlib.axes.Axes) or
cbar is None):
raise ValueError(f"Invalid arg cbar={cbar}; must be 'auto', None, "
"or matplotlib.axes.Axes")

Expand All @@ -79,7 +79,7 @@ def _set_cbar(t, ax, cbar, cbar_kwargs):
cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
if cbar == "auto":
ax.figure.colorbar(t, ax=ax, **cbar_kwargs)
elif isinstance(cbar, axes.Axes):
elif isinstance(cbar, matplotlib.axes.Axes):
cbar.figure.colorbar(t, ax=cbar, **cbar_kwargs)


Expand Down Expand Up @@ -261,17 +261,18 @@ def grid_archive_heatmap(archive,
def cvt_archive_heatmap(archive,
ax=None,
*,
plot_centroids=False,
plot_samples=False,
transpose_measures=False,
cmap="magma",
aspect="auto",
ms=1,
lw=0.5,
ec="black",
vmin=None,
vmax=None,
cbar="auto",
cbar_kwargs=None):
cbar_kwargs=None,
plot_centroids=False,
plot_samples=False,
ms=1):
"""Plots heatmap of a :class:`~ribs.archives.CVTArchive` with 2D measure
space.
Expand Down Expand Up @@ -314,9 +315,6 @@ def cvt_archive_heatmap(archive,
archive (CVTArchive): A 2D :class:`~ribs.archives.CVTArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
Expand All @@ -329,8 +327,11 @@ def cvt_archive_heatmap(archive,
aspect ('auto', 'equal', float): The aspect ratio of the heatmap (i.e.
height/width). Defaults to ``'auto'``. ``'equal'`` is the same as
``aspect=1``.
ms (float): Marker size for both centroids and samples.
lw (float): Line width when plotting the voronoi diagram.
lw (float): Line width when plotting the Voronoi diagram.
ec (matplotlib color): Edge color of the cells in the Voronoi diagram.
See `here
<https://matplotlib.org/stable/tutorials/colors/colors.html>`_ for
more info on specifying colors in Matplotlib.
vmin (float): Minimum objective value to use in the plot. If ``None``,
the minimum objective value in the archive is used.
vmax (float): Maximum objective value to use in the plot. If ``None``,
Expand All @@ -342,14 +343,24 @@ def cvt_archive_heatmap(archive,
the colorbar on the specified Axes.
cbar_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.colorbar`.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
ms (float): Marker size for both centroids and samples.
Raises:
ValueError: The archive is not 2D.
ValueError: ``plot_samples`` is passed in but the archive does not have
samples (e.g., due to using custom centroids during construction).
"""
_validate_heatmap_visual_args(
aspect, cbar, archive.measure_dim, [2],
"Heatmaps can only be plotted for 2D CVTArchive")

if plot_samples and archive.samples is None:
raise ValueError("Samples are not available for this archive, but "
"`plot_samples` was passed in.")

if aspect is None:
aspect = "auto"

Expand All @@ -360,12 +371,15 @@ def cvt_archive_heatmap(archive,
lower_bounds = archive.lower_bounds
upper_bounds = archive.upper_bounds
centroids = archive.centroids
samples = archive.samples
if transpose_measures:
lower_bounds = np.flip(lower_bounds)
upper_bounds = np.flip(upper_bounds)
centroids = np.flip(centroids, axis=1)
samples = np.flip(samples, axis=1)

if plot_samples:
samples = archive.samples
if transpose_measures:
samples = np.flip(samples, axis=1)

# Retrieve and initialize the axis.
ax = plt.gca() if ax is None else ax
Expand Down Expand Up @@ -404,6 +418,10 @@ def cvt_archive_heatmap(archive,
min_obj = min_obj if vmin is None else vmin
max_obj = max_obj if vmax is None else vmax

# If the min and max are the same, we set a sensible default range.
if min_obj == max_obj:
min_obj, max_obj = min_obj - 0.01, max_obj + 0.01

# Shade the regions.
#
# Note: by default, the first region will be an empty list -- see:
Expand All @@ -415,23 +433,26 @@ def cvt_archive_heatmap(archive,
# `polygon` is also O(n) anyway.
if -1 not in region:
if objective is None:
color = "white"
# Transparent white (RGBA format) -- this ensures that if a
# figure is saved with a transparent background, the empty cells
# will also be transparent.
color = (1.0, 1.0, 1.0, 0.0)
else:
normalized_obj = np.clip(
(objective - min_obj) / (max_obj - min_obj), 0.0, 1.0)
color = cmap(normalized_obj)
polygon = [vor.vertices[i] for i in region]
ax.fill(*zip(*polygon), color=color, ec="k", lw=lw)
polygon = vor.vertices[region]
ax.fill(*zip(*polygon), color=color, ec=ec, lw=lw)

# Create a colorbar.
mappable = ScalarMappable(cmap=cmap)
mappable.set_clim(min_obj, max_obj)

# Plot the sample points and centroids.
if plot_samples:
ax.plot(samples[:, 0], samples[:, 1], "o", c="gray", ms=ms)
ax.plot(samples[:, 0], samples[:, 1], "o", c="grey", ms=ms)
if plot_centroids:
ax.plot(centroids[:, 0], centroids[:, 1], "ko", ms=ms)
ax.plot(centroids[:, 0], centroids[:, 1], "o", c="black", ms=ms)

# Create color bar.
_set_cbar(mappable, ax, cbar, cbar_kwargs)
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions tests/visualize/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,14 @@ def test_sliding_archive_mismatch_xy_with_boundaries():
sliding_boundaries_archive_heatmap(archive, boundary_lw=0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_vmin_equals_vmax"],
remove_text=False,
extensions=["png"])
def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-0.5, vmax=-0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_centroids"],
remove_text=False,
extensions=["png"])
Expand All @@ -628,6 +636,26 @@ def test_cvt_archive_heatmap_with_samples(cvt_archive):
cvt_archive_heatmap(cvt_archive, plot_samples=True)


def test_cvt_archive_heatmap_no_samples_error():
# This archive has no samples since custom centroids were passed in.
archive = CVTArchive(solution_dim=2,
cells=2,
ranges=[(-1, 1), (-1, 1)],
custom_centroids=[[0, 0], [1, 1]])

# Thus, plotting samples on this archive should fail.
with pytest.raises(ValueError):
cvt_archive_heatmap(archive, lw=3.0, ec="grey", plot_samples=True)


@image_comparison(baseline_images=["cvt_archive_heatmap_voronoi_style"],
remove_text=False,
extensions=["png"])
def test_cvt_archive_heatmap_voronoi_style(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey")


#
# Parallel coordinate plot test
#
Expand Down

0 comments on commit d4b6ea4

Please sign in to comment.