Skip to content

Commit

Permalink
changed dropout_rate to dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
achaar authored and haifeng-jin committed Jul 13, 2020
1 parent 1fa743d commit ff3cf11
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 73 deletions.
82 changes: 41 additions & 41 deletions autokeras/blocks/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@ class DenseBlock(block_module.Block):
If left unspecified, it will be tuned automatically.
use_bn: Boolean. Whether to use BatchNormalization layers.
If left unspecified, it will be tuned automatically.
dropout_rate: Float. The dropout rate for the layers.
dropout: Float. The dropout rate for the layers.
If left unspecified, it will be tuned automatically.
"""

def __init__(self,
num_layers: Optional[int] = None,
use_batchnorm: Optional[bool] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
super().__init__(**kwargs)
self.num_layers = num_layers
self.use_batchnorm = use_batchnorm
self.dropout_rate = dropout_rate
self.dropout = dropout

def get_config(self):
config = super().get_config()
config.update({
'num_layers': self.num_layers,
'use_batchnorm': self.use_batchnorm,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand All @@ -58,10 +58,10 @@ def build(self, hp, inputs=None):
use_batchnorm = self.use_batchnorm
if use_batchnorm is None:
use_batchnorm = hp.Boolean('use_batchnorm', default=False)
if self.dropout_rate is not None:
dropout_rate = self.dropout_rate
if self.dropout is not None:
dropout = self.dropout
else:
dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0)
dropout = hp.Choice('dropout', [0.0, 0.25, 0.5], default=0)

for i in range(num_layers):
units = hp.Choice(
Expand All @@ -72,8 +72,8 @@ def build(self, hp, inputs=None):
if use_batchnorm:
output_node = layers.BatchNormalization()(output_node)
output_node = layers.ReLU()(output_node)
if dropout_rate > 0:
output_node = layers.Dropout(dropout_rate)(output_node)
if dropout > 0:
output_node = layers.Dropout(dropout)(output_node)
return output_node


Expand Down Expand Up @@ -169,7 +169,7 @@ class ConvBlock(block_module.Block):
unspecified, it will be tuned automatically.
separable: Boolean. Whether to use separable conv layers.
If left unspecified, it will be tuned automatically.
dropout_rate: Float. Between 0 and 1. The dropout rate for after the
dropout: Float. Between 0 and 1. The dropout rate for after the
convolutional layers. If left unspecified, it will be tuned
automatically.
"""
Expand All @@ -180,15 +180,15 @@ def __init__(self,
num_layers: Optional[int] = None,
max_pooling: Optional[bool] = None,
separable: Optional[bool] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
super().__init__(**kwargs)
self.kernel_size = kernel_size
self.num_blocks = num_blocks
self.num_layers = num_layers
self.max_pooling = max_pooling
self.separable = separable
self.dropout_rate = dropout_rate
self.dropout = dropout

def get_config(self):
config = super().get_config()
Expand All @@ -198,7 +198,7 @@ def get_config(self):
'num_layers': self.num_layers,
'max_pooling': self.max_pooling,
'separable': self.separable,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand Down Expand Up @@ -230,10 +230,10 @@ def build(self, hp, inputs=None):
max_pooling = hp.Boolean('max_pooling', default=True)
pool = layer_utils.get_max_pooling(input_node.shape)

if self.dropout_rate is not None:
dropout_rate = self.dropout_rate
if self.dropout is not None:
dropout = self.dropout
else:
dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0)
dropout = hp.Choice('dropout', [0.0, 0.25, 0.5], default=0)

for i in range(num_blocks):
for j in range(num_layers):
Expand All @@ -249,8 +249,8 @@ def build(self, hp, inputs=None):
kernel_size - 1,
padding=self._get_padding(kernel_size - 1,
output_node))(output_node)
if dropout_rate > 0:
output_node = layers.Dropout(dropout_rate)(output_node)
if dropout > 0:
output_node = layers.Dropout(dropout)(output_node)
return output_node

@staticmethod
Expand Down Expand Up @@ -374,12 +374,12 @@ class Transformer(block_module.Block):
pretraining='none',
num_heads=2,
dense_dim=32,
dropout_rate = 0.25)(output_node)
dropout = 0.25)(output_node)
output_node = ak.SpatialReduction(reduction_type='global_avg')(output_node)
output_node = ak.DenseBlock(num_layers=1, use_batchnorm = False)(output_node)
output_node = ak.ClassificationHead(
loss=losses.SparseCategoricalCrossentropy(),
dropout_rate = 0.25)(output_node)
dropout = 0.25)(output_node)
clf = ak.AutoModel(inputs=text_input, outputs=output_node, max_trials=2)
```
# Arguments
Expand All @@ -394,7 +394,7 @@ class Transformer(block_module.Block):
it will be tuned automatically.
dense_dim: Int. The output dimension of the Feed-Forward Network. If left
unspecified, it will be tuned automatically.
dropout_rate: Float. Between 0 and 1. If left unspecified, it will be
dropout: Float. Between 0 and 1. If left unspecified, it will be
tuned automatically.
"""

Expand All @@ -404,15 +404,15 @@ def __init__(self,
embedding_dim: Optional[int] = None,
num_heads: Optional[int] = None,
dense_dim: Optional[int] = None,
dropout_rate: Optional[int] = None,
dropout: Optional[int] = None,
**kwargs):
super().__init__(**kwargs)
self.max_features = max_features
self.pretraining = pretraining
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self. dense_dim = dense_dim
self.dropout_rate = dropout_rate
self.dropout = dropout

def get_config(self):
config = super().get_config()
Expand All @@ -422,7 +422,7 @@ def get_config(self):
'embedding_dim': self.embedding_dim,
'num_heads': self.num_heads,
'dense_dim': self.dense_dim,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand All @@ -449,9 +449,9 @@ def build(self, hp, inputs=None):
dense_dim = self.dense_dim or hp.Choice('dense_dim',
[128, 256, 512, 1024, 2048],
default=2048)
dropout_rate = self.dropout_rate or hp.Choice('dropout_rate',
[0.0, 0.25, 0.5],
default=0)
dropout = self.dropout or hp.Choice('dropout',
[0.0, 0.25, 0.5],
default=0)

ffn = tf.keras.Sequential(
[layers.Dense(dense_dim, activation="relu"),
Expand All @@ -460,22 +460,22 @@ def build(self, hp, inputs=None):

layernorm1 = layers.LayerNormalization(epsilon=1e-6)
layernorm2 = layers.LayerNormalization(epsilon=1e-6)
dropout1 = layers.Dropout(dropout_rate)
dropout2 = layers.Dropout(dropout_rate)
dropout1 = layers.Dropout(dropout)
dropout2 = layers.Dropout(dropout)
# Token and Position Embeddings
input_node = nest.flatten(inputs)[0]
token_embedding = Embedding(max_features=self.max_features,
pretraining=pretraining,
embedding_dim=embedding_dim,
dropout_rate=dropout_rate).build(hp, input_node)
dropout=dropout).build(hp, input_node)
maxlen = input_node.shape[-1]
batch_size = tf.shape(input_node)[0]
positions = self.pos_array_funct(maxlen, batch_size)
position_embedding = Embedding(max_features=maxlen,
pretraining=pretraining,
embedding_dim=embedding_dim,
dropout_rate=dropout_rate).build(hp,
positions)
dropout=dropout).build(hp,
positions)
output_node = tf.keras.layers.Add()([token_embedding,
position_embedding])
attn_output = MultiHeadSelfAttention(
Expand Down Expand Up @@ -626,29 +626,29 @@ class Embedding(block_module.Block):
model), 'glove', 'fasttext' or 'word2vec'. Use pretrained word embedding.
If left unspecified, it will be tuned automatically.
embedding_dim: Int. If left unspecified, it will be tuned automatically.
dropout_rate: Float. The dropout rate for after the Embedding layer.
dropout: Float. The dropout rate for after the Embedding layer.
If left unspecified, it will be tuned automatically.
"""

def __init__(self,
max_features: int = 20001,
pretraining: Optional[str] = None,
embedding_dim: Optional[int] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
super().__init__(**kwargs)
self.max_features = max_features
self.pretraining = pretraining
self.embedding_dim = embedding_dim
self.dropout_rate = dropout_rate
self.dropout = dropout

def get_config(self):
config = super().get_config()
config.update({
'max_features': self.max_features,
'pretraining': self.pretraining,
'embedding_dim': self.embedding_dim,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand Down Expand Up @@ -678,10 +678,10 @@ def build(self, hp, inputs=None):
# input_length=input_node.shape[1],
# trainable=True)
output_node = layer(input_node)
if self.dropout_rate is not None:
dropout_rate = self.dropout_rate
if self.dropout is not None:
dropout = self.dropout
else:
dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0.25)
if dropout_rate > 0:
output_node = layers.Dropout(dropout_rate)(output_node)
dropout = hp.Choice('dropout', [0.0, 0.25, 0.5], default=0.25)
if dropout > 0:
output_node = layers.Dropout(dropout)(output_node)
return output_node
42 changes: 21 additions & 21 deletions autokeras/blocks/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ClassificationHead(head_module.Head):
loss: A Keras loss function. Defaults to use `binary_crossentropy` or
`categorical_crossentropy` based on the number of classes.
metrics: A list of Keras metrics. Defaults to use 'accuracy'.
dropout_rate: Float. The dropout rate for the layers.
dropout: Float. The dropout rate for the layers.
If left unspecified, it will be tuned automatically.
"""

Expand All @@ -43,11 +43,11 @@ def __init__(self,
multi_label: bool = False,
loss: Optional[types.LossType] = None,
metrics: Optional[types.MetricsType] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
self.num_classes = num_classes
self.multi_label = multi_label
self.dropout_rate = dropout_rate
self.dropout = dropout
if metrics is None:
metrics = ['accuracy']
if loss is None:
Expand All @@ -68,7 +68,7 @@ def get_config(self):
config.update({
'num_classes': self.num_classes,
'multi_label': self.multi_label,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand All @@ -81,13 +81,13 @@ def build(self, hp, inputs=None):
if len(output_node.shape) > 2:
output_node = reduction.SpatialReduction().build(hp, output_node)

if self.dropout_rate is not None:
dropout_rate = self.dropout_rate
if self.dropout is not None:
dropout = self.dropout
else:
dropout_rate = hp.Choice('dropout_rate', [0.0, 0.25, 0.5], default=0)
dropout = hp.Choice('dropout', [0.0, 0.25, 0.5], default=0)

if dropout_rate > 0:
output_node = layers.Dropout(dropout_rate)(output_node)
if dropout > 0:
output_node = layers.Dropout(dropout)(output_node)
output_node = layers.Dense(self.output_shape[-1])(output_node)
if isinstance(self.loss, tf.keras.losses.BinaryCrossentropy):
output_node = layers.Activation(activations.sigmoid,
Expand Down Expand Up @@ -119,29 +119,29 @@ class RegressionHead(head_module.Head):
multi_label: Boolean. Defaults to False.
loss: A Keras loss function. Defaults to use `mean_squared_error`.
metrics: A list of Keras metrics. Defaults to use `mean_squared_error`.
dropout_rate: Float. The dropout rate for the layers.
dropout: Float. The dropout rate for the layers.
If left unspecified, it will be tuned automatically.
"""

def __init__(self,
output_dim: Optional[int] = None,
loss: types.LossType = 'mean_squared_error',
metrics: Optional[types.MetricsType] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
if metrics is None:
metrics = ['mean_squared_error']
super().__init__(loss=loss,
metrics=metrics,
**kwargs)
self.output_dim = output_dim
self.dropout_rate = dropout_rate
self.dropout = dropout

def get_config(self):
config = super().get_config()
config.update({
'output_dim': self.output_dim,
'dropout_rate': self.dropout_rate})
'dropout': self.dropout})
return config

def build(self, hp, inputs=None):
Expand All @@ -155,12 +155,12 @@ def build(self, hp, inputs=None):
input_node = inputs[0]
output_node = input_node

dropout_rate = self.dropout_rate or hp.Choice('dropout_rate',
[0.0, 0.25, 0.5],
default=0)
dropout = self.dropout or hp.Choice('dropout',
[0.0, 0.25, 0.5],
default=0)

if dropout_rate > 0:
output_node = layers.Dropout(dropout_rate)(output_node)
if dropout > 0:
output_node = layers.Dropout(dropout)(output_node)
output_node = reduction.Flatten().build(hp, output_node)
output_node = layers.Dense(self.output_shape[-1],
name=self.name)(output_node)
Expand Down Expand Up @@ -191,20 +191,20 @@ class SegmentationHead(ClassificationHead):
loss: A Keras loss function. Defaults to use `binary_crossentropy` or
`categorical_crossentropy` based on the number of classes.
metrics: A list of Keras metrics. Defaults to use 'accuracy'.
dropout_rate: Float. The dropout rate for the layers.
dropout: Float. The dropout rate for the layers.
If left unspecified, it will be tuned automatically.
"""

def __init__(self,
num_classes: Optional[int] = None,
loss: Optional[types.LossType] = None,
metrics: Optional[types.MetricsType] = None,
dropout_rate: Optional[float] = None,
dropout: Optional[float] = None,
**kwargs):
super().__init__(loss=loss,
metrics=metrics,
num_classes=num_classes,
dropout_rate=dropout_rate,
dropout=dropout,
**kwargs)

def build(self, hp, inputs):
Expand Down

0 comments on commit ff3cf11

Please sign in to comment.