Skip to content

Commit

Permalink
Add emitter_kwargs to optimizer ask and tell (#159)
Browse files Browse the repository at this point in the history
- Add emitter_kwargs
- Update docstrings
- Add tests
  • Loading branch information
btjanaka committed Jul 7, 2021
1 parent 1d8df94 commit 05f4133
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 9 deletions.
55 changes: 47 additions & 8 deletions ribs/optimizers/_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Provides the Optimizer."""
import itertools

import numpy as np
from threadpoolctl import threadpool_limits

Expand Down Expand Up @@ -81,12 +83,30 @@ def emitters(self):
in this optimizer."""
return self._emitters

def ask(self):
@staticmethod
def _process_emitter_kwargs(emitter_kwargs):
"""Converts emitter_kwargs to an iterable so it can zip with the
emitters."""
if emitter_kwargs is None:
return itertools.repeat({})
if isinstance(emitter_kwargs, dict):
return itertools.repeat(emitter_kwargs)
return emitter_kwargs # Assume it is a list/iterable of dicts.

def ask(self, emitter_kwargs=None):
"""Generates a batch of solutions by calling ask() on all emitters.
.. note:: The order of the solutions returned from this method is
important, so do not rearrange them.
Args:
emitter_kwargs (dict or list of dict): kwargs to pass to the
emitters' :meth:`~ribs.emitters.EmitterBase.ask` method. If one
dict is passed in, its kwargs are passed to all the emitters. If
a list of dicts is passed in, each dict is passed to each
emitter (e.g. ``dict[0]`` goes to :attr:`emitters` [0]).
Emitters are in the same order as they were when the optimizer
was constructed.
Returns:
(n_solutions, dim) array: An array of n solutions to evaluate. Each
row contains a single solution.
Expand All @@ -99,19 +119,25 @@ def ask(self):
self._asked = True

self._solutions = []
emitter_kwargs = self._process_emitter_kwargs(emitter_kwargs)

# Limit OpenBLAS to single thread. This is typically faster than
# multithreading because our data is too small.
with threadpool_limits(limits=1, user_api="blas"):
for i, emitter in enumerate(self._emitters):
emitter_sols = emitter.ask()
for i, (emitter,
kwargs) in enumerate(zip(self._emitters, emitter_kwargs)):
emitter_sols = emitter.ask(**kwargs)
self._solutions.append(emitter_sols)
self._num_emitted[i] = len(emitter_sols)

self._solutions = np.concatenate(self._solutions, axis=0)
return self._solutions

def tell(self, objective_values, behavior_values, metadata=None):
def tell(self,
objective_values,
behavior_values,
metadata=None,
emitter_kwargs=None):
"""Returns info for solutions from :meth:`ask`.
.. note:: The objective values, behavior values, and metadata must be in
Expand All @@ -127,6 +153,13 @@ def tell(self, objective_values, behavior_values, metadata=None):
this array contains a solution's coordinates in behavior space.
metadata ((n_solutions,) array): Each entry of this array contains
an object holding metadata for a solution.
emitter_kwargs (dict or list of dict): kwargs to pass to the
emitters' :meth:`~ribs.emitters.EmitterBase.tell` method. If one
dict is passed in, its kwargs are passed to all the emitters. If
a list of dicts is passed in, each dict is passed to each
emitter (e.g. ``dict[0]`` goes to :attr:`emitters` [0]).
Emitters are in the same order as they were when the optimizer
was constructed.
Raises:
RuntimeError: This method is called without first calling
:meth:`ask`.
Expand All @@ -135,6 +168,7 @@ def tell(self, objective_values, behavior_values, metadata=None):
raise RuntimeError("tell() was called without calling ask().")
self._asked = False

emitter_kwargs = self._process_emitter_kwargs(emitter_kwargs)
objective_values = np.asarray(objective_values)
behavior_values = np.asarray(behavior_values)
metadata = (np.empty(len(self._solutions), dtype=object)
Expand All @@ -145,9 +179,14 @@ def tell(self, objective_values, behavior_values, metadata=None):
with threadpool_limits(limits=1, user_api="blas"):
# Keep track of pos because emitters may have different batch sizes.
pos = 0
for emitter, n in zip(self._emitters, self._num_emitted):
for emitter, n, kwargs in zip(self._emitters, self._num_emitted,
emitter_kwargs):
end = pos + n
emitter.tell(self._solutions[pos:end],
objective_values[pos:end],
behavior_values[pos:end], metadata[pos:end])
emitter.tell(
self._solutions[pos:end],
objective_values[pos:end],
behavior_values[pos:end],
metadata[pos:end],
**kwargs,
)
pos = end
89 changes: 88 additions & 1 deletion tests/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from ribs.archives import GridArchive
from ribs.emitters import GaussianEmitter
from ribs.emitters import EmitterBase, GaussianEmitter
from ribs.optimizers import Optimizer

# pylint: disable = redefined-outer-name
Expand Down Expand Up @@ -126,3 +126,90 @@ def test_tell_fails_when_ask_not_called(optimizer_fixture):
optimizer, *_ = optimizer_fixture
with pytest.raises(RuntimeError):
optimizer.tell(None, None)


@pytest.fixture
def kwargs_fixture():
"""Fixture for testing emitter_kwargs in the optimizer."""

class KwargsEmitter(EmitterBase):
"""Emitter which takes in kwargs in its ask() and tell() methods.
ask() and tell() simply set self.arg to be the value of arg.
"""

def __init__(self, archive):
EmitterBase.__init__(self, archive, 3, None)
self.arg = None

def ask(self, arg=None):
self.arg = arg
return []

def tell(self,
solutions,
objective_values,
behavior_values,
metadata=None,
arg=None):
self.arg = arg

archive = GridArchive([100, 100], [(-1, 1), (-1, 1)])
emitters = [KwargsEmitter(archive) for _ in range(3)]
return emitters, Optimizer(archive, emitters)


def test_ask_with_no_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs=None)
for e in emitters:
assert e.arg is None


def test_ask_with_dict_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs={"arg": 42})
for e in emitters:
assert e.arg == 42


def test_ask_with_list_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs=[{"arg": 1}, {"arg": 2}, {"arg": 3}])
for e, val in zip(emitters, [1, 2, 3]):
assert e.arg == val


def test_tell_with_no_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell([], [], [], emitter_kwargs=None)
for e in emitters:
assert e.arg is None


def test_tell_with_dict_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell([], [], [], emitter_kwargs={"arg": 42})
for e in emitters:
assert e.arg == 42


def test_tell_with_list_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell(
[],
[],
[],
emitter_kwargs=[{
"arg": 1
}, {
"arg": 2
}, {
"arg": 3
}],
)
for e, val in zip(emitters, [1, 2, 3]):
assert e.arg == val

0 comments on commit 05f4133

Please sign in to comment.