Skip to content

Commit

Permalink
Add conditions to the search space
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Jul 31, 2020
1 parent 71ee98a commit 13aa31b
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 122 deletions.
120 changes: 62 additions & 58 deletions autokeras/blocks/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
from autokeras.utils import layer_utils
from autokeras.utils import utils

REDUCTION_TYPE = 'reduction_type'
FLATTEN = 'flatten'
GLOBAL_MAX = 'global_max'
GLOBAL_AVG = 'global_avg'


def shape_compatible(shape1, shape2):
if len(shape1) != len(shape2):
Expand Down Expand Up @@ -53,10 +58,6 @@ def build(self, hp, inputs=None):
if len(inputs) == 1:
return inputs

merge_type = self.merge_type or hp.Choice('merge_type',
['add', 'concatenate'],
default='add')

if not all([shape_compatible(input_node.shape, inputs[0].shape) for
input_node in inputs]):
new_inputs = []
Expand All @@ -67,12 +68,20 @@ def build(self, hp, inputs=None):
# TODO: Even inputs have different shape[-1], they can still be Add(
# ) after another layer. Check if the inputs are all of the same
# shape
if all([input_node.shape == inputs[0].shape for input_node in inputs]):
if self._inputs_same_shape(inputs):
merge_type = self.merge_type or hp.Choice(
'merge_type', ['add', 'concatenate'], default='add')
if merge_type == 'add':
return layers.Add(inputs)
return layers.Add()(inputs)

return layers.Concatenate()(inputs)

def _inputs_same_shape(self, inputs):
for input_node in inputs:
if input_node.shape.as_list() != inputs[0].shape.as_list():
return False
return True


class Flatten(block_module.Block):
"""Flatten the input tensor with Keras Flatten layer."""
Expand All @@ -86,23 +95,23 @@ def build(self, hp, inputs=None):
return input_node


class SpatialReduction(block_module.Block):
"""Reduce the dimension of a spatial tensor, e.g. image, to a vector.
# Arguments
reduction_type: String. 'flatten', 'global_max' or 'global_avg'.
If left unspecified, it will be tuned automatically.
"""
class Reduction(block_module.Block):

def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.reduction_type = reduction_type

def get_config(self):
config = super().get_config()
config.update({'reduction_type': self.reduction_type})
config.update({REDUCTION_TYPE: self.reduction_type})
return config

def global_max(self, input_node):
raise NotImplementedError

def global_avg(self, input_node):
raise NotImplementedError

def build(self, hp, inputs=None):
inputs = nest.flatten(inputs)
utils.validate_num_inputs(inputs, 1)
Expand All @@ -113,62 +122,57 @@ def build(self, hp, inputs=None):
if len(output_node.shape) <= 2:
return output_node

reduction_type = self.reduction_type or hp.Choice('reduction_type',
['flatten',
'global_max',
'global_avg'],
default='global_avg')
if reduction_type == 'flatten':
if self.reduction_type is None:
reduction_type = hp.Choice(
REDUCTION_TYPE, [FLATTEN, GLOBAL_MAX, GLOBAL_AVG])
with hp.conditional_scope(REDUCTION_TYPE, [reduction_type]):
return self._build_block(hp, output_node, reduction_type)
else:
return self._build_block(hp, output_node, self.reduction_type)

def _build_block(self, hp, output_node, reduction_type):
if reduction_type == FLATTEN:
output_node = Flatten().build(hp, output_node)
elif reduction_type == 'global_max':
output_node = layer_utils.get_global_max_pooling(
output_node.shape)()(output_node)
elif reduction_type == 'global_avg':
output_node = layer_utils.get_global_average_pooling(
output_node.shape)()(output_node)
elif reduction_type == GLOBAL_MAX:
output_node = self.global_max(output_node)
elif reduction_type == GLOBAL_AVG:
output_node = self.global_avg(output_node)
return output_node


class TemporalReduction(block_module.Block):
"""Reduce the dimension of a temporal tensor, e.g. output of RNN, to a vector.
class SpatialReduction(Reduction):
"""Reduce the dimension of a spatial tensor, e.g. image, to a vector.
# Arguments
reduction_type: String. 'flatten', 'global_max' or 'global_avg'. If left
unspecified, it will be tuned automatically.
reduction_type: String. 'flatten', 'global_max' or 'global_avg'.
If left unspecified, it will be tuned automatically.
"""

def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.reduction_type = reduction_type
super().__init__(reduction_type, **kwargs)

def get_config(self):
config = super().get_config()
config.update({'reduction_type': self.reduction_type})
return config
def global_max(self, input_node):
return layer_utils.get_global_max_pooling(
input_node.shape)()(input_node)

def build(self, hp, inputs=None):
inputs = nest.flatten(inputs)
utils.validate_num_inputs(inputs, 1)
input_node = inputs[0]
output_node = input_node
def global_avg(self, input_node):
return layer_utils.get_global_average_pooling(
input_node.shape)()(input_node)

# No need to reduce.
if len(output_node.shape) <= 2:
return output_node

reduction_type = self.reduction_type or hp.Choice('reduction_type',
['flatten',
'global_max',
'global_avg'],
default='global_avg')
class TemporalReduction(Reduction):
"""Reduce the dimension of a temporal tensor, e.g. output of RNN, to a vector.
if reduction_type == 'flatten':
output_node = Flatten().build(hp, output_node)
elif reduction_type == 'global_max':
output_node = tf.math.reduce_max(output_node, axis=-2)
elif reduction_type == 'global_avg':
output_node = tf.math.reduce_mean(output_node, axis=-2)
elif reduction_type == 'global_min':
output_node = tf.math.reduce_min(output_node, axis=-2)
# Arguments
reduction_type: String. 'flatten', 'global_max' or 'global_avg'. If left
unspecified, it will be tuned automatically.
"""

return output_node
def __init__(self, reduction_type: Optional[str] = None, **kwargs):
super().__init__(reduction_type, **kwargs)

def global_max(self, input_node):
return tf.math.reduce_max(input_node, axis=-2)

def global_avg(self, input_node):
return tf.math.reduce_mean(input_node, axis=-2)
136 changes: 72 additions & 64 deletions autokeras/blocks/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
from autokeras.blocks import reduction
from autokeras.engine import block as block_module

BLOCK_TYPE = 'block_type'
RESNET = 'resnet'
XCEPTION = 'xception'
VANILLA = 'vanilla'
NORMALIZE = 'normalize'
AUGMENT = 'augment'
TRANSFORMER = 'transformer'
MAX_TOKENS = 'max_tokens'
NGRAM = 'ngram'


class ImageBlock(block_module.Block):
"""Block for image data.
Expand Down Expand Up @@ -47,35 +57,43 @@ def __init__(self,

def get_config(self):
config = super().get_config()
config.update({'block_type': self.block_type,
'normalize': self.normalize,
'augment': self.augment})
config.update({BLOCK_TYPE: self.block_type,
NORMALIZE: self.normalize,
AUGMENT: self.augment})
return config

def _build_block(self, hp, output_node, block_type):
if block_type == RESNET:
return basic.ResNetBlock().build(hp, output_node)
elif block_type == XCEPTION:
return basic.XceptionBlock().build(hp, output_node)
elif block_type == VANILLA:
return basic.ConvBlock().build(hp, output_node)

def build(self, hp, inputs=None):
input_node = nest.flatten(inputs)[0]
output_node = input_node

block_type = self.block_type or hp.Choice('block_type',
['resnet', 'xception', 'vanilla'],
default='vanilla')

normalize = self.normalize
if normalize is None:
normalize = hp.Boolean('normalize', default=False)
augment = self.augment
if augment is None:
augment = hp.Boolean('augment', default=False)
if normalize:
if self.normalize is None and hp.Boolean(NORMALIZE):
with hp.conditional_scope(NORMALIZE, [True]):
output_node = preprocessing.Normalization().build(hp, output_node)
elif self.normalize:
output_node = preprocessing.Normalization().build(hp, output_node)
if augment:

if self.augment is None and hp.Boolean(AUGMENT):
with hp.conditional_scope(AUGMENT, [True]):
output_node = preprocessing.ImageAugmentation().build(
hp, output_node)
elif self.augment:
output_node = preprocessing.ImageAugmentation().build(hp, output_node)
if block_type == 'resnet':
output_node = basic.ResNetBlock().build(hp, output_node)
elif block_type == 'xception':
output_node = basic.XceptionBlock().build(hp, output_node)
elif block_type == 'vanilla':
output_node = basic.ConvBlock().build(hp, output_node)

if self.block_type is None:
block_type = hp.Choice(BLOCK_TYPE, [RESNET, XCEPTION, VANILLA])
with hp.conditional_scope(BLOCK_TYPE, [block_type]):
output_node = self._build_block(hp, output_node, block_type)
else:
output_node = self._build_block(hp, output_node, self.block_type)

return output_node


Expand Down Expand Up @@ -107,40 +125,44 @@ def __init__(self,
def get_config(self):
config = super().get_config()
config.update({
'block_type': self.block_type,
'max_tokens': self.max_tokens,
BLOCK_TYPE: self.block_type,
MAX_TOKENS: self.max_tokens,
'pretraining': self.pretraining})
return config

def build(self, hp, inputs=None):
input_node = nest.flatten(inputs)[0]
output_node = input_node
block_type = self.block_type or hp.Choice('block_type',
['vanilla',
'transformer',
'ngram'],
default='vanilla')
max_tokens = self.max_tokens or hp.Choice('max_tokens',
[500, 5000, 20000],
default=5000)
if block_type == 'ngram':
if self.block_type is None:
block_type = hp.Choice(BLOCK_TYPE, [VANILLA, TRANSFORMER, NGRAM])
with hp.conditional_scope(BLOCK_TYPE, [block_type]):
output_node = self._build_block(hp, output_node, block_type)
else:
output_node = self._build_block(hp, output_node, self.block_type)
return output_node

def _build_block(self, hp, output_node, block_type):
max_tokens = self.max_tokens or hp.Choice(
MAX_TOKENS, [500, 5000, 20000], default=5000)
if block_type == NGRAM:
output_node = preprocessing.TextToNgramVector(
max_tokens=max_tokens).build(hp, output_node)
output_node = basic.DenseBlock().build(hp, output_node)
return basic.DenseBlock().build(hp, output_node)
output_node = preprocessing.TextToIntSequence(
max_tokens=max_tokens).build(hp, output_node)
if block_type == TRANSFORMER:
output_node = basic.Transformer(
max_features=max_tokens + 1,
pretraining=self.pretraining,
).build(hp, output_node)
else:
output_node = preprocessing.TextToIntSequence(
max_tokens=max_tokens).build(hp, output_node)
if block_type == 'transformer':
output_node = basic.Transformer(max_features=max_tokens + 1,
pretraining=self.pretraining
).build(hp, output_node)
else:
output_node = basic.Embedding(
max_features=max_tokens + 1,
pretraining=self.pretraining).build(hp, output_node)
output_node = basic.ConvBlock().build(hp, output_node)
output_node = reduction.SpatialReduction().build(hp, output_node)
output_node = basic.DenseBlock().build(hp, output_node)
output_node = basic.Embedding(
max_features=max_tokens + 1,
pretraining=self.pretraining,
).build(hp, output_node)
output_node = basic.ConvBlock().build(hp, output_node)
output_node = reduction.SpatialReduction().build(hp, output_node)
output_node = basic.DenseBlock().build(hp, output_node)
return output_node


Expand All @@ -150,7 +172,6 @@ class StructuredDataBlock(block_module.Block):
# Arguments
categorical_encoding: Boolean. Whether to use the CategoricalToNumerical to
encode the categorical features to numerical features. Defaults to True.
If specified as None, it will be tuned automatically.
seed: Int. Random seed.
"""

Expand Down Expand Up @@ -181,28 +202,15 @@ def get_config(self):
'column_names': self.column_names})
return config

def build_categorical_encoding(self, hp, input_node):
def build(self, hp, inputs=None):
input_node = nest.flatten(inputs)[0]
output_node = input_node
categorical_encoding = self.categorical_encoding
if categorical_encoding is None:
categorical_encoding = hp.Choice('categorical_encoding',
[True, False],
default=True)
if categorical_encoding:
if self.categorical_encoding:
block = preprocessing.CategoricalToNumerical()
block.column_types = self.column_types
block.column_names = self.column_names
output_node = block.build(hp, output_node)
return output_node

def build_body(self, hp, input_node):
output_node = basic.DenseBlock().build(hp, input_node)
return output_node

def build(self, hp, inputs=None):
input_node = nest.flatten(inputs)[0]
output_node = self.build_categorical_encoding(hp, input_node)
output_node = self.build_body(hp, output_node)
output_node = basic.DenseBlock().build(hp, output_node)
return output_node


Expand Down

0 comments on commit 13aa31b

Please sign in to comment.