In [1]:
# Use scikit-learn to grid search the batch size and epochs
import numpy
from sklearn.model_selection import GridSearchCV
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier

Using TensorFlow backend.


In [2]:
# Function to create model, required for KerasClassifier
def create_model(optimizer='adam'):
	# create model
	model = Sequential()
	model.add(Dense(12, input_dim=8, activation='relu'))
	model.add(Dense(1, activation='sigmoid'))
	# Compile model
	model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
	return model

In [3]:
# load dataset
dataset = numpy.loadtxt("./input/pima-indians-diabetes.csv", delimiter=",")

In [5]:
print len(dataset)

768


In [4]:
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]

# create model
model = KerasClassifier(
    build_fn = create_model, 
    verbose = 0
)

grid = GridSearchCV(
    estimator = model, 
    param_grid = dict(
        epochs = [5, 10, 15],
        batch_size = [5, 10, 20],
        optimizer = ['SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
    ), 
    n_jobs=1, 
    verbose=2
)

grid_result = grid.fit(X, Y)

# summarize results
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

Fitting 3 folds for each of 63 candidates, totalling 189 fits
[CV] epochs=5, optimizer=SGD, batch_size=5 ...........................
[CV] ............ epochs=5, optimizer=SGD, batch_size=5, total=   1.3s
[CV] epochs=5, optimizer=SGD, batch_size=5 ...........................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    1.4s remaining:    0.0s


[CV] ............ epochs=5, optimizer=SGD, batch_size=5, total=   0.9s
[CV] epochs=5, optimizer=SGD, batch_size=5 ...........................
[CV] ............ epochs=5, optimizer=SGD, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=RMSprop, batch_size=5 .......................
[CV] ........ epochs=5, optimizer=RMSprop, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=RMSprop, batch_size=5 .......................
[CV] ........ epochs=5, optimizer=RMSprop, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=RMSprop, batch_size=5 .......................
[CV] ........ epochs=5, optimizer=RMSprop, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=Adagrad, batch_size=5 .......................
[CV] ........ epochs=5, optimizer=Adagrad, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=Adagrad, batch_size=5 .......................
[CV] ........ epochs=5, optimizer=Adagrad, batch_size=5, total=   1.0s
[CV] epochs=5, optimizer=Adagrad, batch_size=5 .......................
[CV] .

[Parallel(n_jobs=1)]: Done 189 out of 189 | elapsed:  9.1min finished


Best: 0.669271 using {'epochs': 15, 'optimizer': 'Nadam', 'batch_size': 20}
0.542969 (0.157321) with: {'epochs': 5, 'optimizer': 'SGD', 'batch_size': 5}
0.406250 (0.065907) with: {'epochs': 5, 'optimizer': 'RMSprop', 'batch_size': 5}
0.527344 (0.097003) with: {'epochs': 5, 'optimizer': 'Adagrad', 'batch_size': 5}
0.492188 (0.111082) with: {'epochs': 5, 'optimizer': 'Adadelta', 'batch_size': 5}
0.645833 (0.012890) with: {'epochs': 5, 'optimizer': 'Adam', 'batch_size': 5}
0.522135 (0.141251) with: {'epochs': 5, 'optimizer': 'Adamax', 'batch_size': 5}
0.656250 (0.025315) with: {'epochs': 5, 'optimizer': 'Nadam', 'batch_size': 5}
0.652344 (0.022999) with: {'epochs': 10, 'optimizer': 'SGD', 'batch_size': 5}
0.660156 (0.033603) with: {'epochs': 10, 'optimizer': 'RMSprop', 'batch_size': 5}
0.618490 (0.039879) with: {'epochs': 10, 'optimizer': 'Adagrad', 'batch_size': 5}
0.563802 (0.057262) with: {'epochs': 10, 'optimizer': 'Adadelta', 'batch_size': 5}
0.656250 (0.005524) with: {'epochs': 10, 