From 7453529b653f2ea333848adf500dab4c6800d48c Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Sat, 20 Oct 2018 19:55:59 -0500 Subject: [PATCH] final fit --- autokeras/cnn_module.py | 1 + autokeras/search.py | 6 +++--- tests/image/test_image_supervised.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/autokeras/cnn_module.py b/autokeras/cnn_module.py index 77044ee37..74abbf89f 100644 --- a/autokeras/cnn_module.py +++ b/autokeras/cnn_module.py @@ -80,6 +80,7 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False): if retrain: graph.weighted = False _, _1, graph = train((graph, train_data, test_data, trainer_args, None, self.metric, self.loss, self.verbose)) + searcher.replace_model(graph, searcher.get_best_model_id()) @property def best_model(self): diff --git a/autokeras/search.py b/autokeras/search.py index 3dbc626dc..a38f39afc 100644 --- a/autokeras/search.py +++ b/autokeras/search.py @@ -89,7 +89,7 @@ def __init__(self, n_output_node, input_shape, path, metric, loss, verbose, self.bo = BayesianOptimizer(self, t_min, metric, kernel_lambda, beta) def load_model_by_id(self, model_id): - return pickle_from_file(os.path.join(self.path, str(model_id) + '.h5')) + return pickle_from_file(os.path.join(self.path, str(model_id) + '.graph')) def load_best_model(self): return self.load_model_by_id(self.get_best_model_id()) @@ -106,13 +106,13 @@ def get_best_model_id(self): return min(self.history, key=lambda x: x['metric_value'])['model_id'] def replace_model(self, graph, model_id): - pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.h5')) + pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph')) def add_model(self, metric_value, loss, graph, model_id): if self.verbose: print('\nSaving model.') - pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.h5')) + pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph')) # Update best_model text file ret = {'model_id': model_id, 'loss': loss, 'metric_value': metric_value} diff --git a/tests/image/test_image_supervised.py b/tests/image/test_image_supervised.py index e96068759..286f566bc 100644 --- a/tests/image/test_image_supervised.py +++ b/tests/image/test_image_supervised.py @@ -218,7 +218,7 @@ def test_export_keras_model(_): score = clf.evaluate(train_x, train_y) assert score <= 1.0 - model_file_name = os.path.join(path, 'test_keras_model.h5') + model_file_name = os.path.join(path, 'test_keras_model.graph') clf.export_keras_model(model_file_name) from keras.models import load_model model = load_model(model_file_name) @@ -242,7 +242,7 @@ def test_export_keras_model(_): score = clf.evaluate(train_x, train_y) assert score >= 0.0 - model_file_name = os.path.join(path, 'test_keras_model.h5') + model_file_name = os.path.join(path, 'test_keras_model.graph') clf.export_keras_model(model_file_name) from keras.models import load_model model = load_model(model_file_name)