Skip to content

Commit

Permalink
bayesian searcher tested
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Mar 15, 2018
1 parent 50d8eeb commit 84900b6
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from autokeras import constant
import numpy as np

from tests.common import clean_dir

default_test_path = 'tests/resources/temp'


def simple_transform(_):
generator = RandomConvClassifierGenerator(input_shape=(28, 28, 1), n_classes=3)
Expand All @@ -12,15 +16,17 @@ def simple_transform(_):

@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_hill_climbing_classifier_searcher(_, _1):
def test_hill_climbing_searcher(_, _1):
x_train = np.random.rand(2, 28, 28, 1)
y_train = np.random.rand(2, 3)
x_test = np.random.rand(1, 28, 28, 1)
y_test = np.random.rand(1, 3)

constant.MAX_MODEL_NUM = 10
generator = HillClimbingSearcher(3, (28, 28, 1), verbose=False, path=constant.DEFAULT_SAVE_PATH)
clean_dir(default_test_path)
generator = HillClimbingSearcher(3, (28, 28, 1), verbose=False, path=default_test_path)
generator.search(x_train, y_train, x_test, y_test)
clean_dir(default_test_path)
assert len(generator.history) == len(generator.history_configs)


Expand All @@ -32,12 +38,28 @@ def test_random_searcher(_):
y_test = np.random.rand(1, 3)

constant.MAX_MODEL_NUM = 3
generator = RandomSearcher(3, (28, 28, 1), verbose=False, path=constant.DEFAULT_SAVE_PATH)
clean_dir(default_test_path)
generator = RandomSearcher(3, (28, 28, 1), verbose=False, path=default_test_path)
generator.search(x_train, y_train, x_test, y_test)
clean_dir(default_test_path)
assert len(generator.history) == len(generator.history_configs)


# TODO: Test Bayesian Search
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_bayesian_searcher(_, _1):
x_train = np.random.rand(2, 28, 28, 1)
y_train = np.random.rand(2, 3)
x_test = np.random.rand(1, 28, 28, 1)
y_test = np.random.rand(1, 3)

constant.MAX_MODEL_NUM = 10
clean_dir(default_test_path)
generator = BayesianSearcher(3, (28, 28, 1), verbose=False, path=default_test_path)
generator.search(x_train, y_train, x_test, y_test)
clean_dir(default_test_path)
assert len(generator.history) == len(generator.history_configs)


def test_search_tree():
tree = SearchTree()
Expand Down

0 comments on commit 84900b6

Please sign in to comment.