Skip to content

Commit

Permalink
implemented the design
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jan 1, 2023
1 parent 1b6365f commit c63987c
Showing 1 changed file with 145 additions and 34 deletions.
179 changes: 145 additions & 34 deletions keras_tuner/tuners/gridsearch.py
Expand Up @@ -22,18 +22,6 @@
from keras_tuner.engine import tuner as tuner_module


def compare(a, b):
"""Compare two `HyperParameters` values.
Args:
a:
b:
Returns:
"""
pass


class GridSearchOracle(oracle_module.Oracle):
"""Grid search oracle.
Expand Down Expand Up @@ -92,6 +80,11 @@ def __init__(
max_retries_per_trial=max_retries_per_trial,
max_consecutive_failed_trials=max_consecutive_failed_trials,
)
# List of trial_id sorting in ascending alphabetical order of their hp
# values.
self._ordered_ids = []
# Queue of trial_ids pending to find their next combinations.
self._populate_next = []

def populate_space(self, trial_id):
"""Fill the hyperparameter space with values.
Expand All @@ -107,22 +100,128 @@ def populate_space(self, trial_id):
"STOPPED" (the oracle has finished searching and no new trial should
be created).
"""
if len(self.start_order) > 0:
last_trial = self.trials[self.start_order[-1]]
last_values = last_trial.hyperparameters.values
# The keys (hp names) in the `last_values` are always consistent with
# the hps in `self.get_space().space`, even for newly appeared hps.
# For example, during last trial's `_populate_space()`, a new hp
# has not appeared. During last trial's `HyperModel.build()`, the
# new hp appeared. The `hyperparameters.values` is then updated
# immediately.
values = self._get_next_combination(last_values)
else:
values = None

# See if this is the first trial.
if len(self.start_order) == 0:
# Use all default values for the first trial.
values = {hp.name: hp.default for hp in self.get_space().space}
if values is None:
return {"status": trial_module.TrialStatus.STOPPED, "values": None}
return {"status": trial_module.TrialStatus.RUNNING, "values": values}
self._ordered_ids.append(trial_id)
hps = self.get_space()
values = {
hp.name: hp.default
for hp in self.get_space().space
if hps.is_active(hp)
}
# Although the trial is not finished, we still push it into
# _populate_next to quickly generate values for the first few trials
# for multiple workers. The same trial_id will be pushed into
# _populate_next again when the trial is finished just in case of
# new hps appeared during the trial.
self._populate_next.append(trial_id)

# Pick tried values to create its next combination if not tried.
while len(self._populate_next) > 0 and values is None:
old_trial_id = self._populate_next.pop(0)

# Create its immediate next combination.
old_values = self.trials[old_trial_id].hyperparameters.values
new_values = self._get_next_combination(old_values)
if new_values is None:
continue

print(trial_id, old_trial_id)
print(old_values)
print(new_values)
# Skip if tried next combination.
next_id = self._get_next_id(old_trial_id)
print(next_id, self._ordered_ids)
if next_id is not None:
next_values = self.trials[next_id].hyperparameters.values
print(next_values)
if self._compare(new_values, next_values) >= 0:
continue

if next_id is None:
self._ordered_ids.append(trial_id)
else:
self._ordered_ids.insert(next_id)

values = new_values

if values is not None:
return {
"status": trial_module.TrialStatus.RUNNING,
"values": values,
}

# Wait for the ongoing trials to finish when the values queue is empty
# in case of any new hp discovered.
if len(self.ongoing_trials) > 0:
return {"status": trial_module.TrialStatus.IDLE, "values": None}

# Reaching this point means ongoing_trial, values, populate_next
# are all empty.
return {"status": trial_module.TrialStatus.STOPPED, "values": None}

def _compare(self, a, b):
"""Compare two `HyperParameters`' values.
The smallest index where a differs from b decides which one is larger.
In the values of one `HyperParameter`, the default value is the smallest.
The rest are sorted according to their order in `HyperParameter.values`.
If one value is the prefix of another, the longer one is larger.
Args:
a: Dict. HyperParameters values. Only active values are included.
b: Dict. HyperParameters values. Only active values are included.
Returns:
-1 if a < b, 0 if a == b, 1 if a > b.
"""
hps = self.get_space()
for hp in hps.space:
# The hp is not active in neither a or b.
# Whether it is active should be the same in a and b,
# or the loop have stopped at the parent values which are different.
if hp.name not in a:
continue

if a[hp.name] == b[hp.name]:
continue

# Get a ordered list of the values of the hp.
value_list = list(hp.values)
if hp.default in value_list:
value_list.remove(hp.default)
value_list.insert(0, hp.default)

index_a = value_list.index(a[hp.name])
index_b = value_list.index(b[hp.name])
if index_a == index_b:
continue
return -1 if index_a < index_b else 1

return 0

def _get_next_id(self, trial_id):
"""Get next `Trial`'s ID defined by the sorting order of the values.
Imagine sorting all the created `Trial`s in ascending order according to
their values. Get the next trial ID of a given trial ID. The order is
defined by `self._compare()`.
Args:
trial_id: The trial ID of the previous trial sorting in alphabetical
order according to their values.
Returns:
The trial ID if exist or None.
"""
old_index = self._ordered_ids.index(trial_id)
# Check index out of range.
if not (old_index + 1 < len(self._ordered_ids)):
return None
return self._ordered_ids[old_index + 1]

def _get_next_combination(self, values):
"""Get the next value combination to try.
Expand All @@ -142,7 +241,10 @@ def _get_next_combination(self, values):
from the last trial.
Returns:
Dict. The next possible value combination for the hyperparameters.
Dict or None. The next possible value combination for the
hyperparameters. If no next combination exist (values is the last
combination), it returns None. The return values only include the
active ones.
"""

hps = self.get_space()
Expand All @@ -154,8 +256,7 @@ def _get_next_combination(self, values):
# Put the default value first.
all_values[hp.name] = [hp.default] + value_list
default_values = {hp.name: hp.default for hp in hps.space}
new_values = copy.deepcopy(values)
hps.values = new_values
hps.values = copy.deepcopy(values)

bumped_value = False

Expand All @@ -165,16 +266,26 @@ def _get_next_combination(self, values):
name = hp.name
# Bump up the hp value if possible and active.
if hps.is_active(hp):
value = new_values[name]
value = hps.values[name]
if value != all_values[name][-1]:
index = all_values[name].index(value) + 1
new_values[name] = all_values[name][index]
hps.values[name] = all_values[name][index]
bumped_value = True
break
# Otherwise, reset to its first value.
new_values[name] = default_values[name]
hps.values[name] = default_values[name]

hps.ensure_active_values()
return hps.values if bumped_value else None

def end_trial(self, trial_id, status="COMPLETED", message=None):
super().end_trial(trial_id=trial_id, status=status, message=message)
# It is OK for a trial_id to be pushed into _populate_next multiple
# times. It will be skipped during _populate_space if its next
# combination has been tried.

return new_values if bumped_value else None
# For not blocking _populate_space, we push it regardless of the status.
self._populate_next.append(trial_id)


class GridSearch(tuner_module.Tuner):
Expand Down

0 comments on commit c63987c

Please sign in to comment.