Skip to content

Commit

Permalink
Enable CMA-ME emitters to work with float32 (#74)
Browse files Browse the repository at this point in the history
* Add test that checks for dtype

* Fix dtype in CMA-ME emitters
  • Loading branch information
btjanaka committed Feb 3, 2021
1 parent 14af85f commit 81c38c5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
11 changes: 8 additions & 3 deletions ribs/emitters/opt/_cma_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, dimension, dtype):
self.eigenvalues = np.ones((dimension,), dtype=dtype)
self.condition_number = 1
self.invsqrt = np.eye(dimension, dtype=dtype) # C^(-1/2)
self.dtype = dtype

# The last evaluation on which the eigensystem was updated.
self.updated_eval = 0
Expand All @@ -53,9 +54,10 @@ def update_eigensystem(self, current_eval, lazy_gap_evals):
# Force symmetry.
self.cov = np.maximum(self.cov, self.cov.T)

# Note: eigh returns float64, so we must cast it.
self.eigenvalues, self.eigenbasis = np.linalg.eigh(self.cov)
self.eigenvalues = self.eigenvalues.real
self.eigenbasis = self.eigenbasis.real
self.eigenvalues = self.eigenvalues.real.astype(self.dtype)
self.eigenbasis = self.eigenbasis.real.astype(self.dtype)
self.condition_number = (np.max(self.eigenvalues) /
np.min(self.eigenvalues))
self.invsqrt = (self.eigenbasis *
Expand Down Expand Up @@ -199,7 +201,10 @@ def ask(self, lower_bounds, upper_bounds):
remaining_indices = np.arange(self.batch_size)
while len(remaining_indices) > 0:
unscaled_params = self._rng.normal(
0.0, self.sigma, (len(remaining_indices), self.solution_dim))
0.0,
self.sigma,
(len(remaining_indices), self.solution_dim),
).astype(self.dtype)
new_solutions, out_of_bounds = self._transform_and_check_sol(
unscaled_params, transform_mat, self.mean, lower_bounds,
upper_bounds)
Expand Down
56 changes: 29 additions & 27 deletions tests/core/emitters/cma_me_emitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,41 @@


@pytest.mark.parametrize(
"emitter_name",
["ImprovementEmitter", "RandomDirectionEmitter", "OptimizingEmitter"])
def test_auto_batch_size(emitter_name):
"emitter_class",
[ImprovementEmitter, RandomDirectionEmitter, OptimizingEmitter])
def test_auto_batch_size(emitter_class):
archive = GridArchive([20, 20], [(-1.0, 1.0)] * 2)

# Batch size is not provided, so it should be auto-generated.
emitter = {
"ImprovementEmitter":
lambda: ImprovementEmitter(archive, np.zeros(20), 1.0),
"RandomDirectionEmitter":
lambda: RandomDirectionEmitter(archive, np.zeros(20), 1.0),
"OptimizingEmitter":
lambda: OptimizingEmitter(archive, np.zeros(20), 1.0),
}[emitter_name]()

emitter = emitter_class(archive, np.zeros(10), 1.0)
assert emitter.batch_size is not None
assert isinstance(emitter.batch_size, int)


@pytest.mark.parametrize(
"emitter_name",
["ImprovementEmitter", "RandomDirectionEmitter", "OptimizingEmitter"])
def test_list_as_initial_solution(emitter_name):
"emitter_class",
[ImprovementEmitter, RandomDirectionEmitter, OptimizingEmitter])
def test_list_as_initial_solution(emitter_class):
archive = GridArchive([20, 20], [(-1.0, 1.0)] * 2)

emitter = {
"ImprovementEmitter":
lambda: ImprovementEmitter(archive, [0.0] * 20, 1.0),
"RandomDirectionEmitter":
lambda: RandomDirectionEmitter(archive, [0.0] * 20, 1.0),
"OptimizingEmitter":
lambda: OptimizingEmitter(archive, [0.0] * 20, 1.0),
}[emitter_name]()
emitter = emitter_class(archive, [0.0] * 10, 1.0)

# The list was passed in but should be converted to a numpy array.
assert (emitter.x0 == np.zeros(20)).all()
assert isinstance(emitter.x0, np.ndarray)
assert (emitter.x0 == np.zeros(10)).all()


@pytest.mark.parametrize(
"emitter_class",
[ImprovementEmitter, RandomDirectionEmitter, OptimizingEmitter])
@pytest.mark.parametrize("dtype", [np.float64, np.float32],
ids=["float64", "float32"])
def test_dtypes(emitter_class, dtype):
archive = GridArchive([20, 20], [(-1.0, 1.0)] * 2, dtype=dtype)
archive.initialize(10)
emitter = emitter_class(archive, np.zeros(10), 1.0)
assert emitter.x0.dtype == dtype

# Try running with the negative sphere function for a few iterations.
for _ in range(10):
sols = emitter.ask()
objs = -np.sum(np.square(sols), axis=1)
bcs = sols[:, :2]
emitter.tell(sols, objs, bcs)

0 comments on commit 81c38c5

Please sign in to comment.