Skip to content

Commit

Permalink
Merge 11b70bd into 5f12932
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Sep 24, 2019
2 parents 5f12932 + 11b70bd commit 3e571b5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 17 deletions.
48 changes: 39 additions & 9 deletions autokeras/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,22 +194,19 @@ def predict(self, x, batch_size=32, **kwargs):
"""Predict the output for a given testing data.
# Arguments
x: tf.data.Dataset or numpy.ndarray. Testing data.
x: Any allowed types according to the input node. Testing data.
batch_size: Int. Defaults to 32.
**kwargs: Any arguments supported by keras.Model.predict.
# Returns
A list of numpy.ndarray objects or a single numpy.ndarray.
The predicted results.
"""
best_model = self.tuner.get_best_models(1)[0]
best_trial = self.tuner.get_best_trials(1)[0]
best_hp = best_trial.hyperparameters

self.tuner.load_trial(best_trial)
x = self._process_xy(x, predict=True)
x = self.hypermodel.preprocess(best_hp, x)
x = x.batch(batch_size)
best_model, x = self._prepare_best_model_and_data(
x=x,
y=None,
batch_size=batch_size,
predict=True)
y = best_model.predict(x, **kwargs)
y = self._postprocess(y)
if isinstance(y, list) and len(y) == 1:
Expand All @@ -225,6 +222,39 @@ def _postprocess(self, y):
new_y.append(temp_y)
return new_y

def evaluate(self, x, y=None, batch_size=32, **kwargs):
"""Evaluate the best model for the given data.
# Arguments
x: Any allowed types according to the input node. Testing data.
y: Any allowed types according to the head. Testing targets.
Defaults to None.
batch_size: Int. Defaults to 32.
**kwargs: Any arguments supported by keras.Model.evaluate.
# Returns
Scalar test loss (if the model has a single output and no metrics) or
list of scalars (if the model has multiple outputs and/or metrics).
The attribute model.metrics_names will give you the display labels for
the scalar outputs.
"""
best_model, data = self._prepare_best_model_and_data(
x=x,
y=y,
batch_size=batch_size)
return best_model.evaluate(data, **kwargs)

def _prepare_best_model_and_data(self, x, y, batch_size, predict=False):
best_model = self.tuner.get_best_models(1)[0]
best_trial = self.tuner.get_best_trials(1)[0]
best_hp = best_trial.hyperparameters

self.tuner.load_trial(best_trial)
x = self._process_xy(x, y, predict=predict)
x = self.hypermodel.preprocess(best_hp, x)
x = x.batch(batch_size)
return best_model, x


class GraphAutoModel(AutoModel):
"""A HyperModel defined by a graph of HyperBlocks.
Expand Down
32 changes: 24 additions & 8 deletions tests/test_auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def tmp_dir(tmpdir_factory):
return tmpdir_factory.mktemp('test_auto_model')


def test_graph_auto_model_basic(tmp_dir):
def test_basics(tmp_dir):
x_train = np.random.rand(100, 32)
y_train = np.random.rand(100, 1)

Expand All @@ -21,16 +21,32 @@ def test_graph_auto_model_basic(tmp_dir):
output_node = ak.DenseBlock()(output_node)
output_node = ak.RegressionHead()(output_node)

graph = ak.GraphAutoModel(input_node,
output_node,
directory=tmp_dir,
max_trials=1)
graph.fit(x_train, y_train, epochs=1, validation_data=(x_train, y_train))
result = graph.predict(x_train)

auto_model = ak.GraphAutoModel(input_node,
output_node,
directory=tmp_dir,
max_trials=1)
auto_model.fit(x_train, y_train, epochs=1, validation_data=(x_train, y_train))
result = auto_model.predict(x_train)
assert result.shape == (100, 1)


def test_evaluate(tmp_dir):
x_train = np.random.rand(100, 32)
y_train = np.random.rand(100, 1)

input_node = ak.Input()
output_node = input_node
output_node = ak.DenseBlock()(output_node)
output_node = ak.RegressionHead()(output_node)

auto_model = ak.GraphAutoModel(input_node,
output_node,
directory=tmp_dir,
max_trials=1)
auto_model.fit(x_train, y_train, epochs=1, validation_data=(x_train, y_train))
auto_model.evaluate(x_train, y_train)


def test_merge(tmp_dir):
x_train = np.random.rand(100, 33)
y_train = np.random.rand(100, 1)
Expand Down

0 comments on commit 3e571b5

Please sign in to comment.