Skip to content

Commit

Permalink
use linked list instead of list
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jan 7, 2023
1 parent 0464937 commit 9a42ef8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 27 deletions.
90 changes: 63 additions & 27 deletions keras_tuner/tuners/gridsearch.py
Expand Up @@ -15,13 +15,72 @@
"Basic exhaustive search tuner."


import collections
import copy

from keras_tuner.engine import oracle as oracle_module
from keras_tuner.engine import trial as trial_module
from keras_tuner.engine import tuner as tuner_module


class LinkedList:
"""A simplified linked list with limited supported operations.
It doesn't copy any data pass to it but directly refer to it.
"""

def __init__(self):
# _memory is a list to store data.
# Its index is the address for the linked list.
# index to data
self._memory = []
self._data_to_index = {}
# index to index
self._next_index = collections.defaultdict(lambda: None)
self._last_index = None

def insert(self, data, data_pre=None):
"""Insert data after another data.
`data` is inserted after `data_pre` in the linked list.
Args:
data: The data to insert.
data_pre: Optional. The data marking the insertion location. If left
unspecified, the data will be appended to the rear of the linked
list.
"""
self._memory.append(data)
new_index = len(self._memory) - 1
self._data_to_index[data] = new_index

index = (
self._last_index if data_pre is None else self._data_to_index[data_pre]
)

self._next_index[new_index] = self._next_index[index]
self._next_index[index] = new_index

# Update self._last_index.
while self._next_index[self._last_index] is not None:
self._last_index = self._next_index[self._last_index]

def next(self, data):
"""Get the next data for a given data.
Args:
data: The data used to get its next data in the linked list.
Returns:
The next data if exists. Otherwise, return None.
"""
index = self._data_to_index[data]
next_index = self._next_index[index]
if next_index is None:
return None
return self._memory[next_index]


class GridSearchOracle(oracle_module.Oracle):
"""Grid search oracle.
Expand Down Expand Up @@ -82,7 +141,7 @@ def __init__(
)
# List of trial_id sorting in ascending alphabetical order of their hp
# values.
self._ordered_ids = []
self._ordered_ids = LinkedList()
# Queue of trial_ids pending to find their next combinations.
self._populate_next = []

Expand All @@ -105,7 +164,7 @@ def populate_space(self, trial_id):
# See if this is the first trial.
if len(self.start_order) == 0:
# Use all default values for the first trial.
self._ordered_ids.append(trial_id)
self._ordered_ids.insert(trial_id)
hps = self.get_space()
values = {
hp.name: hp.default
Expand All @@ -130,16 +189,13 @@ def populate_space(self, trial_id):
continue

# Skip if tried next combination.
next_id = self._get_next_id(old_trial_id)
next_id = self._ordered_ids.next(old_trial_id)
if next_id is not None:
next_values = self.trials[next_id].hyperparameters.values
if self._compare(new_values, next_values) >= 0:
continue

if next_id is None:
self._ordered_ids.append(trial_id)
else:
self._ordered_ids.insert(self._ordered_ids.index(next_id), trial_id)
self._ordered_ids.insert(trial_id, old_trial_id)

values = new_values

Expand Down Expand Up @@ -196,26 +252,6 @@ def _compare(self, a, b):

return 0

def _get_next_id(self, trial_id):
"""Get next `Trial`'s ID defined by the sorting order of the values.
Imagine sorting all the created `Trial`s in ascending order according to
their values. Get the next trial ID of a given trial ID. The order is
defined by `self._compare()`.
Args:
trial_id: The trial ID of the previous trial sorting in alphabetical
order according to their values.
Returns:
The trial ID if exist or None.
"""
old_index = self._ordered_ids.index(trial_id)
# Check index out of range.
if not (old_index + 1 < len(self._ordered_ids)):
return None
return self._ordered_ids[old_index + 1]

def _get_next_combination(self, values):
"""Get the next value combination to try.
Expand Down
20 changes: 20 additions & 0 deletions keras_tuner/tuners/gridsearch_test.py
Expand Up @@ -203,3 +203,23 @@ def end_trial(trial):
end_trial(trial_4)
trial_5 = oracle.create_trial(tuner_id="5")
assert trial_5.status == trial_module.TrialStatus.STOPPED


def test_linked_list():
linked_list = gridsearch.LinkedList()
linked_list.insert("0")
assert linked_list.next("0") is None
linked_list.insert("1")
assert linked_list.next("0") == "1"
assert linked_list.next("1") is None
linked_list.insert("2", "0")
assert linked_list.next("0") == "2"
assert linked_list.next("2") == "1"
assert linked_list.next("1") is None
linked_list.insert("3", "1")
linked_list.insert("4")
assert linked_list.next("0") == "2"
assert linked_list.next("2") == "1"
assert linked_list.next("1") == "3"
assert linked_list.next("3") == "4"
assert linked_list.next("4") is None

0 comments on commit 9a42ef8

Please sign in to comment.