Skip to content

Commit

Permalink
self review update
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jul 30, 2019
1 parent aa3a6c6 commit 6b42b51
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 89 deletions.
6 changes: 3 additions & 3 deletions autokeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from autokeras.hypermodel.node import ImageInput
from autokeras.hypermodel.node import Input
from autokeras.hypermodel.node import TextInput
from autokeras.hypermodel.processor import Normalize
from autokeras.hypermodel.processor import TextToIntSequence
from autokeras.hypermodel.processor import TextToNgramVector
from autokeras.hypermodel.preprocessor import Normalize
from autokeras.hypermodel.preprocessor import TextToIntSequence
from autokeras.hypermodel.preprocessor import TextToNgramVector
from autokeras.task import ImageClassifier
from autokeras.task import ImageRegressor
from autokeras.task import TextClassifier
Expand Down
11 changes: 5 additions & 6 deletions autokeras/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import tensorflow as tf
from tensorflow.python.util import nest

import autokeras.utils
from autokeras import meta_model
from autokeras import tuner
from autokeras import utils
from autokeras.hypermodel import graph
from autokeras.hypermodel import head
from autokeras.hypermodel import node
from autokeras.hypermodel import processor
from autokeras.hypermodel import preprocessor


class AutoModel(object):
Expand Down Expand Up @@ -157,12 +158,10 @@ def predict(self, x, batch_size=32, **kwargs):
**kwargs: Any arguments supported by keras.Model.predict.
"""
best_model = self.tuner.get_best_models(1)[0]
best_hp = self.tuner.get_best_hp(1)[0]
best_trial = self.tuner.get_best_trials(1)[0]
filename = '%s-preprocessors' % best_trial.trial_id
path = os.path.join(best_trial.directory, filename)
best_hp = best_trial.hyperparameters

self.hypermodel.load_preprocessors(path)
self.tuner.load_trial(best_trial)
x = utils.prepare_preprocess(x, x)
x = self.hypermodel.preprocess(best_hp, x)
x = x.batch(batch_size)
Expand All @@ -182,7 +181,7 @@ def _label_encoding(self, y):
hyper_head = output_node.in_blocks[0]
if (isinstance(hyper_head, head.ClassificationHead) and
utils.is_label(temp_y)):
label_encoder = processor.OneHotEncoder()
label_encoder = utils.OneHotEncoder()
label_encoder.fit(y)
new_y.append(label_encoder.transform(y))
self._label_encoders.append(label_encoder)
Expand Down
21 changes: 11 additions & 10 deletions autokeras/hypermodel/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from autokeras import utils
from autokeras.hypermodel import head
from autokeras.hypermodel import hyperblock
from autokeras.hypermodel import processor
from autokeras.hypermodel import preprocessor


class GraphHyperModel(kerastuner.HyperModel):
Expand Down Expand Up @@ -73,7 +73,7 @@ def build(self, hp):
node_id = self._node_to_id[input_node]
real_nodes[node_id] = input_node.build()
for block in self._blocks:
if isinstance(block, processor.HyperPreprocessor):
if isinstance(block, preprocessor.Preprocessor):
continue
temp_inputs = [real_nodes[self._node_to_id[input_node]]
for input_node in block.inputs]
Expand Down Expand Up @@ -261,7 +261,7 @@ def preprocess(self, hp, dataset, validation_data=None, fit=False):
if self.contains_hyper_block():
return self._plain_graph_hm.preprocess(hp, dataset, validation_data, fit)
for block in self._blocks:
if isinstance(block, processor.HyperPreprocessor):
if isinstance(block, preprocessor.Preprocessor):
block.set_hp(hp)
dataset = self._preprocess(dataset, fit=fit)
if not validation_data:
Expand All @@ -280,7 +280,7 @@ def _preprocess(self, dataset, fit=False):
for block in self._blocks:
if (self._block_topo_depth[
self._block_to_id[block]] == depth and
isinstance(block, processor.HyperPreprocessor)):
isinstance(block, preprocessor.Preprocessor)):
temp_blocks.append(block)
if not temp_blocks:
break
Expand Down Expand Up @@ -312,14 +312,15 @@ def _preprocess(self, dataset, fit=False):
dataset = dataset.map(functools.partial(
self._preprocess_transform,
input_node_ids=input_node_ids,
blocks=blocks))
blocks=blocks,
fit=fit))

# Build input_node_ids for next depth.
input_node_ids = list(sorted([self._node_to_id[block.outputs[0]]
for block in blocks]))
return dataset

def _preprocess_transform(self, x, y, input_node_ids, blocks):
def _preprocess_transform(self, x, y, input_node_ids, blocks, fit=False):
x = nest.flatten(x)
id_to_data = {
node_id: temp_x
Expand All @@ -334,7 +335,7 @@ def _preprocess_transform(self, x, y, input_node_ids, blocks):
for hm in blocks:
data = [id_to_data[self._node_to_id[input_node]]
for input_node in hm.inputs]
data = tf.py_function(hm.transform,
data = tf.py_function(functools.partial(hm.transform, fit=fit),
inp=nest.flatten(data),
Tout=hm.output_types())
data = nest.flatten(data)[0]
Expand All @@ -346,10 +347,10 @@ def _preprocess_transform(self, x, y, input_node_ids, blocks):
@staticmethod
def _is_model_inputs(node):
for block in node.in_blocks:
if not isinstance(block, processor.HyperPreprocessor):
if not isinstance(block, preprocessor.Preprocessor):
return False
for block in node.out_blocks:
if not isinstance(block, processor.HyperPreprocessor):
if not isinstance(block, preprocessor.Preprocessor):
return True
return False

Expand All @@ -364,7 +365,7 @@ def save_preprocessors(self, path):
return
preprocessors = {}
for block in self._blocks:
if isinstance(block, processor.HyperPreprocessor):
if isinstance(block, preprocessor.Preprocessor):
preprocessors[block.name] = block.get_weights()
utils.pickle_to_file(preprocessors, path)

Expand Down
6 changes: 3 additions & 3 deletions autokeras/hypermodel/hyperblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from autokeras.hypermodel import block
from autokeras.hypermodel import node
from autokeras.hypermodel import processor
from autokeras.hypermodel import preprocessor


class HyperBlock(block.Block):
Expand Down Expand Up @@ -65,10 +65,10 @@ def build(self, hp, inputs=None):
if not isinstance(input_node, node.TextNode):
raise ValueError('The input_node should be a TextNode.')
if vectorizer == 'ngram':
output_node = processor.TextToNgramVector()(output_node)
output_node = preprocessor.TextToNgramVector()(output_node)
output_node = block.DenseBlock()(output_node)
else:
output_node = processor.TextToIntSequence()(output_node)
output_node = preprocessor.TextToIntSequence()(output_node)
output_node = block.EmbeddingBlock(
pretraining=self.pretraining)(output_node)
output_node = block.ConvBlock(separable=True)(output_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,9 @@
from autokeras.hypermodel import block


class HyperPreprocessor(block.Block):
class Preprocessor(block.Block):
"""Hyper preprocessing block base class."""

# TODO: Implement save and load, Since each trial they may be fit with different
# data because the preprocessors before the current preprocessor may change. So
# they need to be saved and loaded for prediction, otherwise the prediction
# cannot be done.

# TODO: It needs to know if it is in fit mode or predict mode. the behavior may
# be different. e.g. Image Augmentation should only augment the dataset when in
# fit mode.

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._hp = None
Expand Down Expand Up @@ -53,11 +44,12 @@ def update(self, x):
"""
raise NotImplementedError

def transform(self, x):
def transform(self, x, fit=False):
"""Incrementally fit the preprocessor with a single training instance.
# Arguments
x: EagerTensor. A single instance in the training dataset.
fit: Boolean. Whether it is in fit mode.
Returns:
A transformed instanced which can be converted to a tf.Tensor.
Expand Down Expand Up @@ -98,53 +90,7 @@ def set_weights(self, weights):
pass


class OneHotEncoder(object):
"""A class that can format data.
This class provides ways to transform data's classification label into
vector.
# Arguments
data: The input data
num_classes: The number of classes in the classification problem.
labels: The number of labels.
label_to_vec: Mapping from label to vector.
int_to_label: Mapping from int to label.
"""

def __init__(self):
"""Initialize a OneHotEncoder"""
self.data = None
self.num_classes = 0
self.labels = None
self.label_to_vec = {}
self.int_to_label = {}

def fit(self, data):
"""Create mapping from label to vector, and vector to label."""
data = np.array(data).flatten()
self.labels = set(data)
self.num_classes = len(self.labels)
for index, label in enumerate(self.labels):
vec = np.array([0] * self.num_classes)
vec[index] = 1
self.label_to_vec[label] = vec
self.int_to_label[index] = label

def transform(self, data):
"""Get vector for every element in the data array."""
data = np.array(data)
if len(data.shape) > 1:
data = data.flatten()
return np.array(list(map(lambda x: self.label_to_vec[x], data)))

def inverse_transform(self, data):
"""Get label for every element in data."""
return np.array(list(map(lambda x: self.int_to_label[x],
np.argmax(np.array(data), axis=1))))


class Normalize(HyperPreprocessor):
class Normalize(Preprocessor):
""" Perform basic image transformation and augmentation.
# Arguments
Expand Down Expand Up @@ -174,7 +120,7 @@ def finalize(self):
square_mean = np.mean(self.square_sum / self.count, axis=axis)
self.std = np.sqrt(square_mean - np.square(self.mean))

def transform(self, x):
def transform(self, x, fit=False):
""" Transform the test data, perform normalization.
# Arguments
Expand Down Expand Up @@ -210,7 +156,7 @@ def set_weights(self, weights):
self._shape = weights['_shape']


class TextToIntSequence(HyperPreprocessor):
class TextToIntSequence(Preprocessor):
"""Convert raw texts to sequences of word indices."""

def __init__(self, max_len=None, **kwargs):
Expand All @@ -227,7 +173,7 @@ def update(self, x):
if self.max_len is None:
self._max_len = max(self._max_len, len(sequence))

def transform(self, x):
def transform(self, x, fit=False):
sentence = nest.flatten(x)[0].numpy().decode('utf-8')
sequence = self._tokenizer.texts_to_sequences(sentence)[0]
sequence = tf.keras.preprocessing.sequence.pad_sequences(
Expand All @@ -253,7 +199,7 @@ def set_weights(self, weights):
self._tokenizer = weights['_tokenizer']


class TextToNgramVector(HyperPreprocessor):
class TextToNgramVector(Preprocessor):
"""Convert raw texts to n-gram vectors."""
# TODO: Implement save and load.

Expand Down Expand Up @@ -288,7 +234,7 @@ def finalize(self):
k=min(self._max_features, data.shape[1]))
self.selector.fit(data, self.labels)

def transform(self, x):
def transform(self, x, fit=False):
sentence = nest.flatten(x)[0].numpy().decode('utf-8')
data = self._vectorizer.transform([sentence]).toarray()
if self.selector:
Expand All @@ -301,3 +247,20 @@ def output_types(self):
@property
def output_shape(self):
return self._shape

def get_weights(self):
return {'_vectorizer': self._vectorizer,
'selector': self.selector,
'labels': self.labels,
'_max_features': self._max_features,
'_texts': self._texts,
'_shape': self._shape}

def set_weights(self, weights):
self._vectorizer = weights['_vectorizer']
self.selector = weights['selector']
self.labels = weights['labels']
self._max_features = weights['_max_features']
self._vectorizer.max_features = self._max_features
self._texts = weights['_texts']
self._shape = weights['_shape']
6 changes: 6 additions & 0 deletions autokeras/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def _save_preprocessors(self, trial_id, base_directory='.'):
def get_best_trials(self, num_trials=1):
return super()._get_best_trials(num_trials)

def load_trial(self, trial):
self.hypermodel.hyper_build(trial.hyperparameters)
filename = '%s-preprocessors' % trial.trial_id
path = os.path.join(trial.directory, filename)
self.hypermodel.load_preprocessors(path)


class RandomSearch(AutoTuner, kerastuner.RandomSearch):
"""KerasTuner RandomSearch with preprocessing layer tuning."""
Expand Down
46 changes: 46 additions & 0 deletions autokeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,49 @@ def pickle_from_file(path):
def pickle_to_file(obj, path):
"""Save the pickle file to the specified path."""
pickle.dump(obj, open(path, 'wb'))


class OneHotEncoder(object):
"""A class that can format data.
This class provides ways to transform data's classification label into
vector.
# Arguments
data: The input data
num_classes: The number of classes in the classification problem.
labels: The number of labels.
label_to_vec: Mapping from label to vector.
int_to_label: Mapping from int to label.
"""

def __init__(self):
"""Initialize a OneHotEncoder"""
self.data = None
self.num_classes = 0
self.labels = None
self.label_to_vec = {}
self.int_to_label = {}

def fit(self, data):
"""Create mapping from label to vector, and vector to label."""
data = np.array(data).flatten()
self.labels = set(data)
self.num_classes = len(self.labels)
for index, label in enumerate(self.labels):
vec = np.array([0] * self.num_classes)
vec[index] = 1
self.label_to_vec[label] = vec
self.int_to_label[index] = label

def transform(self, data):
"""Get vector for every element in the data array."""
data = np.array(data)
if len(data.shape) > 1:
data = data.flatten()
return np.array(list(map(lambda x: self.label_to_vec[x], data)))

def inverse_transform(self, data):
"""Get label for every element in data."""
return np.array(list(map(lambda x: self.int_to_label[x],
np.argmax(np.array(data), axis=1))))

0 comments on commit 6b42b51

Please sign in to comment.