Skip to content

Commit

Permalink
final fit (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Oct 21, 2018
1 parent d5f3b36 commit f6ca5a7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions autokeras/cnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False):
if retrain:
graph.weighted = False
_, _1, graph = train((graph, train_data, test_data, trainer_args, None, self.metric, self.loss, self.verbose))
searcher.replace_model(graph, searcher.get_best_model_id())

@property
def best_model(self):
Expand Down
6 changes: 3 additions & 3 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, n_output_node, input_shape, path, metric, loss, verbose,
self.bo = BayesianOptimizer(self, t_min, metric, kernel_lambda, beta)

def load_model_by_id(self, model_id):
return pickle_from_file(os.path.join(self.path, str(model_id) + '.h5'))
return pickle_from_file(os.path.join(self.path, str(model_id) + '.graph'))

def load_best_model(self):
return self.load_model_by_id(self.get_best_model_id())
Expand All @@ -106,13 +106,13 @@ def get_best_model_id(self):
return min(self.history, key=lambda x: x['metric_value'])['model_id']

def replace_model(self, graph, model_id):
pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.h5'))
pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph'))

def add_model(self, metric_value, loss, graph, model_id):
if self.verbose:
print('\nSaving model.')

pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.h5'))
pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph'))

# Update best_model text file
ret = {'model_id': model_id, 'loss': loss, 'metric_value': metric_value}
Expand Down
4 changes: 2 additions & 2 deletions tests/image/test_image_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_export_keras_model(_):
score = clf.evaluate(train_x, train_y)
assert score <= 1.0

model_file_name = os.path.join(path, 'test_keras_model.h5')
model_file_name = os.path.join(path, 'test_keras_model.graph')
clf.export_keras_model(model_file_name)
from keras.models import load_model
model = load_model(model_file_name)
Expand All @@ -242,7 +242,7 @@ def test_export_keras_model(_):
score = clf.evaluate(train_x, train_y)
assert score >= 0.0

model_file_name = os.path.join(path, 'test_keras_model.h5')
model_file_name = os.path.join(path, 'test_keras_model.graph')
clf.export_keras_model(model_file_name)
from keras.models import load_model
model = load_model(model_file_name)
Expand Down

0 comments on commit f6ca5a7

Please sign in to comment.