## 8.3.1 Model scheduling with Hyperband


In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import autokeras as ak
import keras_tuner as kt

(x_train, y_train), (x_test, y_test) = mnist.load_data()


def build_model(hp):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(
        tf.keras.layers.Dense(
            units=hp.Int("units", min_value=32, max_value=512, step=32),
            activation="relu",
        )
    )
    model.add(tf.keras.layers.Dense(10, activation="softmax"))
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")
    return model


tuner = kt.Hyperband(
    build_model,
    objective="val_loss",
    max_epochs=10,
    factor=3,
    hyperband_iterations=2,
    directory="result_dir",
    project_name="helloworld",
)

tuner.search(x_train, y_train, epochs=1, validation_data=(x_test, y_test))


Trial 60 Complete [00h 01m 23s]
val_loss: 0.28353795409202576

Best val_loss So Far: 0.2147570550441742
Total elapsed time: 00h 29m 49s


## 8.3.2 Faster convergence with pretrained weights in search space


In [2]:
import tensorflow as tf
import autokeras as ak

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
input_node = ak.ImageInput()
output_node = ak.Normalization()(input_node)
output_node = ak.ImageAugmentation()(output_node)
output_node = ak.ResNetBlock(pretrained=True)(output_node)
output_node = ak.ClassificationHead()(output_node)
model = ak.AutoModel(
    inputs=input_node, outputs=output_node, max_trials=2, overwrite=True
)
model.fit(x_train[:100], y_train[:100], epochs=1)
model.evaluate(x_test, y_test)


Trial 2 Complete [00h 00m 12s]
val_loss: 1.519028902053833

Best val_loss So Far: 1.519028902053833
Total elapsed time: 00h 00m 29s






[2.5178773403167725, 0.12430000305175781]

In [3]:
import tensorflow as tf

resnet = tf.keras.applications.ResNet50(include_top=False, weights="imagenet")
resnet.summary()


Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, None, None,   0           ['input_3[0][0]']                
                                3)                                                                
                                                                                                  
 conv1_conv (Conv2D)            (None, None, None,   9472        ['conv1_pad[0][0]']              
                                64)                                                        

In [4]:
import tensorflow as tf
import kerastuner as kt


def build_model(hp):
    if hp.Boolean("pretrained"):
        weights = "imagenet"
    else:
        weights = None
    resnet = tf.keras.applications.ResNet50(include_top=False, weights=weights)
    if hp.Boolean("freeze"):
        resnet.trainable = False

    input_node = tf.keras.Input(shape=(32, 32, 3))
    output_node = resnet(input_node)
    output_node = tf.keras.layers.Dense(10, activation="softmax")(output_node)
    model = tf.keras.Model(inputs=input_node, outputs=output_node)
    model.compile(loss="sparse_categorical_crossentropy")
    return model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

tuner = kt.RandomSearch(
    build_model,
    objective="val_loss",
    max_trials=4,
    overwrite=True,
    directory="result_dir",
    project_name="pretrained",
)

tuner.search(
    x_train[:100], y_train[:100], epochs=1, validation_data=(x_test[:100], y_test[:100])
)


Trial 4 Complete [00h 00m 11s]
val_loss: 3.788170099258423

Best val_loss So Far: 3.788170099258423
Total elapsed time: 00h 01m 06s
