Skip to content

Commit

Permalink
bug fix using dataset for text input
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jun 9, 2020
1 parent de11192 commit dd3469e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
13 changes: 8 additions & 5 deletions autokeras/adapters/input_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tensorflow.python.util import nest

from autokeras.engine import adapter as adapter_module
from autokeras.utils import data_utils

CATEGORICAL = 'categorical'
NUMERICAL = 'numerical'
Expand Down Expand Up @@ -38,8 +39,10 @@ def check(self, x):

def convert_to_dataset(self, x):
if isinstance(x, np.ndarray):
# TODO: expand the dims after converting to Dataset.
if x.ndim == 3:
x = np.expand_dims(x, axis=3)
x = x.astype(np.float32)
return super().convert_to_dataset(x)


Expand All @@ -61,11 +64,11 @@ def check(self, x):
'{type}.'.format(type=x.dtype))

def convert_to_dataset(self, x):
if len(x.shape) == 1:
x = x.reshape(-1, 1)
if isinstance(x, np.ndarray):
x = tf.data.Dataset.from_tensor_slices(x)
return super().convert_to_dataset(x)
x = super().convert_to_dataset(x)
shape = data_utils.dataset_shape(x)
if len(shape) == 1:
x = x.map(lambda a: tf.reshape(a, [-1, 1]))
return x


class StructuredDataInputAdapter(adapter_module.Adapter):
Expand Down
3 changes: 1 addition & 2 deletions autokeras/engine/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def convert_to_dataset(self, dataset):
tf.data.Dataset. The converted dataset.
"""
if isinstance(dataset, np.ndarray):
dataset = tf.data.Dataset.from_tensor_slices(
dataset.astype(np.float32))
dataset = tf.data.Dataset.from_tensor_slices(dataset)
return data_utils.batch_dataset(dataset, self.batch_size)

def fit(self, dataset):
Expand Down
12 changes: 8 additions & 4 deletions autokeras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
from tensorflow.python.util import nest


def batch_dataset(dataset, batch_size):
def batched(dataset):
shape = nest.flatten(dataset_shape(dataset))[0]
if shape[0] is not None:
return dataset.batch(batch_size)
return dataset
return len(shape) > 0 and shape[0] is None


def batch_dataset(dataset, batch_size):
if batched(dataset):
return dataset
return dataset.batch(batch_size)


def split_dataset(dataset, validation_split):
Expand Down
34 changes: 34 additions & 0 deletions tests/autokeras/adapters/input_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from autokeras.adapters import input_adapter
from autokeras.utils import data_utils
from tests import utils


Expand Down Expand Up @@ -124,6 +125,39 @@ def test_input_numerical():
assert 'Expect the data to Input to be numerical' in str(info.value)


def test_text_dataset():
x = tf.data.Dataset.from_tensor_slices(np.array([
'a b c',
'b b c',
]))
adapter = input_adapter.TextInputAdapter()
x = adapter.transform(x)
assert data_utils.dataset_shape(x).as_list() == [None, 1]
assert isinstance(x, tf.data.Dataset)


def test_text_dataset_batch():
x = tf.data.Dataset.from_tensor_slices(np.array([
'a b c',
'b b c',
])).batch(32)
adapter = input_adapter.TextInputAdapter()
x = adapter.transform(x)
assert data_utils.dataset_shape(x).as_list() == [None, 1]
assert isinstance(x, tf.data.Dataset)


def test_text_np():
x = np.array([
'a b c',
'b b c',
])
adapter = input_adapter.TextInputAdapter()
x = adapter.transform(x)
assert data_utils.dataset_shape(x).as_list() == [None, 1]
assert isinstance(x, tf.data.Dataset)


def test_text_input_type_error():
x = 'unknown'
adapter = input_adapter.TextInputAdapter()
Expand Down

0 comments on commit dd3469e

Please sign in to comment.