## Wrappers for the Scikit-Learn API

You can use Sequential Keras models (single-input only) as part of your Scikit-Learn workflow via the wrappers found at keras.wrappers.scikit_learn.py.

There are two wrappers available:

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params), which implements the Scikit-Learn classifier interface,

keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params), which implements the Scikit-Learn regressor interface.

#### Arguments

* build_fn: callable function or class instance
* sk_params: model parameters & fitting parameters

build_fn should construct, compile and return a Keras model, which will then be used to fit/predict. One of the following three values could be passed to build_fn:

1. A function
2. An instance of a class that implements the call method
3. None. This means you implement a class that inherits from either KerasClassifier or KerasRegressor. The call method of the present class will then be treated as the default build_fn.

sk_params takes both model parameters and fitting parameters. Legal model parameters are the arguments of build_fn. Note that like all other estimators in scikit-learn, 'build_fn' should provide default values for its arguments, so that you could create the estimator without passing any values to sk_params.

sk_params could also accept parameters for calling fit, predict, predict_proba, and score methods (e.g., epochs, batch_size). fitting (predicting) parameters are selected in the following order:

1. Values passed to the dictionary arguments of fit, predict, predict_proba, and score methods
2. Values passed to sk_params
3. The default values of the keras.models.Sequential fit, predict, predict_proba and score methods

When using scikit-learn's grid_search API, legal tunable parameters are those you could pass to sk_params, including fitting parameters. In other words, you could use grid_search to search for the best batch_size or epochs as well as the model parameters.

In [1]:
from keras.models import Sequential
from keras.layers import Dense, Activation

def build_model(optimizer='rmsprop', dense_dims=32):
    model = Sequential()
    model.add(Dense(dense_dims, activation='relu', input_dim=100))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    return model
    

Using TensorFlow backend.


In [2]:
from keras.wrappers.scikit_learn import KerasClassifier

keras_classifier = KerasClassifier(build_model, epochs=2)

In [3]:
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))

keras_classifier.fit(data, labels)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x112b699d0>

In [4]:
keras_classifier.predict_proba(data[:2])



array([[ 0.49557364,  0.50442636],
       [ 0.47164112,  0.52835888]], dtype=float32)

In [7]:
from sklearn.model_selection import GridSearchCV

gs = GridSearchCV(keras_classifier, {'epochs': [2, 3], 'dense_dims':[16, 32]})

In [8]:
gs.fit(data, labels)

Epoch 1/2
Epoch 2/2
 32/666 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2
 32/667 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2
 32/667 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/666 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/667 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/667 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2
 32/666 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2
 32/667 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2
 32/667 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/666 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/667 [>.............................] - ETA: 0sEpoch 1/3
Epoch 2/3
Epoch 3/3
 32/667 [>.............................] - ETA: 0sEpoch 1/2
Epoch 2/2


GridSearchCV(cv=None, error_score='raise',
       estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x106a7b390>,
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'epochs': [2, 3], 'dense_dims': [16, 32]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring=None, verbose=0)

In [10]:
gs.best_params_

{'dense_dims': 32, 'epochs': 2}