-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major API Changes for Emitter #207
Closed
Closed
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
3a30f7f
add _evolution_startegy_emitter
itsdawei 9c857e6
add ranker base and random_direction_ranker
itsdawei 0e80bbb
add two stage imporvement ranker
itsdawei d4ba0d0
update rd_ranker
itsdawei 4d754fe
rename file
itsdawei a94efbf
two_stage ranker
itsdawei 11e0bf1
remove generate_random_direction
itsdawei 489d5ff
finish comments
itsdawei a8b104e
add more comments to ranker classes
itsdawei 269ac3b
comment in ranker_base
itsdawei 50567e9
comment in random_direction_ranker
itsdawei 5bf5795
objective rankers
itsdawei 6686968
add selectors
itsdawei 40e77ad
minor chage in es_emitter
itsdawei 8e85926
rankers.reset() signature change
itsdawei cfcd277
fix requested changes
itsdawei 679b8fe
handle duplicate docstring; fix requested change
itsdawei 2e4c209
minor fix
itsdawei 7478720
remove old docstring
itsdawei 2ab1ec6
use np.lexsort()
itsdawei 429becd
fix call to ranker and selector
itsdawei b98ccaf
combine selectors to selectors.py
itsdawei 183f758
delete old selectors
itsdawei 24f10dc
update
itsdawei e51ffd8
update sphinx doc with selector
itsdawei c1d6e4b
fix sphinx
itsdawei fca5426
use numpy
itsdawei b08fb90
implement: get_ranker
itsdawei df95565
force add docs/api rankers and selctors module
itsdawei File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""Provides the EvolutionStrategyEmitter.""" | ||
import itertools | ||
|
||
import numpy as np | ||
|
||
from ribs.emitters._emitter_base import EmitterBase | ||
|
||
|
||
class EvolutionStrategyEmitter(EmitterBase): | ||
"""Adapts a evolution strategy optimizer towards the objective. | ||
|
||
This emitter originates in `Fontaine 2020 | ||
<https://arxiv.org/abs/1912.02400>`_. Initially, it starts at ``x0`` and | ||
uses some evolution strategy (i.e. CMA-ES) to optimize for objective values. | ||
After the evolution strategy converges, the emitter restarts the optimizer. | ||
|
||
Args: | ||
archive (ribs.archives.ArchiveBase): An archive to use when creating and | ||
inserting solutions. For instance, this can be | ||
:class:`ribs.archives.GridArchive`. | ||
x0 (np.ndarray): Initial solution. | ||
ranker (ribs.emitters.rankers.RankerBase): The Ranker object defines | ||
how the generated solutions are ranked and what to do on restart. | ||
selector (ribs.emitters.selectors.Selector): Method for selecting | ||
solutions in CMA-ES. With "mu" selection, the first half of the | ||
solutions will be selected, while in "filter", any solutions that | ||
were added to the archive will be selected. | ||
evolution_strategy (EvolutionStrategy): The evolution strategy to use | ||
:class:`ribs.emitter.opt.CMAEvolutionStrategy` | ||
restart_rule ("no_improvement" or "basic"): Method to use when checking | ||
for restart. With "basic", only the default CMA-ES convergence rules | ||
will be used, while with "no_improvement", the emitter will restart | ||
when none of the proposed solutions were added to the archive. | ||
bounds (None or array-like): Bounds of the solution space. Solutions are | ||
clipped to these bounds. Pass None to indicate there are no bounds. | ||
Alternatively, pass an array-like to specify the bounds for each | ||
dim. Each element in this array-like can be None to indicate no | ||
bound, or a tuple of ``(lower_bound, upper_bound)``, where | ||
``lower_bound`` or ``upper_bound`` may be None to indicate no bound. | ||
batch_size (int): Number of solutions to return in :meth:`ask`. If not | ||
passed in, a batch size will automatically be calculated. | ||
seed (int): Value to seed the random number generator. Set to None to | ||
avoid a fixed seed. | ||
Raises: | ||
ValueError: If ``restart_rule`` is invalid. | ||
""" | ||
|
||
def __init__(self, | ||
archive, | ||
x0, | ||
ranker, | ||
selector, | ||
evolution_strategy, | ||
restart_rule="no_improvement", | ||
bounds=None, | ||
batch_size=None, | ||
seed=None): | ||
self._rng = np.random.default_rng(seed) | ||
self._x0 = np.array(x0, dtype=archive.dtype) | ||
EmitterBase.__init__( | ||
self, | ||
archive, | ||
len(self._x0), | ||
bounds, | ||
) | ||
|
||
if restart_rule not in ["basic", "no_improvement"]: | ||
raise ValueError(f"Invalid restart_rule {restart_rule}") | ||
self._restart_rule = restart_rule | ||
|
||
self.opt = evolution_strategy | ||
self.opt.reset(self._x0) | ||
|
||
self._ranker = ranker | ||
self._ranker.reset(archive, self) | ||
|
||
self._selector = selector | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._batch_size = batch_size | ||
self._restarts = 0 # Currently not exposed publicly. | ||
|
||
@property | ||
def x0(self): | ||
"""numpy.ndarray: Initial solution for the optimizer.""" | ||
return self._x0 | ||
|
||
@property | ||
def batch_size(self): | ||
"""int: Number of solutions to return in :meth:`ask`.""" | ||
return self._batch_size | ||
|
||
def ask(self): | ||
"""Samples new solutions from a multivariate Gaussian. | ||
|
||
The multivariate Gaussian is parameterized by the evolution strategy | ||
optimizer ``self.opt``. | ||
|
||
Returns: | ||
``(batch_size, solution_dim)`` array -- contains ``batch_size`` new | ||
solutions to evaluate. | ||
""" | ||
return self.opt.ask(self.lower_bounds, self.upper_bounds) | ||
|
||
def _check_restart(self, num_parents): | ||
"""Emitter-side checks for restarting the optimizer. | ||
|
||
The optimizer also has its own checks. | ||
""" | ||
if self._restart_rule == "no_improvement": | ||
return num_parents == 0 | ||
return False | ||
|
||
def tell(self, solutions, objective_values, behavior_values, metadata=None): | ||
"""Gives the emitter results from evaluating solutions. | ||
|
||
As we insert solutions into the archive, we record the solutions' | ||
impact on the fitness of the archive. For example, if the added | ||
solution makes an improvement on an existing elite, then we | ||
will record ``(AddStatus.IMPROVED_EXISTING, imporvement_value)`` | ||
|
||
The solutions are ranked based on the `rank()` function defined by | ||
`self._ranker`. | ||
|
||
`self._selector` defines how many top solutions are passed to the | ||
evolution strategy. | ||
|
||
Args: | ||
solutions (numpy.ndarray): Array of solutions generated by this | ||
emitter's :meth:`ask()` method. | ||
objective_values (numpy.ndarray): 1D array containing the objective | ||
function value of each solution. | ||
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)`` | ||
array with the behavior space coordinates of each solution. | ||
metadata (numpy.ndarray): 1D object array containing a metadata | ||
object for each solution. | ||
""" | ||
# Tuples of (solution was added, projection onto random direction, | ||
# index). | ||
ranking_data = [] | ||
|
||
# Tupe of (add status, add value) | ||
add_results = [] | ||
|
||
metadata = itertools.repeat(None) if metadata is None else metadata | ||
|
||
# Add solutions to the archive. | ||
for i, (sol, obj, beh, meta) in enumerate( | ||
zip(solutions, objective_values, behavior_values, metadata)): | ||
add_results.append(self.archive.add(sol, obj, beh, meta)) | ||
|
||
# Sort the solutions with some Ranker | ||
indices = self._ranker.rank(self, self.archive, solutions, | ||
objective_values, behavior_values, metadata, | ||
add_results[0], add_results[1]) | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Select the number of parents | ||
num_parents = self._selector.select(self, self.archive, solutions, | ||
objective_values, behavior_values, | ||
metadata, add_results[0], | ||
add_results[1]) | ||
|
||
# Update Evolution Strategy | ||
self.opt.tell(solutions[indices], num_parents) | ||
|
||
# Check for reset. | ||
if (self.opt.check_stop([obj for status, obj, i in ranking_data]) or | ||
self._check_restart(new_sols)): | ||
new_x0 = self.archive.sample_elites(1).solution_batch[0] | ||
self.opt.reset(new_x0) | ||
self._ranker.reset(self, self.archive) | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._restarts += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Internal subpackage with rankers for use across emitters.""" | ||
from ribs.emitters.rankers._ranker_base import RankerBase | ||
from ribs.emitters.rankers._random_direction_ranker import RandomDirectionRanker | ||
from ribs.emitters.rankers._two_stage_random_direction_ranker import TwoStageRandomDirectionRanker | ||
from ribs.emitters.rankers._objective_ranker import ObjectiveRanker | ||
from ribs.emitters.rankers._two_stage_improvement_ranker import TwoStageImprovementRanker | ||
|
||
__all__ = [ | ||
"RandomDirectionRanker", | ||
"TwoStageRandomDirectionRanker", | ||
"ObjectiveRanker", | ||
"TwoStageImprovementRanker", | ||
"RankerBase", | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Provides the ObjectiveRanker | ||
""" | ||
import numpy as np | ||
|
||
from ribs.emitters.rankers._ranker_base import RankerBase | ||
|
||
|
||
class ObjectiveRanker(RankerBase): | ||
"""Ranks the solutions based on the objective values | ||
|
||
This ranker originates in `Fontaine 2020 | ||
<https://arxiv.org/abs/1912.02400>`_ as OptimizingEmitter. | ||
We rank the solutions solely based on their objective values. | ||
""" | ||
|
||
def rank(self, emitter, archive, solutions, objective_values, | ||
behavior_values, metadata, add_statuses, add_values): | ||
"""Ranks the soutions based on their objective values. | ||
|
||
Args: | ||
emitter (ribs.emitters.EmitterBase): | ||
archive (ribs.archives.ArchiveBase): An archive to use when creating | ||
and inserting solutions. For instance, this can be | ||
:class:`ribs.archives.GridArchive`. | ||
solutions (numpy.ndarray): Array of solutions generated by this | ||
emitter's :meth:`ask()` method. | ||
objective_values (numpy.ndarray): 1D array containing the objective | ||
function value of each solution. | ||
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)`` | ||
array with the behavior space coordinates of each solution. | ||
metadata (numpy.ndarray): 1D object array containing a metadata | ||
object for each solution. | ||
add_statuses (): | ||
add_values (): | ||
|
||
Returns: | ||
indices: indicate order of the solutions in descending order | ||
""" | ||
ranking_data = [] | ||
for i, (obj, status) in enumerate(zip(objective_values, add_statuses)): | ||
added = bool(status) | ||
ranking_data.append((added, obj, i)) | ||
if added: | ||
new_sols += 1 | ||
|
||
|
||
# if self._selection_rule == "filter": | ||
# # Sort by whether the solution was added into the archive, followed | ||
# # by objective value. | ||
# key = lambda x: (x[0], x[1]) | ||
|
||
# Sort only by objective value. | ||
ranking_data.sort(reverse=True, key=lambda x: x[1]) | ||
return [d[2] for d in ranking_data] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Provides the RandomDirectionRanker | ||
""" | ||
import numpy as np | ||
|
||
from ribs.emitters.rankers._ranker_base import RankerBase | ||
|
||
|
||
class RandomDirectionRanker(RankerBase): | ||
"""Ranks the solutions based on projection onto a direction in | ||
behavior space. | ||
|
||
This ranker originates in `Fontaine 2020 | ||
<https://arxiv.org/abs/1912.02400>`_ as RandomDirectionEmitter. | ||
We rank the solutions solely based on their projection onto a random | ||
direction in behavior space. | ||
|
||
To rank the solutions first by whether they were added, and then by | ||
the projection, refer to | ||
:class:`ribs.emitters.rankers.TwoStageRandomDirectionRanker`. | ||
""" | ||
|
||
def rank(self, emitter, archive, solutions, objective_values, | ||
behavior_values, metadata, add_statuses, add_values): | ||
"""Ranks the soutions based on projection onto a direction in behavior | ||
space. | ||
|
||
Args: | ||
emitter (ribs.emitters.EmitterBase): | ||
archive (ribs.archives.ArchiveBase): An archive to use when creating | ||
and inserting solutions. For instance, this can be | ||
:class:`ribs.archives.GridArchive`. | ||
solutions (numpy.ndarray): Array of solutions generated by this | ||
emitter's :meth:`ask()` method. | ||
objective_values (numpy.ndarray): 1D array containing the objective | ||
function value of each solution. | ||
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)`` | ||
array with the behavior space coordinates of each solution. | ||
metadata (numpy.ndarray): 1D object array containing a metadata | ||
object for each solution. | ||
add_statuses (): | ||
add_values (): | ||
|
||
Returns: | ||
indices: which represent the descending order of the solutions | ||
""" | ||
ranking_data = [] | ||
for i, (beh, status) in enumerate(zip(behavior_values, add_statuses)): | ||
projection = np.dot(beh, self._target_behavior_dir) | ||
added = bool(status) | ||
|
||
ranking_data.append((added, projection, i)) | ||
if added: | ||
new_sols += 1 | ||
|
||
# Sort only by projection. | ||
ranking_data.sort(reverse=True, key=lambda x: x[1]) | ||
return [d[2] for d in ranking_data] | ||
|
||
def reset(self, archive, emitter): | ||
"""Generates a new random direction in the behavior space. | ||
|
||
The direction is sampled from a standard Gaussian -- since the standard | ||
Gaussian is isotropic, there is equal probability for any direction. The | ||
direction is then scaled to the behavior space bounds. | ||
""" | ||
|
||
ranges = archive.upper_bounds - archive.lower_bounds | ||
behavior_dim = len(ranges) | ||
unscaled_dir = self._rng.standard_normal(behavior_dim) | ||
self._target_behavior_dir = unscaled_dir * ranges |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Provides the RankerBase.""" | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class RankerBase(ABC): | ||
"""Base class for rankers. | ||
|
||
Every ranker has an :meth:`rank` method that returns a list of indices | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
that indicate how the solutions should be ranked and an :meth:`reset` method | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
that resets the internal state of the ranker | ||
(i.e. in :class:`ribs.emitters.rankers._random_direction_ranker`). | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Child classes are only required to override :meth:`rank`. | ||
""" | ||
|
||
@abstractmethod | ||
def rank(self, emitter, archive, solutions, objective_values, | ||
behavior_values, metadata, add_statuses, add_values): | ||
"""Generate a list of indices that represents an ordering of solutions | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
emitter (ribs.emitters.EmitterBase): The emitter that this ranker | ||
object is associated with. | ||
archive (ribs.archives.ArchiveBase): An archive to use when creating | ||
and inserting solutions. For instance, this can be | ||
:class:`ribs.archives.GridArchive`. | ||
solutions (numpy.ndarray): Array of solutions generated by this | ||
emitter's :meth:`ask()` method. | ||
objective_values (numpy.ndarray): 1D array containing the objective | ||
function value of each solution. | ||
behavior_values (numpy.ndarray): ``(n, <behavior space dimension>)`` | ||
array with the behavior space coordinates of each solution. | ||
metadata (numpy.ndarray): 1D object array containing a metadata | ||
object for each solution. | ||
add_statuses (): | ||
add_values (): | ||
|
||
Returns: | ||
indices: which represent the descending order of the solutions | ||
""" | ||
|
||
def reset(self, emitter, archive): | ||
"""Resets the internal state of the ranker | ||
|
||
Args: | ||
emitter (ribs.emitters.EmitterBase): The emitter that this ranker | ||
object is associated with. | ||
archive (ribs.archives.ArchiveBase): An archive to use when creating | ||
and inserting solutions. For instance, this can be | ||
:class:`ribs.archives.GridArchive`. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
from ribs.archives import AddStatus | ||
from ribs.emitters.rankers._ranker_base import RankerBase | ||
|
||
|
||
class TwoStageImprovementRanker(RankerBase): | ||
|
||
def rank(self, emitter, archive, solutions, objective_values, | ||
behavior_values, metadata, statuses, add_values): | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ranking_data = [] | ||
for i, (status, add_value) in enumerate(zip(statuses, add_value)): | ||
itsdawei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
added = bool(status) | ||
ranking_data.append((added, add_value, i)) | ||
if status in (AddStatus.NEW, AddStatus.IMPROVE_EXISTING): | ||
new_sols += 1 | ||
|
||
# New solutions sort ahead of improved ones, which sort ahead of ones | ||
# that were not added. | ||
ranking_data.sort(reverse=True) | ||
return [d[2] for d in ranking_data] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make a restarter object just like for selectors and rankers.