Skip to content

Commit

Permalink
Adopting the tf.keras preprocessing layers (#922)
Browse files Browse the repository at this point in the history
* temp

* temp

* temp

* update

* fixing tests

* layer

* update

* update

* bug fix

* update

* refactor and fixed integration tests

* update unit tests

* update ci to use tf-nightly

* Adapter (#930)

* adapter

* update

* update

* update

* update

* update

* Adapter (#932)

* adapter

* update

* update

* update

* update

* update

* docs

* addressing comments
  • Loading branch information
haifeng-jin committed Jan 30, 2020
1 parent d63648e commit 6d7208e
Show file tree
Hide file tree
Showing 84 changed files with 3,561 additions and 4,758 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Expand Up @@ -18,7 +18,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[tests] --progress-bar off
pip install tensorflow-cpu
pip install tf-nightly
- name: Lint with flake8
run: |
flake8
Expand Down
68 changes: 33 additions & 35 deletions autokeras/__init__.py
@@ -1,39 +1,37 @@
from autokeras.auto_model import AutoModel
from autokeras.hypermodel.base import Block
from autokeras.hypermodel.base import Head
from autokeras.hypermodel.base import HyperBlock
from autokeras.hypermodel.base import Node
from autokeras.hypermodel.base import Preprocessor
from autokeras.hypermodel.block import ConvBlock
from autokeras.hypermodel.block import DenseBlock
from autokeras.hypermodel.block import EmbeddingBlock
from autokeras.hypermodel.block import Merge
from autokeras.hypermodel.block import ResNetBlock
from autokeras.hypermodel.block import RNNBlock
from autokeras.hypermodel.block import SpatialReduction
from autokeras.hypermodel.block import TemporalReduction
from autokeras.hypermodel.block import XceptionBlock
from autokeras.hypermodel.head import ClassificationHead
from autokeras.hypermodel.head import RegressionHead
from autokeras.hypermodel.hyperblock import ImageBlock
from autokeras.hypermodel.hyperblock import StructuredDataBlock
from autokeras.hypermodel.hyperblock import TextBlock
from autokeras.hypermodel.node import ImageInput
from autokeras.hypermodel.node import Input
from autokeras.hypermodel.node import StructuredDataInput
from autokeras.hypermodel.node import TextInput
from autokeras.hypermodel.preprocessor import FeatureEngineering
from autokeras.hypermodel.preprocessor import ImageAugmentation
from autokeras.hypermodel.preprocessor import LightGBM
from autokeras.hypermodel.preprocessor import Normalization
from autokeras.hypermodel.preprocessor import TextToIntSequence
from autokeras.hypermodel.preprocessor import TextToNgramVector
from autokeras.task import ImageClassifier
from autokeras.task import ImageRegressor
from autokeras.task import StructuredDataClassifier
from autokeras.task import StructuredDataRegressor
from autokeras.task import TextClassifier
from autokeras.task import TextRegressor
from autokeras.engine.block import Block
from autokeras.engine.head import Head
from autokeras.engine.node import Node
from autokeras.hypermodels import CategoricalToNumerical
from autokeras.hypermodels import ClassificationHead
from autokeras.hypermodels import ConvBlock
from autokeras.hypermodels import DenseBlock
from autokeras.hypermodels import Embedding
from autokeras.hypermodels import Flatten
from autokeras.hypermodels import ImageAugmentation
from autokeras.hypermodels import ImageBlock
from autokeras.hypermodels import Merge
from autokeras.hypermodels import Normalization
from autokeras.hypermodels import RegressionHead
from autokeras.hypermodels import ResNetBlock
from autokeras.hypermodels import RNNBlock
from autokeras.hypermodels import SpatialReduction
from autokeras.hypermodels import StructuredDataBlock
from autokeras.hypermodels import TemporalReduction
from autokeras.hypermodels import TextBlock
from autokeras.hypermodels import TextToIntSequence
from autokeras.hypermodels import TextToNgramVector
from autokeras.hypermodels import XceptionBlock
from autokeras.nodes import ImageInput
from autokeras.nodes import Input
from autokeras.nodes import StructuredDataInput
from autokeras.nodes import TextInput
from autokeras.tasks import ImageClassifier
from autokeras.tasks import ImageRegressor
from autokeras.tasks import StructuredDataClassifier
from autokeras.tasks import StructuredDataRegressor
from autokeras.tasks import TextClassifier
from autokeras.tasks import TextRegressor

from .utils import check_tf_version

Expand Down
8 changes: 8 additions & 0 deletions autokeras/adapters/__init__.py
@@ -0,0 +1,8 @@
from autokeras.adapters.input_adapter import CATEGORICAL
from autokeras.adapters.input_adapter import NUMERICAL
from autokeras.adapters.input_adapter import ImageInputAdapter
from autokeras.adapters.input_adapter import InputAdapter
from autokeras.adapters.input_adapter import StructuredDataInputAdapter
from autokeras.adapters.input_adapter import TextInputAdapter
from autokeras.adapters.output_adapter import ClassificationHeadAdapter
from autokeras.adapters.output_adapter import RegressionHeadAdapter
193 changes: 193 additions & 0 deletions autokeras/adapters/input_adapter.py
@@ -0,0 +1,193 @@
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.util import nest

from autokeras.engine import adapter as adapter_module

CATEGORICAL = 'categorical'
NUMERICAL = 'numerical'


class InputAdapter(adapter_module.Adapter):

def check(self, x):
"""Record any information needed by transform."""
if not isinstance(x, (np.ndarray, tf.data.Dataset)):
raise TypeError('Expect the data to Input to be numpy.ndarray or '
'tf.data.Dataset, but got {type}.'.format(type=type(x)))
if isinstance(x, np.ndarray) and not np.issubdtype(x.dtype, np.number):
raise TypeError('Expect the data to Input to be numerical, but got '
'{type}.'.format(type=x.dtype))


class ImageInputAdapter(adapter_module.Adapter):

def check(self, x):
"""Record any information needed by transform."""
if not isinstance(x, (np.ndarray, tf.data.Dataset)):
raise TypeError('Expect the data to ImageInput to be numpy.ndarray or '
'tf.data.Dataset, but got {type}.'.format(type=type(x)))
if isinstance(x, np.ndarray) and x.ndim not in [3, 4]:
raise ValueError('Expect the data to ImageInput to have 3 or 4 '
'dimensions, but got input shape {shape} with {ndim} '
'dimensions'.format(shape=x.shape, ndim=x.ndim))
if isinstance(x, np.ndarray) and not np.issubdtype(x.dtype, np.number):
raise TypeError('Expect the data to ImageInput to be numerical, but got '
'{type}.'.format(type=x.dtype))

def convert_to_dataset(self, x):
if isinstance(x, np.ndarray):
if x.ndim == 3:
x = np.expand_dims(x, axis=3)
return super().convert_to_dataset(x)


class TextInputAdapter(adapter_module.Adapter):

def check(self, x):
"""Record any information needed by transform."""
if not isinstance(x, (np.ndarray, tf.data.Dataset)):
raise TypeError('Expect the data to TextInput to be numpy.ndarray or '
'tf.data.Dataset, but got {type}.'.format(type=type(x)))

if isinstance(x, np.ndarray) and x.ndim != 1:
raise ValueError('Expect the data to TextInput to have 1 dimension, but '
'got input shape {shape} with {ndim} dimensions'.format(
shape=x.shape,
ndim=x.ndim))
if isinstance(x, np.ndarray) and not np.issubdtype(x.dtype, np.character):
raise TypeError('Expect the data to TextInput to be strings, but got '
'{type}.'.format(type=x.dtype))

def convert_to_dataset(self, x):
if len(x.shape) == 1:
x = x.reshape(-1, 1)
if isinstance(x, np.ndarray):
x = tf.data.Dataset.from_tensor_slices(x)
return x


class StructuredDataInputAdapter(adapter_module.Adapter):

def __init__(self, column_names=None, column_types=None, **kwargs):
super().__init__(**kwargs)
self.column_names = column_names
self.column_types = column_types
# Variables for inferring column types.
self.count_nan = None
self.count_numerical = None
self.count_categorical = None
self.count_unique_numerical = []
self.num_col = None

def get_config(self):
config = super().get_config()
config.update({
'count_nan': self.count_nan,
'count_numerical': self.count_numerical,
'count_categorical': self.count_categorical,
'count_unique_numerical': self.count_unique_numerical,
'num_col': self.num_col
})
return config

@classmethod
def from_config(cls, config):
obj = super().from_config(config)
obj.count_nan = config['count_nan']
obj.count_numerical = config['count_numerical']
obj.count_categorical = config['count_categorical']
obj.count_unique_numerical = config['count_unique_numerical']
obj.num_col = config['num_col']

def check(self, x):
if not isinstance(x, (pd.DataFrame, np.ndarray)):
raise TypeError('Unsupported type {type} for '
'{name}.'.format(type=type(x),
name=self.__class__.__name__))

# Extract column_names from pd.DataFrame.
if isinstance(x, pd.DataFrame) and self.column_names is None:
self.column_names = list(x.columns)
# column_types is provided by user
if self.column_types:
for column_name in self.column_types:
if column_name not in self.column_names:
raise ValueError('Column_names and column_types are '
'mismatched. Cannot find column name '
'{name} in the data.'.format(
name=column_name))

# Generate column_names.
if self.column_names is None:
if self.column_types:
raise ValueError('Column names must be specified.')
self.column_names = [index for index in range(x.shape[1])]

# Check if column_names has the correct length.
if len(self.column_names) != x.shape[1]:
raise ValueError('Expect column_names to have length {expect} '
'but got {actual}.'.format(
expect=x.shape[1],
actual=len(self.column_names)))

def convert_to_dataset(self, x):
if isinstance(x, pd.DataFrame):
# Convert x, y, validation_data to tf.Dataset.
x = x.values.astype(np.unicode)
if isinstance(x, np.ndarray):
x = x.astype(np.unicode)
dataset = tf.data.Dataset.from_tensor_slices(x)
return dataset

def fit(self, dataset):
super().fit(dataset)
for x in dataset:
self.update(x)
self.infer_column_types()

def update(self, x):
# Calculate the statistics.
x = nest.flatten(x)[0].numpy()
if self.num_col is None:
self.num_col = len(x)
self.count_nan = np.zeros(self.num_col)
self.count_numerical = np.zeros(self.num_col)
self.count_categorical = np.zeros(self.num_col)
for i in range(len(x)):
self.count_unique_numerical.append({})
for i in range(self.num_col):
x[i] = x[i].decode('utf-8')
if x[i] == 'nan':
self.count_nan[i] += 1
elif x[i] == 'True':
self.count_categorical[i] += 1
elif x[i] == 'False':
self.count_categorical[i] += 1
else:
try:
tmp_num = float(x[i])
self.count_numerical[i] += 1
if tmp_num not in self.count_unique_numerical[i]:
self.count_unique_numerical[i][tmp_num] = 1
else:
self.count_unique_numerical[i][tmp_num] += 1
except ValueError:
self.count_categorical[i] += 1

def infer_column_types(self):
column_types = {}
for i in range(self.num_col):
if self.count_categorical[i] > 0:
column_types[self.column_names[i]] = CATEGORICAL
elif len(self.count_unique_numerical[i])/self.count_numerical[i] < 0.05:
column_types[self.column_names[i]] = CATEGORICAL
else:
column_types[self.column_names[i]] = NUMERICAL
# Partial column_types is provided.
if self.column_types is None:
self.column_types = {}
for key, value in column_types.items():
if key not in self.column_types:
self.column_types[key] = value

0 comments on commit 6d7208e

Please sign in to comment.