Skip to content

Commit

Permalink
Merge a8943f1 into 2fcc50b
Browse files Browse the repository at this point in the history
  • Loading branch information
droidadroit committed Nov 16, 2018
2 parents 2fcc50b + a8943f1 commit cbec9df
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions autokeras/nn/generator.py
Expand Up @@ -7,7 +7,23 @@


class NetworkGenerator:
"""The base class for generating a network.
It can be used to generate a CNN or Multi-Layer Perceptron.
Attributes:
n_output_node: Number of output nodes in the network.
input_shape: A tuple to represent the input shape.
"""
def __init__(self, n_output_node, input_shape):
"""Initialize the instance.
Sets the parameters `n_output_node` and `input_shape` for the instance.
Args:
n_output_node: An integer. Number of output nodes in the network.
input_shape: A tuple. Input shape of the network.
"""
self.n_output_node = n_output_node
self.input_shape = input_shape

Expand All @@ -17,7 +33,24 @@ def generate(self, model_len, model_width):


class CnnGenerator(NetworkGenerator):
"""A class to generate CNN.
Attributes:
n_dim: `len(self.input_shape) - 1`
conv: A class that represents `(n_dim-1)` dimensional convolution.
dropout: A class that represents `(n_dim-1)` dimensional dropout.
global_avg_pooling: A class that represents `(n_dim-1)` dimensional Global Average Pooling.
pooling: A class that represents `(n_dim-1)` dimensional pooling.
batch_norm: A class that represents `(n_dim-1)` dimensional batch normalization.
"""

def __init__(self, n_output_node, input_shape):
"""Initialize the instance.
Args:
n_output_node: An integer. Number of output nodes in the network.
input_shape: A tuple. Input shape of the network.
"""
super(CnnGenerator, self).__init__(n_output_node, input_shape)
self.n_dim = len(self.input_shape) - 1
if len(self.input_shape) > 4:
Expand All @@ -31,6 +64,12 @@ def __init__(self, n_output_node, input_shape):
self.batch_norm = get_batch_norm_class(self.n_dim)

def generate(self, model_len=Constant.MODEL_LEN, model_width=Constant.MODEL_WIDTH):
"""Generates a CNN.
Args:
model_len: An integer. Number of convolutional layers.
model_width: An integer. Number of filters for the convolutional layers.
"""
pooling_len = int(model_len / 4)
graph = Graph(self.input_shape, False)
temp_input_channel = self.input_shape[-1]
Expand All @@ -53,12 +92,30 @@ def generate(self, model_len=Constant.MODEL_LEN, model_width=Constant.MODEL_WIDT


class MlpGenerator(NetworkGenerator):
"""A class to generate Multi-Layer Perceptron.
"""

def __init__(self, n_output_node, input_shape):
"""Initialize the instance.
Args:
n_output_node: An integer. Number of output nodes in the network.
input_shape: A tuple. Input shape of the network. If it is 1D, ensure the value is appended by a comma
in the tuple.
"""
super(MlpGenerator, self).__init__(n_output_node, input_shape)
if len(self.input_shape) > 1:
raise ValueError('The input dimension is too high.')

def generate(self, model_len=Constant.MLP_MODEL_LEN, model_width=Constant.MLP_MODEL_WIDTH):
"""Generates a Multi-Layer Perceptron.
Args:
model_len: An integer. Number of hidden layers.
model_width: An integer or a list of integers of length `model_len`. If it is a list, it represents the
number of nodes in each hidden layer. If it is an integer, all hidden layers have nodes equal to this
value.
"""
if type(model_width) is list and not len(model_width) == model_len:
raise ValueError('The length of \'model_width\' does not match \'model_len\'')
elif type(model_width) is int:
Expand Down

0 comments on commit cbec9df

Please sign in to comment.