In [1]:
#Tutorial from https://machinelearningmastery.com/grid-search-hyperparameters-deep-learning-models-python-keras/
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier

Using TensorFlow backend.


## 1. How to Use Keras Models in scikit-learn

Keras models can be used in scikit-learn by wrapping them with the KerasClassifier or KerasRegressor class.

To use these wrappers you must define a function that creates and returns your Keras sequential model, then pass this function to the build_fn argument when constructing the KerasClassifier class.

In [2]:
def create_model():
    model = Sequential()
    model.add(Dense(8, input_dim=4, activation='relu'))
    model.add(Dense(3, activation='softmax')) 
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

The constructor for the KerasClassifier class can take default arguments that are passed on to the calls to **model.fit()**, such as the number of epochs and the batch size.

In [3]:
model = KerasClassifier(build_fn=create_model, epochs=10)

The constructor for the KerasClassifier class can also take new arguments that can be passed to your custom **create_model()** function. These new arguments must also be defined in the signature of your **create_model()** function with default parameters. This is useful to search the best number of neurons in the network, or the best activation function, etc. Because we can make a for loop in some list of number of neurons or a list with some activation functions, evaluate all the possible models and compare them after the loop.

In [4]:
def create_model(dropout_rate=0.0):
    model = Sequential()
    model.add(Dense(8, input_dim=4, activation='relu'))
    model.add(Dense(3, activation='softmax')) 
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

In [5]:
model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

You can learn more about the scikit-learn wrapper in [Keras API documentation](https://keras.io/scikit-learn-api/).