In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras import Model
from tensorflow.keras.losses import SparseCategoricalCrossentropy

from sklearn.model_selection import GridSearchCV
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from tensorflow.keras import regularizers

In [2]:
pixels = 28*28
(xtr, ytr), (xte, yte) = tf.keras.datasets.mnist.load_data()

xtr = xtr.reshape((60000, pixels)).astype(np.float32) / 255.0
xte = xte.reshape((10000, pixels)).astype(np.float32) / 255.0


In [3]:
def create_model(hidden_nodes, dropout_rate, reg_val):
    inputs = Input(shape = (pixels,), name = 'images')
    z = Dense(hidden_nodes, activation='relu', kernel_regularizer=regularizers.l2(reg_val), bias_regularizer=regularizers.l2(reg_val), name = 'hidden1')(inputs)
    z = Dropout(dropout_rate)(z)
    z = Dense(10, activation='softmax')(z)

    our_model = Model(inputs = inputs, outputs = z)
    our_model.summary()

    our_model.compile(optimizer = 'adam', loss = SparseCategoricalCrossentropy(), metrics = ['accuracy'])
    return our_model

In [4]:
dropout_rate = [0.2, 0.3]
hidden_nodes = [32]
reg_val = [1e-3]
epochs = [1, 10]

params = {'hidden_nodes':hidden_nodes, 'dropout_rate': dropout_rate, 'reg_val': reg_val, 'epochs': epochs}
model = KerasClassifier(build_fn=create_model, epochs = 10, batch_size = 32, verbose = 0)


In [5]:
gridCV = GridSearchCV(model, params, cv = 5, n_jobs = -1)
gridCV.fit(xtr, ytr)
print("best hyperparams:", gridCV.best_params_)

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
images (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
hidden1 (Dense)              (None, 32)                25120     
_________________________________________________________________
dropout (Dropout)            (None, 32)                0         
_________________________________________________________________
dense (Dense)                (None, 10)                330       
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0
_________________________________________________________________
best hyperparams: {'dropout_rate': 0.2, 'epochs': 10, 'hidden_nodes': 32, 'reg_val': 0.001}
