Skip to content

Commit

Permalink
Speed up 2D cvt_archive_heatmap by order of magnitude (#355)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

Currently, cvt_archive_heatmap plots individual polygons via ax.fill .
We can speed this up by instead using a
[`PolyCollection`](https://matplotlib.org/stable/api/collections_api.html#matplotlib.collections.PolyCollection)
to add all the polygons at once. This is similar to using a
`PatchCollection` as shown here:
https://matplotlib.org/stable/gallery/shapes_and_collections/patch_collection.html.

Benchmark for plotting a `CVTArchive` with 10,000 cells:

- Before: 14.9 sec
- After: 0.6 sec

I used the following code to benchmark the implementation:

```python
"""Driver for cvt heatmap experiments."""

import time

import fire
import matplotlib.pyplot as plt
import numpy as np

from ribs.archives import CVTArchive
from ribs.visualize import cvt_archive_heatmap


def main(n_cells=10000):
    """Creates the archive and plots it."""
    np.random.seed(42)

    archive = CVTArchive(
        solution_dim=3,
        cells=n_cells,
        ranges=[(-1, 1), (-1, 1)],
        custom_centroids=np.random.uniform(-1, 1, (n_cells, 2)),
    )

    archive.add(
        np.random.uniform(-1, 1, (20000, 3)),
        np.random.standard_normal(20000),
        np.random.uniform(-1, 1, (20000, 2)),
    )

    plt.figure(figsize=(8, 6))

    start_time = time.time()
    cvt_archive_heatmap(archive)
    print("Plot time", time.time() - start_time)

    plt.savefig("cvt.png")


if __name__ == "__main__":
    fire.Fire(main)
```

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Speed up 2D polygon plotting by using matplotlib PolyCollection —
note that I initially used a PatchCollection with individual Polygon
patches, but PolyCollection is much faster because we do not have to
construct the individual `Polygon` patches in Python.
- [x]  Compute facecolors in a batch instead of individually
- [x] Fix test errors — it seems the images changed slightly due to the
new implementation, so we now allow a slight tolerance for cvt heatmap
images

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 7, 2023
1 parent d4b6ea4 commit 7048c36
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 33 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Drop Python 3.7 support and upgrade dependencies (#350)
- Add visualization of QDax repertoires (#353)
- Improve cvt_archive_heatmap flexibility (#354)
- Speed up 2D cvt_archive_heatmap by order of magnitude (#355)

#### Documentation

Expand Down
69 changes: 49 additions & 20 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,27 +422,56 @@ def cvt_archive_heatmap(archive,
if min_obj == max_obj:
min_obj, max_obj = min_obj - 0.01, max_obj + 0.01

# Shade the regions.
#
# Note: by default, the first region will be an empty list -- see:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html
# However, this empty region is ignored by ax.fill since `polygon` is also
# an empty list in this case.
# Vertices of all cells.
vertices = []
# The facecolor of each cell. Shape (n_regions, 4) for RGBA format, but we
# do not know n_regions in advance.
facecolors = []
# Boolean array indicating which of the facecolors needs to be computed with
# the cmap. The other colors correspond to empty cells. Shape (n_regions,)
facecolor_cmap_mask = []
# The objective corresponding to the regions which must be passed through
# the cmap. Shape (sum(facecolor_cmap_mask),)
facecolor_objs = []

# Cycle through the regions to set up polygon vertices and facecolors.
for region, objective in zip(vor.regions, region_obj):
# This check is O(n), but n is typically small, and creating
# `polygon` is also O(n) anyway.
if -1 not in region:
if objective is None:
# Transparent white (RGBA format) -- this ensures that if a
# figure is saved with a transparent background, the empty cells
# will also be transparent.
color = (1.0, 1.0, 1.0, 0.0)
else:
normalized_obj = np.clip(
(objective - min_obj) / (max_obj - min_obj), 0.0, 1.0)
color = cmap(normalized_obj)
polygon = vor.vertices[region]
ax.fill(*zip(*polygon), color=color, ec=ec, lw=lw)
# Checking for -1 is O(n), but n is typically small.
#
# We check length since the first region is an empty list by default:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html
if -1 in region or len(region) == 0:
continue

if objective is None:
# Transparent white (RGBA format) -- this ensures that if a figure
# is saved with a transparent background, the empty cells will also
# be transparent.
facecolors.append(np.array([1.0, 1.0, 1.0, 0.0]))
facecolor_cmap_mask.append(False)
else:
facecolors.append(np.empty(4))
facecolor_cmap_mask.append(True)
facecolor_objs.append(objective)

vertices.append(vor.vertices[region])

# Compute facecolors from the cmap. We first normalize the objectives and
# clip them to [0, 1].
normalized_objs = np.clip(
(np.asarray(facecolor_objs) - min_obj) / (max_obj - min_obj), 0.0, 1.0)
facecolors = np.asarray(facecolors)
facecolors[facecolor_cmap_mask] = cmap(normalized_objs)

# Plot the collection on the axes. Note that this is faster than plotting
# each polygon individually with ax.fill().
ax.add_collection(
matplotlib.collections.PolyCollection(
vertices,
edgecolors=ec,
facecolors=facecolors,
linewidths=lw,
))

# Create a colorbar.
mappable = ScalarMappable(cmap=cmap)
Expand Down
41 changes: 29 additions & 12 deletions tests/visualize/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

# pylint: disable = redefined-outer-name

# Tolerance for root mean square difference between the pixels of the images,
# where 255 is the max value. We only have tolerance for `cvt_archive_heatmap`
# since it is a bit more finicky than the other plots.
CVT_IMAGE_TOLERANCE = 0.1


@pytest.fixture(autouse=True)
def clean_matplotlib():
Expand Down Expand Up @@ -379,7 +384,8 @@ def test_heatmap_archive__grid_custom_cbar_axis(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_archive__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive)
Expand All @@ -403,7 +409,8 @@ def test_heatmap_with_custom_axis__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_with_custom_axis__cvt(cvt_archive):
_, ax = plt.subplots(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, ax=ax)
Expand All @@ -427,7 +434,8 @@ def test_heatmap_long__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive)
Expand All @@ -451,7 +459,8 @@ def test_heatmap_long_square__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long_square"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long_square__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive, aspect="equal")
Expand All @@ -475,7 +484,8 @@ def test_heatmap_long_transpose__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long_transpose"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long_transpose__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive, transpose_measures=True)
Expand All @@ -502,7 +512,8 @@ def test_heatmap_with_limits__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_limits"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_with_limits__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-1.0, vmax=-0.5)
Expand All @@ -527,7 +538,8 @@ def test_heatmap_listed_cmap__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_listed_cmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_listed_cmap__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
Expand All @@ -553,7 +565,8 @@ def test_heatmap_coolwarm_cmap__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_coolwarm_cmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_coolwarm_cmap__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, cmap="coolwarm")
Expand Down Expand Up @@ -614,23 +627,26 @@ def test_sliding_archive_mismatch_xy_with_boundaries():

@image_comparison(baseline_images=["cvt_archive_heatmap_vmin_equals_vmax"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-0.5, vmax=-0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_centroids"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_with_centroids(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, plot_centroids=True)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_samples"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_with_samples(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, plot_samples=True)
Expand All @@ -650,7 +666,8 @@ def test_cvt_archive_heatmap_no_samples_error():

@image_comparison(baseline_images=["cvt_archive_heatmap_voronoi_style"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_voronoi_style(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey")
Expand Down
3 changes: 2 additions & 1 deletion tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def clean_matplotlib():

@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in visualize_test.py
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

Expand Down

0 comments on commit 7048c36

Please sign in to comment.