Skip to content

Commit

Permalink
test stub and search tree
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Mar 11, 2018
1 parent ef1a41a commit 44f5e54
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 30 deletions.
22 changes: 17 additions & 5 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def search(self, x_train, y_train, x_test, y_test):
return self.load_best_model()

def maximize_acq(self, model_ids):
# TODO: implement it
overall_max_acq_value = 0
father_id = None
target_graph = None
Expand Down Expand Up @@ -211,16 +210,29 @@ def maximize_acq(self, model_ids):
return nm_graph.produce_model(), father_id

def _acq(self, graph):
print(self)
print(graph)
return 0


class SearchTree:
# TODO: implement search tree
def __init__(self):
self.nodes = None
self.root = None
self.adj_list = {}

def add_child(self, u, v):
pass
if u == -1:
self.root = v
self.adj_list[v] = []
return
if v not in self.adj_list[u]:
self.adj_list[u].append(v)
if v not in self.adj_list:
self.adj_list[v] = []

def get_leaves(self):
return self.nodes
ret = []
for key, value in self.adj_list.items():
if not value:
ret.append(key)
return ret
51 changes: 28 additions & 23 deletions autokeras/stub.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
from keras.layers import Dense, Concatenate, BatchNormalization, Activation
from keras.engine import InputLayer
from keras.layers import Dense, Concatenate, BatchNormalization, Activation, Flatten

from autokeras.layers import WeightedAdd
from autokeras.utils import is_conv_layer


class StubBatchNormalization:
class StubLayer:
def __init__(self, input_node=None, output_node=None):
self.input = input_node
self.output = output_node


class StubDense:
class StubBatchNormalization(StubLayer):
pass


class StubDense(StubLayer):
def __init__(self, units, input_node=None, output_node=None):
super().__init__(input_node, output_node)
self.units = units
self.input = input_node
self.output = output_node


class StubConv:
class StubConv(StubLayer):
def __init__(self, filters, input_node=None, output_node=None):
super().__init__(input_node, output_node)
self.filters = filters
self.input = input_node
self.output = output_node


class StubWeightedAdd:
class StubAggregateLayer(StubLayer):
def __init__(self, input_nodes=None, output_node=None):
if input_nodes is None:
input_nodes = []
self.input = input_nodes
self.output = output_node
super().__init__(input_nodes, output_node)


class StubConcatenate:
def __init__(self, input_nodes=None, output_node=None):
if input_nodes is None:
input_nodes = []
self.input = input_nodes
self.output = output_node
class StubConcatenate(StubAggregateLayer):
pass


class StubActivation:
def __init__(self, input_node=None, output_node=None):
self.input = input_node
self.output = output_node
class StubWeightedAdd(StubAggregateLayer):
pass


class StubActivation(StubLayer):
pass


class StubModel:
Expand Down Expand Up @@ -77,7 +77,6 @@ def to_stub_model(model):
input_id = node_to_id[layer.input]
output_id = node_to_id[layer.output]

temp_stub_layer = None
if is_conv_layer(layer):
temp_stub_layer = StubConv(layer.filters, input_id, output_id)
elif isinstance(layer, Dense):
Expand All @@ -90,6 +89,12 @@ def to_stub_model(model):
temp_stub_layer = StubBatchNormalization(input_id, output_id)
elif isinstance(layer, Activation):
temp_stub_layer = StubActivation(input_id, output_id)
elif isinstance(layer, InputLayer):
temp_stub_layer = StubLayer(input_id, output_id)
elif isinstance(layer, Flatten):
temp_stub_layer = StubLayer(input_id, output_id)
else:
raise TypeError("The layer {} is illegal.".format(layer))
ret.add_layer(temp_stub_layer)

return ret
return ret
9 changes: 9 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@ def test_random_searcher(_):
generator.search(x_train, y_train, x_test, y_test)
assert len(generator.history) == len(generator.history_configs)


# TODO: Test Bayesian Search

def test_search_tree():
tree = SearchTree()
tree.add_child(-1, 0)
tree.add_child(0, 1)
tree.add_child(0, 2)
assert len(tree.adj_list) == 3
22 changes: 20 additions & 2 deletions tests/test_stub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import pytest

from autokeras.stub import *
from tests.common import get_add_skip_model, get_concat_skip_model


def test_to_stub_model():
# TODO: Implement
pass
model = get_add_skip_model()
stub_model = to_stub_model(model)
assert len(stub_model.layers) == 18


def test_to_stub_model2():
model = get_concat_skip_model()
stub_model = to_stub_model(model)
assert len(stub_model.layers) == 18


def test_to_stub_model_exception():
model = get_concat_skip_model()
stub_model = to_stub_model(model)
with pytest.raises(Exception) as e:
to_stub_model(stub_model)
assert e.type is TypeError

0 comments on commit 44f5e54

Please sign in to comment.