Skip to content

Commit

Permalink
random searcher
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Dec 22, 2017
1 parent c00cdea commit 640509a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
2 changes: 1 addition & 1 deletion autokeras/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fit(self, x_train, y_train):
x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.33, random_state=42)

pickle.dump(self, open(os.path.join(self.path, 'classifier'), 'wb'))
self.model_id = self.searcher.generate(x_train, y_train, x_test, y_test)
self.model_id = self.searcher.search(x_train, y_train, x_test, y_test)

def predict(self, x_test):
model = self.searcher.load_best_model()
Expand Down
56 changes: 40 additions & 16 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def generate(self):
return model


class HillClimbingSearcher:
class Searcher:
def __init__(self, n_classes, input_shape, path, verbose):
self.n_classes = n_classes
self.input_shape = input_shape
Expand All @@ -82,6 +82,43 @@ def __init__(self, n_classes, input_shape, path, verbose):
self.path = path
self.model_count = 0

def search(self, x_train, y_train, x_test, y_test):
pass

def load_best_model(self):
pass

def add_model(self, model, x_train, y_train, x_test, y_test):
model.compile(loss=categorical_crossentropy,
optimizer=Adadelta(),
metrics=['accuracy'])
model.summary()
ModelTrainer(model, x_train, y_train, x_test, y_test, self.verbose).train_model()
loss, accuracy = model.evaluate(x_test, y_test)
model.save(os.path.join(self.path, str(self.model_count) + '.h5'))
self.history.append({'model_id': self.model_count, 'loss': loss, 'accuracy': accuracy})
self.history_configs.append(extract_config(model))
self.model_count += 1
pickle.dump(self, open(os.path.join(self.path, 'searcher'), 'wb'))


class RandomSearcher(Searcher):
def __init__(self, n_classes, input_shape, path, verbose):
super().__init__(n_classes, input_shape, path, verbose)

def search(self, x_train, y_train, x_test, y_test):
# First model is randomly generated.
while self.model_count < constant.MAX_MODEL_NUM:
model = RandomConvClassifierGenerator(self.n_classes, self.input_shape).generate()
self.add_model(model, x_train, y_train, x_test, y_test)

return self.load_best_model()


class HillClimbingSearcher(Searcher):
def __init__(self, n_classes, input_shape, path, verbose):
super().__init__(n_classes, input_shape, path, verbose)

def _remove_duplicate(self, models):
"""
Remove the duplicate in the history_models
Expand All @@ -95,9 +132,9 @@ def _remove_duplicate(self, models):
ans.append(model_a)
return ans

def generate(self, x_train, y_train, x_test, y_test):
def search(self, x_train, y_train, x_test, y_test):
# First model is randomly generated.
if not self.history:
# First model is randomly generated.
model = RandomConvClassifierGenerator(self.n_classes, self.input_shape).generate()
self.add_model(model, x_train, y_train, x_test, y_test)

Expand All @@ -119,19 +156,6 @@ def generate(self, x_train, y_train, x_test, y_test):

return self.load_best_model()

def add_model(self, model, x_train, y_train, x_test, y_test):
model.compile(loss=categorical_crossentropy,
optimizer=Adadelta(),
metrics=['accuracy'])
model.summary()
ModelTrainer(model, x_train, y_train, x_test, y_test, self.verbose).train_model()
loss, accuracy = model.evaluate(x_test, y_test)
model.save(os.path.join(self.path, str(self.model_count) + '.h5'))
self.history.append({'model_id': self.model_count, 'loss': loss, 'accuracy': accuracy})
self.history_configs.append(extract_config(model))
self.model_count += 1
pickle.dump(self, open(os.path.join(self.path, 'searcher'), 'wb'))

def load_best_model(self):
model_id = max(self.history, key=lambda x: x['accuracy'])['model_id']
return load_model(os.path.join(self.path, str(model_id) + '.h5'))
15 changes: 14 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,17 @@ def test_hill_climbing_classifier_generator(_):
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
generator = HillClimbingSearcher(3, (28, 28, 1), verbose=False, path=constant.DEFAULT_SAVE_PATH)
generator.generate(x_train, y_train, x_test, y_test)
generator.search(x_train, y_train, x_test, y_test)


def test_random_searcher():
x_train = np.random.rand(2, 28, 28, 1)
y_train = np.random.rand(2, 3)
x_test = np.random.rand(1, 28, 28, 1)
y_test = np.random.rand(1, 3)

constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
generator = RandomSearcher(3, (28, 28, 1), verbose=False, path=constant.DEFAULT_SAVE_PATH)
generator.search(x_train, y_train, x_test, y_test)

0 comments on commit 640509a

Please sign in to comment.