diff --git a/train.py b/train.py index e6f9280..b9203fd 100755 --- a/train.py +++ b/train.py @@ -148,7 +148,7 @@ def get_mnist_cnn(): return (nb_classes, batch_size, input_shape, x_train, x_test, y_train, y_test, epochs) -def compile_model_mlp(geneparam, nb_classes, input_shape): +def compile_model_mlp(genome, nb_classes, input_shape): """Compile a sequential model. Args: @@ -159,12 +159,12 @@ def compile_model_mlp(geneparam, nb_classes, input_shape): """ # Get our network parameters. - nb_layers = geneparam['nb_layers' ] - nb_neurons = geneparam['nb_neurons'] - activation = geneparam['activation'] - optimizer = geneparam['optimizer' ] + nb_layers = genome.geneparam['nb_layers' ] + nb_neurons = genome.nb_neurons() + activation = genome.geneparam['activation'] + optimizer = genome.geneparam['optimizer' ] - logging.info("Architecture:%d,%s,%s,%d" % (nb_neurons, activation, optimizer, nb_layers)) + logging.info("Architecture:%s,%s,%s,%d" % (str(nb_neurons), activation, optimizer, nb_layers)) model = Sequential() @@ -173,9 +173,9 @@ def compile_model_mlp(geneparam, nb_classes, input_shape): # Need input shape for first layer. if i == 0: - model.add(Dense(nb_neurons, activation=activation, input_shape=input_shape)) + model.add(Dense(nb_neurons[i], activation=activation, input_shape=input_shape)) else: - model.add(Dense(nb_neurons, activation=activation)) + model.add(Dense(nb_neurons[i], activation=activation)) model.add(Dropout(0.2)) # hard-coded dropout for each layer @@ -265,11 +265,11 @@ def train_and_score(genome, dataset): logging.info("Compling Keras model") if dataset == 'cifar10_mlp': - model = compile_model_mlp(geneparam, nb_classes, input_shape) + model = compile_model_mlp(genome, nb_classes, input_shape) elif dataset == 'cifar10_cnn': model = compile_model_cnn(genome, nb_classes, input_shape) elif dataset == 'mnist_mlp': - model = compile_model_mlp(geneparam, nb_classes, input_shape) + model = compile_model_mlp(genome, nb_classes, input_shape) elif dataset == 'mnist_cnn': model = compile_model_cnn(genome, nb_classes, input_shape)