Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,15 @@ def test_save_best_model(self):
count=1,
pickle_func=cloudpickle.dumps,
worker_args=(),
worker_kwargs={}) as pool:
worker_kwargs={
'delta': -1.0,
'initial_value': 5
}) as pool:
self._learner.set_baseline(pool)
self._learner.run_step(pool)
self.assertEqual(len(self._saved_policies), 1)
self.assertIn('iteration0', self._saved_policies)
self.assertIn('best_policy_1.01_step_0', self._saved_policies)
self._learner.run_step(pool)
self.assertIn('iteration1', self._saved_policies)
self.assertIn('best_policy_1.07_step_1', self._saved_policies)

def test_save_best_model_only_saves_best(self):
with local_worker_manager.LocalWorkerPoolManager(
Expand All @@ -191,9 +193,15 @@ def test_save_best_model_only_saves_best(self):
pickle_func=cloudpickle.dumps,
worker_args=(),
worker_kwargs={
'delta': -1.0,
'delta': 1.0,
'initial_value': 5
}) as pool:
self._learner.set_baseline(pool)
self._learner.run_step(pool)
self.assertIn('best_policy_100.0_step_0', self._saved_policies)
self.assertIn('best_policy_0.94_step_0', self._saved_policies)

# Check that the within the next step we only get a new iteration
# policy and do not save any new best.
current_policies_count = len(self._saved_policies)
self._learner.run_step(pool)
self.assertLen(self._saved_policies, current_policies_count + 1)
8 changes: 6 additions & 2 deletions compiler_opt/es/blackbox_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from compiler_opt.distributed import worker
from compiler_opt.rl import corpus
from compiler_opt.rl import policy_saver
from compiler_opt.rl import constant


@gin.configurable
Expand All @@ -34,11 +35,14 @@ def __init__(self, *, delta=1.0, initial_value=0.0):

def compile(self, policy: policy_saver.Policy,
modules: list[corpus.LoadedModuleSpec]) -> float:
# We return the values with constant.DELTA subtracted so that we get
# exact values we can assert against when writing tests that only see
# the relative reward.
if policy and modules:
self.function_value += self._delta
return self.function_value
return self.function_value - constant.DELTA
else:
return 0.0
return 100 - constant.DELTA


class SizeReturningESWorker(worker.Worker):
Expand Down
Loading