Skip to content

Commit

Permalink
Merge pull request #7 from SaumyaTiwari01/master
Browse files Browse the repository at this point in the history
Trivial code correction to run mlp
  • Loading branch information
jliphard committed Nov 15, 2018
2 parents 73b2204 + b3f53ad commit de5b2dd
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions train.py
Expand Up @@ -148,7 +148,7 @@ def get_mnist_cnn():

return (nb_classes, batch_size, input_shape, x_train, x_test, y_train, y_test, epochs)

def compile_model_mlp(geneparam, nb_classes, input_shape):
def compile_model_mlp(genome, nb_classes, input_shape):
"""Compile a sequential model.
Args:
Expand All @@ -159,12 +159,12 @@ def compile_model_mlp(geneparam, nb_classes, input_shape):
"""
# Get our network parameters.
nb_layers = geneparam['nb_layers' ]
nb_neurons = geneparam['nb_neurons']
activation = geneparam['activation']
optimizer = geneparam['optimizer' ]
nb_layers = genome.geneparam['nb_layers' ]
nb_neurons = genome.nb_neurons()
activation = genome.geneparam['activation']
optimizer = genome.geneparam['optimizer' ]

logging.info("Architecture:%d,%s,%s,%d" % (nb_neurons, activation, optimizer, nb_layers))
logging.info("Architecture:%s,%s,%s,%d" % (str(nb_neurons), activation, optimizer, nb_layers))

model = Sequential()

Expand All @@ -173,9 +173,9 @@ def compile_model_mlp(geneparam, nb_classes, input_shape):

# Need input shape for first layer.
if i == 0:
model.add(Dense(nb_neurons, activation=activation, input_shape=input_shape))
model.add(Dense(nb_neurons[i], activation=activation, input_shape=input_shape))
else:
model.add(Dense(nb_neurons, activation=activation))
model.add(Dense(nb_neurons[i], activation=activation))

model.add(Dropout(0.2)) # hard-coded dropout for each layer

Expand Down Expand Up @@ -265,11 +265,11 @@ def train_and_score(genome, dataset):
logging.info("Compling Keras model")

if dataset == 'cifar10_mlp':
model = compile_model_mlp(geneparam, nb_classes, input_shape)
model = compile_model_mlp(genome, nb_classes, input_shape)
elif dataset == 'cifar10_cnn':
model = compile_model_cnn(genome, nb_classes, input_shape)
elif dataset == 'mnist_mlp':
model = compile_model_mlp(geneparam, nb_classes, input_shape)
model = compile_model_mlp(genome, nb_classes, input_shape)
elif dataset == 'mnist_cnn':
model = compile_model_cnn(genome, nb_classes, input_shape)

Expand Down

0 comments on commit de5b2dd

Please sign in to comment.