Skip to content

Commit

Permalink
Refactor visualize tests (#370)
Browse files Browse the repository at this point in the history
## Description

<!-- Provide a brief description of the PR's purpose here. -->

This PR splits the tests for ribs.visualize into multiple files. Test
methods and baseline images are also renamed since their filepaths now
include the name of their respective visualization function.

## TODO

<!-- Notable points that this PR has either accomplished or will
accomplish. -->

- [x] Split up tests
- [x] Check that number of tests matches (93)

## Questions

<!-- Any concerns or points of confusion? -->

## Status

- [x] I have read the guidelines in

[CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md)
- [x] I have formatted my code using `yapf`
- [x] I have tested my code by running `pytest`
- [x] I have linted my code with `pylint`
- [x] I have added a one-line description of my change to the changelog
in
      `HISTORY.md`
- [x] This PR is ready to go
  • Loading branch information
btjanaka committed Sep 13, 2023
1 parent 9f292a6 commit 8efa0f2
Show file tree
Hide file tree
Showing 67 changed files with 1,045 additions and 992 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

- Improve developer workflow with pre-commit ({pr}`351`, {pr}`363`)
- Refactor visualize module into multiple files ({pr}`357`)
- Refactor visualize tests into multiple files ({pr}`370`)
- Add GitHub link roles in documentation ({pr}`361`)
- Refactor argument validation utilities ({pr}`365`)
- Use Conda envs in all CI jobs ({pr}`368`)
Expand Down
20 changes: 20 additions & 0 deletions tests/visualize/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Visualization Tests

This directory contains tests for ribs.visualize. For image comparison tests,
read
[these instructions](https://matplotlib.org/stable/devel/testing.html#writing-an-image-comparison-test).
Essentially, start by writing a test in one of the files in this directory;
let's pick `grid_archive_heatmap_test.py`. After writing this test, run it with
pytest, then go to the _root_ directory of this repo. There, you will find the
output image in `tests/visualize/baseline_images/grid_archive_heatmap_test`.
Copy this image into the `baseline_images/grid_archive_heatmap_test` directory
in this directory. Here's an example cp command:

```
cp result_images/grid_archive_heatmap_test/my_new_test.png \
tests/baseline_images/grid_archive_heatmap_test
```

Assuming the output is as expected (and assuming the code is deterministic), the
test should now pass when it is re-run. The same applies for tests in other
files; for instance, you can do the same for `cvt_archive_heatmap_test`.
95 changes: 95 additions & 0 deletions tests/visualize/args_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests to check argument validation in various visualization functions.
Many of the functions share parameters like cbar, so here we check if passing
invalid arguments to these functions results in an error.
"""
import pytest

from ribs.archives import CVTArchive, GridArchive, SlidingBoundariesArchive
from ribs.visualize import (cvt_archive_heatmap, grid_archive_heatmap,
sliding_boundaries_archive_heatmap)


@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
def test_heatmap_fails_on_unsupported_dims(archive_type):
archive = {
"grid":
lambda: GridArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
"cvt":
lambda: CVTArchive(
solution_dim=2,
cells=100,
ranges=[(-1, 1)] * 3,
samples=100,
),
"sliding":
lambda: SlidingBoundariesArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
}[archive_type]()

with pytest.raises(ValueError):
{
"grid": grid_archive_heatmap,
"cvt": cvt_archive_heatmap,
"sliding": sliding_boundaries_archive_heatmap,
}[archive_type](archive)


@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
@pytest.mark.parametrize(
"invalid_arg_cbar",
["None", 3.2, True,
(3.2, None), [3.2, None]]) # some random but invalid inputs
def test_heatmap_fails_on_invalid_cbar_option(archive_type, invalid_arg_cbar):
archive = {
"grid":
lambda: GridArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
"cvt":
lambda: CVTArchive(
solution_dim=2,
cells=100,
ranges=[(-1, 1)] * 3,
samples=100,
),
"sliding":
lambda: SlidingBoundariesArchive(
solution_dim=2,
dims=[20, 20, 20],
ranges=[(-1, 1)] * 3,
),
}[archive_type]()

with pytest.raises(ValueError):
{
"grid": grid_archive_heatmap,
"cvt": cvt_archive_heatmap,
"sliding": sliding_boundaries_archive_heatmap,
}[archive_type](archive=archive, cbar=invalid_arg_cbar)


@pytest.mark.parametrize("archive_type", ["grid", "cvt", "sliding"])
@pytest.mark.parametrize(
"invalid_arg_aspect",
["None", True, (3.2, None), [3.2, None]]) # some random but invalid inputs
def test_heatmap_fails_on_invalid_aspect_option(archive_type,
invalid_arg_aspect):
archive = {
"grid":
lambda: GridArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
"cvt":
lambda: CVTArchive(
solution_dim=2, cells=100, ranges=[(-1, 1)] * 3, samples=100),
"sliding":
lambda: SlidingBoundariesArchive(
solution_dim=2, dims=[20, 20, 20], ranges=[(-1, 1)] * 3),
}[archive_type]()

with pytest.raises(ValueError):
{
"grid": grid_archive_heatmap,
"cvt": cvt_archive_heatmap,
"sliding": sliding_boundaries_archive_heatmap,
}[archive_type](archive=archive, aspect=invalid_arg_aspect)
75 changes: 75 additions & 0 deletions tests/visualize/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Utilities for all visualization tests.
See README.md for instructions on writing tests.
"""
import matplotlib.pyplot as plt
import numpy as np
import pytest


@pytest.fixture(autouse=True)
def clean_matplotlib():
"""Cleans up matplotlib figures before and after each test."""
# Before the test.
plt.close("all")

yield

# After the test.
plt.close("all")


def add_uniform_sphere_1d(archive, x_range):
"""Adds points from the negative sphere function in a 1D grid w/ 100 elites.
The solutions are the same as the measures
x_range is a tuple of (lower_bound, upper_bound).
"""
x = np.linspace(x_range[0], x_range[1], 100)
archive.add(
solution_batch=x[:, None],
objective_batch=-x**2,
measures_batch=x[:, None],
)


def add_uniform_sphere_2d(archive, x_range, y_range):
"""Adds points from the negative sphere function in a 100x100 grid.
The solutions are the same as the measures (the (x,y) coordinates).
x_range and y_range are tuples of (lower_bound, upper_bound).
"""
xxs, yys = np.meshgrid(
np.linspace(x_range[0], x_range[1], 100),
np.linspace(y_range[0], y_range[1], 100),
)
xxs, yys = xxs.ravel(), yys.ravel()
coords = np.stack((xxs, yys), axis=1)
archive.add(
solution_batch=coords,
objective_batch=-(xxs**2 + yys**2), # Negative sphere.
measures_batch=coords,
)


def add_uniform_sphere_3d(archive, x_range, y_range, z_range):
"""Adds points from the negative sphere function in a 100x100x100 grid.
The solutions are the same as the measures (the (x,y,z) coordinates).
x_range, y_range, and z_range are tuples of (lower_bound, upper_bound).
"""
xxs, yys, zzs = np.meshgrid(
np.linspace(x_range[0], x_range[1], 40),
np.linspace(y_range[0], y_range[1], 40),
np.linspace(z_range[0], z_range[1], 40),
)
xxs, yys, zzs = xxs.ravel(), yys.ravel(), zzs.ravel()
coords = np.stack((xxs, yys, zzs), axis=1)
archive.add(
solution_batch=coords,
objective_batch=-(xxs**2 + yys**2 + zzs**2), # Negative sphere.
measures_batch=coords,
)

0 comments on commit 8efa0f2

Please sign in to comment.