Skip to content

Commit

Permalink
greedy oracle only generate active hps
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Aug 18, 2020
1 parent ede75e5 commit c270e59
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
57 changes: 32 additions & 25 deletions autokeras/tuners/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import kerastuner
import numpy as np

Expand All @@ -25,7 +23,7 @@ def __init__(self):
super().__init__()
self.num_leaves = 0
self.children = {}
self.hp = None
self.hp_name = None

def is_leaf(self):
return len(self.children) == 0
Expand All @@ -36,10 +34,8 @@ def __init__(self):
super().__init__()
self.root = TrieNode()

def insert(self, hp):
name = hp.name
long_name = name
names = long_name.split("/")
def insert(self, hp_name):
names = hp_name.split("/")

new_word = False
current_node = self.root
Expand All @@ -50,7 +46,7 @@ def insert(self, hp):
new_word = True
current_node = current_node.children[name]
nodes_on_path.append(current_node)
current_node.hp = hp
current_node.hp_name = hp_name

if new_word:
for node in nodes_on_path:
Expand All @@ -66,12 +62,12 @@ def _get_all_nodes(self, node):
ret += self._get_all_nodes(value)
return ret

def get_hps(self, node):
def get_hp_names(self, node):
if node.is_leaf():
return [node.hp]
return [node.hp_name]
ret = []
for key, value in node.children.items():
ret += self.get_hps(value)
ret += self.get_hp_names(value)
return ret


Expand Down Expand Up @@ -113,10 +109,11 @@ def set_state(self, state):
self._tried_initial_hps = state["tried_initial_hps"]

def _select_hps(self):
# TODO: consider condition_scopes.
trie = Trie()
for hp in self.hyperparameters.space:
trie.insert(hp)
best_hps = self._get_best_hps()
for hp in best_hps.space:
if best_hps.is_active(hp):
trie.insert(hp.name)
all_nodes = trie.nodes

if len(all_nodes) <= 1:
Expand All @@ -127,7 +124,7 @@ def _select_hps(self):
probabilities = probabilities / sum_p
node = np.random.choice(all_nodes, p=probabilities)

return trie.get_hps(node)
return trie.get_hp_names(node)

def _next_initial_hps(self):
for index, hps in enumerate(self.initial_hps):
Expand All @@ -144,8 +141,8 @@ def _populate_space(self, trial_id):
}

for i in range(self._max_collisions):
hp_list = self._select_hps()
values = self._generate_hp_values(hp_list)
hp_names = self._select_hps()
values = self._generate_hp_values(hp_names)
# Reached max collisions.
if values is None:
continue
Expand All @@ -160,21 +157,31 @@ def _populate_space(self, trial_id):
"values": None,
}

def _generate_hp_values(self, hp_list):
def _get_best_hps(self):
best_trials = self.get_best_trials()
if best_trials:
best_hps = best_trials[0].hyperparameters
return best_trials[0].hyperparameters
else:
best_hps = self.hyperparameters
return self.hyperparameters

def _generate_hp_values(self, hp_names):
best_hps = self._get_best_hps()

collisions = 0
while True:
hps = copy.deepcopy(best_hps)
hps = kerastuner.HyperParameters()
# Generate a set of random values.
for hp in hp_list:
# TODO: Check is_active for hp.
hps.values[hp.name] = hp.random_sample(self._seed_state)
self._seed_state += 1
for hp in best_hps.space:
hps.merge([hp])
# if not active, do nothing.
# if active, check if selected to be changed.
if hps.is_active(hp):
# if was active and not selected, do nothing.
if best_hps.is_active(hp.name) and hp.name not in hp_names:
continue
# if was not active or selected, sample.
hps.values[hp.name] = hp.random_sample(self._seed_state)
self._seed_state += 1
values = hps.values
# Keep trying until the set of values is unique,
# or until we exit due to too many collisions.
Expand Down
1 change: 0 additions & 1 deletion tests/autokeras/tuners/greedy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def test_greedy_oracle_populate_different_values(get_best_trials):
values_a = oracle._populate_space("a")["values"]
values_b = oracle._populate_space("b")["values"]

assert set(values_a.keys()) == set(values_b.keys())
assert not all([values_a[key] == values_b[key] for key in values_a])


Expand Down

0 comments on commit c270e59

Please sign in to comment.