Skip to content

Commit

Permalink
Prevent Optimizer from taking in non-unique emitter instances (#75)
Browse files Browse the repository at this point in the history
* Add test for non-unique emitters

* Throw ValueError on non-unique emitter instances in Optimizer
  • Loading branch information
btjanaka committed Feb 3, 2021
1 parent 81c38c5 commit 5d36987
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
18 changes: 18 additions & 0 deletions ribs/optimizers/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class Optimizer:
solutions with the same dimension (that is, their ``solution_dim`` attribute
must be the same).
.. warning:: If you are constructing many emitters at once, do not do
something like ``[EmitterClass(...)] * 5``, as this creates a list with
the same instance of ``EmitterClass`` in each position. Instead, use
``[EmitterClass(...) for _ in range 5]``, which creates 5 unique
instances of ``EmitterClass``.
Args:
archive (ribs.archives.ArchiveBase): An archive object, selected from
:mod:`ribs.archives`.
Expand All @@ -25,11 +31,23 @@ class Optimizer:
ValueError: The emitters passed in do not have the same solution
dimensions.
ValueError: There is no emitter passed in.
ValueError: The same emitter instance was passed in multiple times. Each
emitter should be a unique instance (see the warning above).
"""

def __init__(self, archive, emitters):
if len(emitters) == 0:
raise ValueError("Pass in at least one emitter to the optimizer.")

emitter_ids = set(id(e) for e in emitters)
if len(emitter_ids) != len(emitters):
raise ValueError(
"Not all emitters passed in were unique (i.e. some emitters "
"had the same id). If emitters were created with something "
"like [EmitterClass(...)] * n, instead use "
"[EmitterClass(...) for _ in range(n)] so that all emitters "
"are unique instances.")

self._solution_dim = emitters[0].solution_dim

for idx, emitter in enumerate(emitters[1:]):
Expand Down
13 changes: 12 additions & 1 deletion tests/core/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ def test_init_fails_with_no_emitters():
Optimizer(archive, emitters)


def test_init_fails_on_non_unique_emitter_instances():
archive = GridArchive([100, 100], [(-1, 1), (-1, 1)])

# All emitters are the same instance. This is bad because the same emitter
# gets called multiple times.
emitters = [GaussianEmitter(archive, [0.0, 0.0], 1, batch_size=1)] * 5

with pytest.raises(ValueError):
Optimizer(archive, emitters)


def test_init_fails_with_mismatched_emitters():
archive = GridArchive([100, 100], [(-1, 1), (-1, 1)])
emitters = [
Expand Down Expand Up @@ -64,7 +75,7 @@ def test_tell_inserts_solutions_into_archive(_optimizer_fixture):
assert len(optimizer.archive.as_pandas()) == num_solutions


def test_tell_inserts_solutions_with_multiple_emitters(_optimizer_fixture):
def test_tell_inserts_solutions_with_multiple_emitters():
archive = GridArchive([100, 100], [(-1, 1), (-1, 1)])
emitters = [
GaussianEmitter(archive, [0.0, 0.0], 1, batch_size=1),
Expand Down

0 comments on commit 5d36987

Please sign in to comment.