Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major API Changes for Emitter #207

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3a30f7f
add _evolution_startegy_emitter
itsdawei Jun 13, 2022
9c857e6
add ranker base and random_direction_ranker
itsdawei Jun 13, 2022
0e80bbb
add two stage imporvement ranker
itsdawei Jun 13, 2022
d4ba0d0
update rd_ranker
itsdawei Jun 13, 2022
4d754fe
rename file
itsdawei Jun 13, 2022
a94efbf
two_stage ranker
itsdawei Jun 13, 2022
11e0bf1
remove generate_random_direction
itsdawei Jun 13, 2022
489d5ff
finish comments
itsdawei Jun 14, 2022
a8b104e
add more comments to ranker classes
itsdawei Jun 14, 2022
269ac3b
comment in ranker_base
itsdawei Jun 14, 2022
50567e9
comment in random_direction_ranker
itsdawei Jun 14, 2022
5bf5795
objective rankers
itsdawei Jun 14, 2022
6686968
add selectors
itsdawei Jun 15, 2022
40e77ad
minor chage in es_emitter
itsdawei Jun 15, 2022
8e85926
rankers.reset() signature change
itsdawei Jun 15, 2022
cfcd277
fix requested changes
itsdawei Jun 16, 2022
679b8fe
handle duplicate docstring; fix requested change
itsdawei Jun 17, 2022
2e4c209
minor fix
itsdawei Jun 17, 2022
7478720
remove old docstring
itsdawei Jun 17, 2022
2ab1ec6
use np.lexsort()
itsdawei Jun 17, 2022
429becd
fix call to ranker and selector
itsdawei Jun 17, 2022
b98ccaf
combine selectors to selectors.py
itsdawei Jun 17, 2022
183f758
delete old selectors
itsdawei Jun 17, 2022
24f10dc
update
itsdawei Jun 17, 2022
e51ffd8
update sphinx doc with selector
itsdawei Jun 17, 2022
c1d6e4b
fix sphinx
itsdawei Jun 17, 2022
fca5426
use numpy
itsdawei Jun 17, 2022
b08fb90
implement: get_ranker
itsdawei Jun 17, 2022
df95565
force add docs/api rankers and selctors module
itsdawei Jun 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/ribs.emitters.rankers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ribs.emitters.rankers
=============

.. automodule:: ribs.emitters.rankers
:no-members:
:no-inherited-members:
:no-special-members:
7 changes: 7 additions & 0 deletions docs/api/ribs.emitters.selectors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ribs.emitters.selectors
=============

.. automodule:: ribs.emitters.selectors
:no-members:
:no-inherited-members:
:no-special-members:
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ examples

api/ribs.archives
api/ribs.emitters
api/ribs.emitters.rankers
api/ribs.emitters.selectors
api/ribs.optimizers
api/ribs.factory
api/ribs.visualize
Expand Down
111 changes: 111 additions & 0 deletions ribs/_docstrings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""This provides the common docstrings that is used in the project
"""
itsdawei marked this conversation as resolved.
Show resolved Hide resolved

import re


class DocstringComponents:
# Taken from
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
# github.com/mwaskom/seaborn/blob/9d8ce6ad4ab213994f0bc84d0c46869df7be0b49/seaborn/_docstrings.py
regexp = re.compile(r"\n((\n|.)+)\n\s*", re.MULTILINE)

def __init__(self, comp_dict, strip_whitespace=True):
"""Read entries from a dict, optionally stripping outer whitespace."""
if strip_whitespace:
entries = {}
for key, val in comp_dict.items():
m = re.match(self.regexp, val)
if m is None:
entries[key] = val
else:
entries[key] = m.group(1)
else:
entries = comp_dict.copy()

self.entries = entries

def __getattr__(self, attr):
"""Provide dot access to entries for clean raw docstrings."""
if attr in self.entries:
return self.entries[attr]
try:
return self.__getattribute__(attr)
except AttributeError as err:
# If Python is run with -OO, it will strip docstrings and our lookup
# from self.entries will fail. We check for __debug__, which is actually
# set to False by -O (it is True for normal execution).
# But we only want to see an error when building the docs;
# not something users should see, so this slight inconsistency is fine.
if __debug__:
raise err
else:
pass

@classmethod
def from_nested_components(cls, **kwargs):
"""Add multiple sub-sets of components."""
return cls(kwargs, strip_whitespace=False)

@classmethod
def from_function_params(cls, func):
"""Use the numpydoc parser to extract components from existing func."""
params = NumpyDocString(pydoc.getdoc(func))["Parameters"]
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
comp_dict = {}
for p in params:
name = p.name
type = p.type
desc = "\n ".join(p.desc)
comp_dict[name] = f"{name} : {type}\n {desc}"

return cls(comp_dict)


_core_args = dict(
emitter="""
emitter (ribs.emitters.EmitterBase): Emitter to use for generating
solutions and updating the archive.
""",
archive="""
archive (ribs.archives.ArchiveBase): Archive to use when creating
and inserting solutions. For instance, this can be
:class:`ribs.archives.GridArchive`.
""",
solutions="""
solutions (numpy.ndarray): Array of solutions generated by the
emitter's :meth:`ask()` method.
""",
objective_values="""
objective_values (numpy.ndarray): 1D array containing the objective
function value of each solution.
""",
behavior_values="""
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)``
array with the behavior space coordinates of each solution.
""",
metadata="""
metadata (numpy.ndarray): 1D object array containing a metadata
object for each solution.
""",
add_statuses="""
add_statuses (numpy.ndarray): 1D array of :class:`ribs.archive.AddStatus`
returned by a series of calls to archive's :meth:`add()` method.
""",
add_values="""
add_values (numpy.ndarray): 1D array of floats returned by a series of
calls to archive's :meth:`add()` method. For what these floats
represent, refer to :meth:`ribs.archives.add()`
""",
seed="""
seed (int): Value to seed the random number generator. Set to None to
avoid a fixed seed.
"""
)

_core_returns = dict(something="""
something
""")

_core_docs = dict(
args=DocstringComponents(_core_args),
returns=DocstringComponents(_core_returns),
)
187 changes: 187 additions & 0 deletions ribs/emitters/_evolution_strategy_emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Provides the EvolutionStrategyEmitter."""
import itertools

import numpy as np

from ribs.emitters._emitter_base import EmitterBase
from ribs.emitters.rankers import RankerBase


itsdawei marked this conversation as resolved.
Show resolved Hide resolved
def get_ranker(ranker_name):
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
try:
ranker = globals()[ranker_name]()
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
if issubclass(ranker, RankerBase):
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError
return ranker
except KeyError as key_error:
raise RuntimeError("Cannot find class" + ranker_name) from key_error
except TypeError as type_error:
raise RuntimeError(ranker_name +
"is not a subclass of RankerBase") from type_error


class EvolutionStrategyEmitter(EmitterBase):
"""Adapts a evolution strategy optimizer towards the objective.

This emitter originates in `Fontaine 2020
<https://arxiv.org/abs/1912.02400>`_. Initially, it starts at ``x0`` and
uses some evolution strategy (i.e. CMA-ES) to optimize for objective values.
After the evolution strategy converges, the emitter restarts the optimizer.

Args:
archive (ribs.archives.ArchiveBase): An archive to use when creating and
inserting solutions. For instance, this can be
:class:`ribs.archives.GridArchive`.
x0 (np.ndarray): Initial solution.
ranker (ribs.emitters.rankers.RankerBase or str): The Ranker object
defines how the generated solutions are ranked and what to do on
restart. If passing in the name of the ranker as a string,
the corresponding ranker will be created in the constructor.
selector (ribs.emitters.selectors.Selector): Method for selecting
solutions in CMA-ES. With "mu" selection, the first half of the
solutions will be selected, while in "filter", any solutions that
were added to the archive will be selected.
evolution_strategy (EvolutionStrategy): The evolution strategy to use
:class:`ribs.emitter.opt.CMAEvolutionStrategy`
restart_rule ("no_improvement" or "basic"): Method to use when checking
for restart. With "basic", only the default CMA-ES convergence rules
will be used, while with "no_improvement", the emitter will restart
when none of the proposed solutions were added to the archive.
bounds (None or array-like): Bounds of the solution space. Solutions are
clipped to these bounds. Pass None to indicate there are no bounds.
Alternatively, pass an array-like to specify the bounds for each
dim. Each element in this array-like can be None to indicate no
bound, or a tuple of ``(lower_bound, upper_bound)``, where
``lower_bound`` or ``upper_bound`` may be None to indicate no bound.
batch_size (int): Number of solutions to return in :meth:`ask`. If not
passed in, a batch size will automatically be calculated.
seed (int): Value to seed the random number generator. Set to None to
avoid a fixed seed.
Raises:
ValueError: If ``restart_rule`` is invalid.
"""

def __init__(self,
archive,
x0,
ranker,
selector,
evolution_strategy,
restart_rule="no_improvement",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make a restarter object just like for selectors and rankers.

bounds=None,
batch_size=None,
seed=None):
self._rng = np.random.default_rng(seed)
self._x0 = np.array(x0, dtype=archive.dtype)
EmitterBase.__init__(
self,
archive,
len(self._x0),
bounds,
)

if restart_rule not in ["basic", "no_improvement"]:
raise ValueError(f"Invalid restart_rule {restart_rule}")
self._restart_rule = restart_rule

self.opt = evolution_strategy
self.opt.reset(self._x0)

self._ranker = get_ranker(ranker) if isinstance(ranker, str) else ranker
self._ranker.reset(self, archive)

self._selector = selector
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
self._ranker.reset(self, archive)

self._batch_size = batch_size
self._restarts = 0 # Currently not exposed publicly.

@property
def x0(self):
"""numpy.ndarray: Initial solution for the optimizer."""
return self._x0

@property
def batch_size(self):
"""int: Number of solutions to return in :meth:`ask`."""
return self._batch_size

def ask(self):
"""Samples new solutions from a multivariate Gaussian.

The multivariate Gaussian is parameterized by the evolution strategy
optimizer ``self.opt``.

Returns:
``(batch_size, solution_dim)`` array -- contains ``batch_size`` new
solutions to evaluate.
"""
return self.opt.ask(self.lower_bounds, self.upper_bounds)

def _check_restart(self, num_parents):
"""Emitter-side checks for restarting the optimizer.

The optimizer also has its own checks.
"""
if self._restart_rule == "no_improvement":
return num_parents == 0
return False

def tell(self, solutions, objective_values, behavior_values, metadata=None):
"""Gives the emitter results from evaluating solutions.

As we insert solutions into the archive, we record the solutions'
impact on the fitness of the archive. For example, if the added
solution makes an improvement on an existing elite, then we
will record ``(AddStatus.IMPROVED_EXISTING, imporvement_value)``

The solutions are ranked based on the `rank()` function defined by
`self._ranker`.

`self._selector` defines how many top solutions are passed to the
evolution strategy.

Args:
solutions (numpy.ndarray): Array of solutions generated by this
emitter's :meth:`ask()` method.
objective_values (numpy.ndarray): 1D array containing the objective
function value of each solution.
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)``
array with the behavior space coordinates of each solution.
metadata (numpy.ndarray): 1D object array containing a metadata
object for each solution.
"""
add_statues = []
add_values = []

metadata = itertools.repeat(None) if metadata is None else metadata

# Add solutions to the archive.
for i, (sol, obj, beh, meta) in enumerate(
zip(solutions, objective_values, behavior_values, metadata)):
status, value = self.archive.add(sol, obj, beh, meta)
add_statues.append(status)
add_values.append(value)

# Sort the solutions with some Ranker
indices = self._ranker.rank(self, self.archive, solutions,
objective_values, behavior_values, metadata,
add_statues, add_values)

# Select the number of parents
num_parents = self._selector.select(self, self.archive, solutions,
objective_values, behavior_values,
metadata, add_statues, add_values)

# Update Evolution Strategy
self.opt.tell(solutions[indices], num_parents)

# Check for reset.
# TODO bug: no access to ranking_data and new_sols
if (self.opt.check_stop([obj for status, obj, i in ranking_data]) or
self._check_restart(new_sols)):
new_x0 = self.archive.sample_elites(1).solution_batch[0]
self.opt.reset(new_x0)
self._ranker.reset(self, self.archive)
itsdawei marked this conversation as resolved.
Show resolved Hide resolved
self._selector.reset(self, self.archive)
self._restarts += 1