Skip to content

Commit

Permalink
[MRG] Simple version Image Pipeline Implemented (#650)
Browse files Browse the repository at this point in the history
* legacy

* new files

* Update demo.py

* test tunner bug fixed

* classes and funtions in the demo created

* network implemented not tested

* travis to ignore legacy tests

* test connectedHyperparameter

* test fixed

* some basic blocks implemented

* still debuging

* basic test for hypergraph passed

* merge layer implemented

* hyper_graph fully tested

* more args added to auto model

* test fixed

* local changes

* test auto_model

* auto_model fit signature changed for validation data

* super classes extending object

* change n_ to num_

* refactored automodel to extend hypermodel and removed tuner from signature

* rename HyperNode to Node. TextInput added

* rename HyperGraph to GraphAutoModel extending AutoModel

* refactor hyperhead, removed tensor heads

* removed unnecessary blocks, rename build_output to build

* changed some functions and attributes to private

* remove trails from AutoModel public API

* test cases changed accordingly

* import modules instead of objects

* tuner deleted from AutoModel contructor

* change trails to num_trials

* use the same quote sign

* revised auto pipeline docs

* removed compile from AutoModel

* loss and metrics moved to hyper heads

* do not flatten by default in hyperheads

* inputs and outputs down to GraphAutoModel

* changed AutoModel

* name_scope changed to tf 2.0 and moved to hyperparameters and hypermodel

* remove HierarchicalHyperParameters

* renaming the tests

* auto_pipeline

* image module tested

* image regressor tested

* removed some requirements

* update tf to 2.0 beta

* Refactor (#646)

* super classes extending object

* change n_ to num_

* refactored automodel to extend hypermodel and removed tuner from signature

* rename HyperNode to Node. TextInput added

* rename HyperGraph to GraphAutoModel extending AutoModel

* refactor hyperhead, removed tensor heads

* removed unnecessary blocks, rename build_output to build

* changed some functions and attributes to private

* remove trails from AutoModel public API

* test cases changed accordingly

* import modules instead of objects

* tuner deleted from AutoModel contructor

* change trails to num_trials

* use the same quote sign

* revised auto pipeline docs

* removed compile from AutoModel

* loss and metrics moved to hyper heads

* do not flatten by default in hyperheads

* inputs and outputs down to GraphAutoModel

* changed AutoModel

* name_scope changed to tf 2.0 and moved to hyperparameters and hypermodel

* remove HierarchicalHyperParameters

* renaming some variables and make private some members

* demo update

* remove legacy

* test fixed

* changed setup.py

* Update hyper_block.py

Improve the DenseBlock: Add some other layers'category to it.

* Update hyper_block.py

* pep8 style fix

* changed docstrings to use markdown

* renaming some of the variables and classes

* extracted check new and old search space to tuner base class

* fixing some pylint issues

* change the normalization to use mean and stddev only

* dependency

* dependency changed to keras-tuner

* pep8 formatting
  • Loading branch information
haifeng-jin committed Jun 25, 2019
1 parent 1e76e7f commit d6947d6
Show file tree
Hide file tree
Showing 20 changed files with 477 additions and 509 deletions.
1 change: 0 additions & 1 deletion autokeras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from autokeras.hypermodel import *
from autokeras.auto.auto_model import AutoModel
from autokeras.auto.preprocessor import image_augment
88 changes: 47 additions & 41 deletions autokeras/auto/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from queue import Queue

import kerastuner
import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest

from autokeras.hypermodel import hypermodel, hyper_head
from autokeras import layer_utils
from autokeras import tuner
from autokeras.hypermodel import hyper_head
from autokeras import layer_utils, const


class AutoModel(hypermodel.HyperModel):
class AutoModel(kerastuner.HyperModel):
""" A AutoModel should be an AutoML solution.
It contains the HyperModels and the Tuner.
Expand All @@ -28,8 +27,7 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inputs = []
self.outputs = []
self.tuner = tuner.SequentialRandomSearch(self,
objective=self._get_metrics())
self.tuner = None

def build(self, hp):
raise NotImplementedError
Expand All @@ -44,8 +42,8 @@ def fit(self,
x = layer_utils.format_inputs(x, 'train_x')
y = layer_utils.format_inputs(y, 'train_y')

# TODO: Set the shapes only if they are not provided by the user
# when initiating the HyperHead or Block.
# TODO: Set the shapes only if they are not provided by the user when
# initiating the HyperHead or Block.
for x_input, input_node in zip(x, self.inputs):
input_node.shape = x_input.shape[1:]
for y_input, output_node in zip(y, self.outputs):
Expand All @@ -59,28 +57,25 @@ def fit(self,
(x, y), (x_val, y_val) = layer_utils.split_train_to_valid(x, y)
validation_data = x_val, y_val

self.tuner = kerastuner.RandomSearch(
hypermodel=self,
objective='val_loss',
max_trials=trials or const.Constant.NUM_TRAILS)

# TODO: allow early stop if epochs is not specified.
self.tuner.search(trials=trials,
x=x,
self.tuner.search(x=x,
y=y,
validation_data=validation_data,
**kwargs)

def predict(self, x, **kwargs):
"""Predict the output for a given testing data. """
return self.tuner.best_model.predict(x, **kwargs)

def _get_loss(self):
loss = nest.flatten([output_node.in_hypermodels[0].loss
for output_node in self.outputs
if isinstance(output_node.in_hypermodels[0],
hyper_head.HyperHead)])
return loss
return self.tuner.get_best_models(1)[0].predict(x, **kwargs)

def _get_metrics(self):
metrics = []
for metrics_list in [output_node.in_hypermodels[0].metrics
for output_node in self.outputs
for metrics_list in [output_node.in_hypermodels[0].metrics for
output_node in self.outputs
if isinstance(output_node.in_hypermodels[0],
hyper_head.HyperHead)]:
metrics += metrics_list
Expand Down Expand Up @@ -108,22 +103,25 @@ def build(self, hp):
node_id = self._node_to_id[input_node]
real_nodes[node_id] = input_node.build(hp)
for hypermodel in self._hypermodels:
outputs = hypermodel.build(
hp,
inputs=[real_nodes[self._node_to_id[input_node]]
for input_node in hypermodel.inputs])
temp_inputs = [real_nodes[self._node_to_id[input_node]]
for input_node in hypermodel.inputs]
outputs = hypermodel.build(hp,
inputs=temp_inputs)
outputs = layer_utils.format_inputs(outputs, hypermodel.name)
for output_node, real_output_node in zip(hypermodel.outputs, outputs):
for output_node, real_output_node in zip(hypermodel.outputs,
outputs):
real_nodes[self._node_to_id[output_node]] = real_output_node
model = tf.keras.Model([real_nodes[self._node_to_id[input_node]]
for input_node in self.inputs],
[real_nodes[self._node_to_id[output_node]]
for output_node in self.outputs])
model = tf.keras.Model(
[real_nodes[self._node_to_id[input_node]] for input_node in
self.inputs],
[real_nodes[self._node_to_id[output_node]] for output_node in
self.outputs])

# Specify hyperparameters from compile(...)
optimizer = hp.Choice('optimizer',
[tf.keras.optimizers.Adam,
tf.keras.optimizers.Adadelta,
tf.keras.optimizers.SGD])()
['adam',
'adadelta',
'sgd'])

model.compile(optimizer=optimizer,
metrics=self._get_metrics(),
Expand Down Expand Up @@ -155,8 +153,8 @@ def _build_network(self):
while not queue.empty():
input_node = queue.get()
for hypermodel in input_node.out_hypermodels:
# Check at least one output node of the hypermodel
# is in the interested nodes.
# Check at least one output node of the hypermodel is in the
# interested nodes.
if not any([output_node in self._node_to_id for output_node in
hypermodel.outputs]):
continue
Expand All @@ -171,7 +169,8 @@ def _build_network(self):
hypermodel = output_node.in_hypermodels[0]
hypermodel.output_shape = output_node.shape

def _search_network(self, input_node, outputs, in_stack_nodes, visited_nodes):
def _search_network(self, input_node, outputs, in_stack_nodes,
visited_nodes):
visited_nodes.add(input_node)
in_stack_nodes.add(input_node)

Expand All @@ -184,9 +183,7 @@ def _search_network(self, input_node, outputs, in_stack_nodes, visited_nodes):
if output_node in in_stack_nodes:
raise ValueError('The network has a cycle.')
if output_node not in visited_nodes:
self._search_network(output_node,
outputs,
in_stack_nodes,
self._search_network(output_node, outputs, in_stack_nodes,
visited_nodes)
if output_node in self._node_to_id.keys():
outputs_reached = True
Expand All @@ -205,9 +202,18 @@ def _add_hypermodel(self, hypermodel):
self._add_node(output_node)
for input_node in hypermodel.inputs:
if input_node not in self._node_to_id:
raise ValueError('A required input is missing for HyperModel {name}.'
.format(name=hypermodel.name))
raise ValueError(
'A required input is missing '
'for HyperModel {name}.'.format(
name=hypermodel.name))

def _add_node(self, input_node):
if input_node not in self._node_to_id:
self._node_to_id[input_node] = len(self._node_to_id)

def _get_loss(self):
loss = nest.flatten([output_node.in_hypermodels[0].loss
for output_node in self.outputs
if isinstance(output_node.in_hypermodels[0],
hyper_head.HyperHead)])
return loss
32 changes: 0 additions & 32 deletions autokeras/auto/auto_pipeline.py

This file was deleted.

65 changes: 65 additions & 0 deletions autokeras/auto/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import kerastuner
import tensorflow as tf

from autokeras import const
from autokeras.auto import processor, auto_model
from autokeras.hypermodel import hyper_block, hyper_node
from autokeras.hypermodel import hyper_head


class SupervisedImagePipeline(auto_model.AutoModel):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.image_block = hyper_block.ImageBlock()
self.head = None
self.normalizer = processor.Normalizer()

def fit(self, x=None, y=None, trials=None, **kwargs):
self.normalizer.fit(x)
self.inputs = [hyper_node.ImageInput()]
super().fit(x=self.normalizer.transform(x), y=y, **kwargs)

def build(self, hp):
input_node = self.inputs[0].build(hp)
output_node = self.image_block.build(hp, input_node)
output_node = self.head.build(hp, output_node)
model = tf.keras.Model(input_node, output_node)
optimizer = hp.Choice('optimizer',
['adam',
'adadelta',
'sgd'])

model.compile(optimizer=optimizer,
loss=self.head.loss,
metrics=self.head.metrics)

return model


class ImageClassifier(SupervisedImagePipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.head = hyper_head.ClassificationHead()
self.label_encoder = processor.OneHotEncoder()

def fit(self, x=None, y=None, **kwargs):
self.label_encoder.fit(y)
self.head.output_shape = (self.label_encoder.num_classes,)
super().fit(x=x, y=self.label_encoder.transform(y), **kwargs)

def predict(self, x, **kwargs):
return self.label_encoder.inverse_transform(super().predict(x, **kwargs))


class ImageRegressor(SupervisedImagePipeline):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.head = hyper_head.RegressionHead()

def fit(self, x=None, y=None, **kwargs):
self.head.output_shape = (1,)
super().fit(x=x, y=y, **kwargs)

def predict(self, x, **kwargs):
return super().predict(x, **kwargs).flatten()
78 changes: 78 additions & 0 deletions autokeras/auto/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np


class OneHotEncoder(object):
"""A class that can format data.
This class provides ways to transform data's classification label into
vector.
# Attributes
data: The input data
num_classes: The number of classes in the classification problem.
labels: The number of labels.
label_to_vec: Mapping from label to vector.
int_to_label: Mapping from int to label.
"""

def __init__(self):
"""Initialize a OneHotEncoder"""
self.data = None
self.num_classes = 0
self.labels = None
self.label_to_vec = {}
self.int_to_label = {}

def fit(self, data):
"""Create mapping from label to vector, and vector to label."""
data = np.array(data).flatten()
self.labels = set(data)
self.num_classes = len(self.labels)
for index, label in enumerate(self.labels):
vec = np.array([0] * self.num_classes)
vec[index] = 1
self.label_to_vec[label] = vec
self.int_to_label[index] = label

def transform(self, data):
"""Get vector for every element in the data array."""
data = np.array(data)
if len(data.shape) > 1:
data = data.flatten()
return np.array(list(map(lambda x: self.label_to_vec[x], data)))

def inverse_transform(self, data):
"""Get label for every element in data."""
return np.array(list(map(lambda x: self.int_to_label[x],
np.argmax(np.array(data), axis=1))))


class Normalizer(object):
""" Perform basic image transformation and augmentation.
# Attributes
max_val: the maximum value of all data.
mean: the mean value.
std: the standard deviation.
"""

def __init__(self):
self.mean = None
self.std = None

def fit(self, data):
self.mean = np.mean(data, axis=(0, 1, 2), keepdims=True).flatten()
self.std = np.std(data, axis=(0, 1, 2), keepdims=True).flatten()

def transform(self, data):
""" Transform the test data, perform normalization.
# Arguments
data: Numpy array. The data to be transformed.
# Returns
A DataLoader instance.
"""
# channel-wise normalize the image
data = (data - self.mean) / self.std
return data
1 change: 1 addition & 0 deletions autokeras/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class Constant(object):
VALIDATION_SET_SIZE = 0.08333
# TODO: Change it to random and configurable.
SEED = 42
BATCH_SIZE = 128
2 changes: 0 additions & 2 deletions autokeras/hypermodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
from autokeras.hypermodel.hypermodel import HyperModel
# from autokeras.hypermodel.resnet_block import *

0 comments on commit d6947d6

Please sign in to comment.