In [None]:
import tensorflow as tf
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.stats import reciprocal
from sklearn.model_selection import RandomizedSearchCV

How do we know what combinations of hyperparameters is the best for our task? One option is to simply try many combinations of hyperparameters and see which one works best on the validation set (our use K-fold cross-validation). For example, we can use GridSearchCV or RandomizedSearchCV to explore the hyperparameter space.

To do this, we need to wrap our Keras models in objects that mimic Scikit-Learn regressors. The first step is to create a function that will build and compile a Keras model, given a set of hyperparameters

In [None]:
housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full)

In [None]:
def build_model(n_hidden=1, n_neurons=30, learning_rate=3e-3, input_shape=[8]):
    model = keras.models.Sequential()
    model.add(keras.layers.InputLayer(input_shape=input_shape))
    for layer in range(n_hidden):
        model.add(keras.layers.Dense(n_neurons, activation="relu"))
    model.add(keras.layers.Dense(1))
    optimizer = keras.optimizers.SGD(lr=learning_rate)
    model.compile(loss="mse", optimizer=optimizer)
    return model

Now, let's create a KerasRegressor based on this build_model() function:

In [None]:
keras_reg = keras.wrappers.scikit_learn.KerasRegressor(build_model)

In [None]:
params_distribs = {
    "n_hidden" :(0,1,2,3),
    "n_neurons": np.arange(1,100),
    "learning_rate" : reciprocal(3e-4, 3e-2),
}

In [None]:
rnd_search_cv = RandomizedSearchCV(keras_reg, params_distribs, n_iter=10, cv=3)

In [None]:
rnd_search_cv.fit(
    X_train, y_train, epochs=100,
    validation_data = (X_valid, y_valid),
    callbacks=[keras.callbacks.EarlyStopping(patience=10)]
)