Skip to content

Commit

Permalink
fixing prediction for multi io
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed May 24, 2020
1 parent 3ffe63e commit 41542e4
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 50 deletions.
68 changes: 50 additions & 18 deletions autokeras/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,40 +270,51 @@ def _adapt(sources, fit, hms, adapters):
for source, hm, adapter in zip(sources, hms, adapters):
if fit:
source = adapter.fit_transform(source)
hm.config_from_adapter(adapter)
else:
source = adapter.transform(source)
adapted.append(source)
hm.config_from_adapter(adapter)
if len(adapted) == 1:
return adapted[0]
return tf.data.Dataset.zip(tuple(adapted))

def _process_xy(self, x, y, fit):
def _process_xy(self, x, y, fit=False, validation=False, predict=False):
"""Convert x, y to tf.data.Dataset.
# Arguments
x: Any type allowed by the corresponding input node.
y: Any type allowed by the corresponding head.
fit: Boolean. Whether to fit the type converter with the provided data.
validation: Boolean. Whether it is validation data or not.
predict: Boolean. True means the data doesn't contain y.
# Returns
A tf.data.Dataset containing both x and y.
"""
self._check_data_format(x, y, validation=validation, predict=predict)
if isinstance(x, tf.data.Dataset):
dataset = x
x = dataset.map(lambda a, b: a)
y = dataset.map(lambda a, b: b)
if not predict:
y = dataset.map(lambda a, b: b)
y = [y.map(lambda *a: nest.flatten(a)[index])
for index in range(len(self.outputs))]
x = dataset.map(lambda a, b: a)
x = [x.map(lambda *a: nest.flatten(a)[index])
for index in range(len(self.inputs))]
y = [y.map(lambda *a: nest.flatten(a)[index])
for index in range(len(self.outputs))]

x = self._adapt(x, fit, self.inputs, self._input_adapters)
y = self._adapt(y, fit, self._heads, self._output_adapters)
if not predict:
y = self._adapt(y, fit, self._heads, self._output_adapters)

return tf.data.Dataset.zip((x, y))
if not predict:
return tf.data.Dataset.zip((x, y))

def _check_data_format(self, x, y, validation=False):
if len(self.inputs) == 1:
return x

return x.map(lambda *x: (x, ))

def _check_data_format(self, x, y, validation=False, predict=False):
"""Check if the dataset has the same number of IOs with the model."""
if validation:
in_val = ' in validation_data'
Expand All @@ -315,12 +326,16 @@ def _check_data_format(self, x, y, validation=False):
'tf.data.Dataset{in_val}.'.format(in_val=in_val))

if isinstance(x, tf.data.Dataset):
x_shapes, y_shapes = data_utils.dataset_shape(x)
x_shapes = nest.flatten(x_shapes)
y_shapes = nest.flatten(y_shapes)
if not predict:
x_shapes, y_shapes = data_utils.dataset_shape(x)
x_shapes = nest.flatten(x_shapes)
y_shapes = nest.flatten(y_shapes)
else:
x_shapes = nest.flatten(data_utils.dataset_shape(x))
else:
x_shapes = [a.shape for a in nest.flatten(x)]
y_shapes = [a.shape for a in nest.flatten(y)]
if not predict:
y_shapes = [a.shape for a in nest.flatten(y)]

if len(x_shapes) != len(self.inputs):
raise ValueError(
Expand All @@ -329,7 +344,7 @@ def _check_data_format(self, x, y, validation=False):
in_val=in_val,
input_num=len(self.inputs),
data_num=len(x_shapes)))
if len(y_shapes) != len(self.outputs):
if not predict and len(y_shapes) != len(self.outputs):
raise ValueError(
'Expect y{in_val} to have {output_num} arrays, '
'but got {data_num}'.format(
Expand All @@ -346,16 +361,15 @@ def _prepare_data(self, x, y, validation_data, validation_split):
# TODO: Handle other types of input, zip dataset, tensor, dict.
# Prepare the dataset.
self._check_data_format(x, y)
dataset = self._process_xy(x, y, True)
dataset = self._process_xy(x, y, fit=True)
if validation_data:
self._split_dataset = False
if isinstance(validation_data, tf.data.Dataset):
x_val = validation_data
y_val = None
else:
x_val, y_val = validation_data
self._check_data_format(x_val, y_val, validation=True)
validation_data = self._process_xy(x_val, y_val, False)
validation_data = self._process_xy(x_val, y_val, validation=True)
# Split the data with validation_split.
if validation_data is None and validation_split:
self._split_dataset = True
Expand All @@ -364,6 +378,23 @@ def _prepare_data(self, x, y, validation_data, validation_split):
validation_split)
return dataset, validation_data

def _get_x(self, dataset):
"""Remove y from the tf.data.Dataset if exists."""
shapes = data_utils.dataset_shape(dataset)
# Only one or less element in the first level.
if len(shapes) <= 1:
return dataset.map(lambda *x: x[0])
# The first level has more than 1 element.
# The nest has 2 levels.
for shape in shapes:
if isinstance(shape, tuple):
return dataset.map(lambda x, y: x)
# The nest has one level.
# It matches the single IO case.
if len(shapes) == 2 and len(self.inputs) == 1 and len(self.outputs) == 1:
return dataset.map(lambda x, y: x)
return dataset

def predict(self, x, batch_size=32, **kwargs):
"""Predict the output for a given testing data.
Expand All @@ -377,7 +408,8 @@ def predict(self, x, batch_size=32, **kwargs):
The predicted results.
"""
if isinstance(x, tf.data.Dataset):
dataset = self._process_xy(x, None, False)
x = self._get_x(x)
dataset = self._process_xy(x, None, predict=True)
else:
dataset = self._adapt(x, False, self.inputs, self._input_adapters)
dataset = dataset.batch(batch_size)
Expand Down
90 changes: 58 additions & 32 deletions tests/autokeras/auto_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def test_evaluate(tuner_fn, tmp_path):
assert tuner_fn.called


def get_single_io_auto_model(tmp_path):
return ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2)


@mock.patch('autokeras.auto_model.get_tuner_class')
def test_auto_model_predict(tuner_fn, tmp_path):
x_train = np.random.rand(100, 32, 32, 3)
y_train = np.random.rand(100, 1)

auto_model = ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2)
auto_model = get_single_io_auto_model(tmp_path)
auto_model.fit(x_train, y_train, epochs=2, validation_split=0.2)
auto_model.predict(x_train)
assert tuner_fn.called
Expand All @@ -48,10 +52,7 @@ def test_final_fit_concat(tuner_fn, tmp_path):
x_train = np.random.rand(100, 32, 32, 3)
y_train = np.random.rand(100, 1)

auto_model = ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2)
auto_model = get_single_io_auto_model(tmp_path)
auto_model.fit(x_train, y_train, epochs=2, validation_split=0.2)
assert auto_model._split_dataset
assert tuner.search.call_args_list[0][1]['fit_on_val_data']
Expand All @@ -64,10 +65,7 @@ def test_final_fit_not_concat(tuner_fn, tmp_path):
x_train = np.random.rand(100, 32, 32, 3)
y_train = np.random.rand(100, 1)

auto_model = ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2)
auto_model = get_single_io_auto_model(tmp_path)
auto_model.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
assert not auto_model._split_dataset
assert not tuner.search.call_args_list[0][1]['fit_on_val_data']
Expand All @@ -80,11 +78,7 @@ def test_overwrite(tuner_fn, tmp_path):
x_train = np.random.rand(100, 32, 32, 3)
y_train = np.random.rand(100, 1)

auto_model = ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2,
overwrite=False)
auto_model = get_single_io_auto_model(tmp_path)
auto_model.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
assert not tuner_class.call_args_list[0][1]['overwrite']

Expand All @@ -97,23 +91,23 @@ def test_export_model(tuner_fn, tmp_path):
x_train = np.random.rand(100, 32, 32, 3)
y_train = np.random.rand(100, 1)

auto_model = ak.AutoModel(ak.ImageInput(),
ak.RegressionHead(),
directory=tmp_path,
max_trials=2,
overwrite=False)
auto_model = get_single_io_auto_model(tmp_path)
auto_model.fit(x_train, y_train, epochs=2, validation_data=(x_train, y_train))
auto_model.export_model()
assert tuner.get_best_model.called


def get_multi_io_auto_model(tmp_path):
return ak.AutoModel([ak.ImageInput(), ak.ImageInput()],
[ak.RegressionHead(), ak.RegressionHead()],
directory=tmp_path,
max_trials=2,
overwrite=False)


@mock.patch('autokeras.auto_model.get_tuner_class')
def test_multi_io_with_tf_dataset(tuner_fn, tmp_path):
auto_model = ak.AutoModel([ak.ImageInput(), ak.ImageInput()],
[ak.RegressionHead(), ak.RegressionHead()],
directory=tmp_path,
max_trials=2,
overwrite=False)
auto_model = get_multi_io_auto_model(tmp_path)
x1 = utils.generate_data()
y1 = utils.generate_data(shape=(1,))
dataset = tf.data.Dataset.from_tensor_slices(((x1, x1), (y1, y1)))
Expand All @@ -140,11 +134,7 @@ def test_single_nested_dataset(tuner_fn, tmp_path):


def dataset_error(x, y, validation_data, message, tmp_path):
auto_model = ak.AutoModel([ak.ImageInput(), ak.ImageInput()],
[ak.RegressionHead(), ak.RegressionHead()],
directory=tmp_path,
max_trials=2,
overwrite=False)
auto_model = get_multi_io_auto_model(tmp_path)
with pytest.raises(ValueError) as info:
auto_model.fit(x, y, epochs=2, validation_data=validation_data)
assert message in str(info.value)
Expand Down Expand Up @@ -185,3 +175,39 @@ def test_dataset_and_y(tuner_fn, tmp_path):
val_dataset = tf.data.Dataset.from_tensor_slices(((x1,), (y1, y1)))
dataset_error(x, y, val_dataset,
'Expect y is None', tmp_path)


@mock.patch('autokeras.auto_model.get_tuner_class')
def test_multi_input_predict(tuner_fn, tmp_path):
auto_model = get_multi_io_auto_model(tmp_path)
x1 = utils.generate_data()
y1 = utils.generate_data(shape=(1,))
dataset = tf.data.Dataset.from_tensor_slices(((x1, x1), (y1, y1)))
auto_model.fit(dataset, None, epochs=2, validation_data=dataset)

dataset2 = tf.data.Dataset.from_tensor_slices(((x1, x1),))
auto_model.predict(dataset2)


@mock.patch('autokeras.auto_model.get_tuner_class')
def test_multi_input_predict2(tuner_fn, tmp_path):
auto_model = get_multi_io_auto_model(tmp_path)
x1 = utils.generate_data()
y1 = utils.generate_data(shape=(1,))
dataset = tf.data.Dataset.from_tensor_slices(((x1, x1), (y1, y1)))
auto_model.fit(dataset, None, epochs=2, validation_data=dataset)

dataset2 = tf.data.Dataset.from_tensor_slices((x1, x1))
auto_model.predict(dataset2)


@mock.patch('autokeras.auto_model.get_tuner_class')
def test_single_input_predict(tuner_fn, tmp_path):
auto_model = get_single_io_auto_model(tmp_path)
x1 = utils.generate_data()
y1 = utils.generate_data(shape=(1,))
dataset = tf.data.Dataset.from_tensor_slices((x1, y1))
auto_model.fit(dataset, None, epochs=2, validation_data=dataset)

dataset2 = tf.data.Dataset.from_tensor_slices((x1, y1))
auto_model.predict(dataset2)

0 comments on commit 41542e4

Please sign in to comment.