In [None]:
!pip install kubeflow-katib

In [None]:
!pip install tensorflow==2.9.1

In [None]:
def train_mnist_model(parameters):
    import tensorflow as tf
    import numpy as np
    import logging

    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%SZ",
        level=logging.INFO,
    )
    logging.info("--------------------------------------------------------------------------------------")
    logging.info(f"Input Parameters: {parameters}")
    logging.info("--------------------------------------------------------------------------------------\n\n")


    # Get HyperParameters from the input params dict.
    lr = float(parameters["lr"])
    num_epoch = int(parameters["num_epoch"])

    # Set dist parameters and strategy.
    is_dist = parameters["is_dist"]
    num_workers = parameters["num_workers"]
    batch_size_per_worker = 64
    batch_size_global = batch_size_per_worker * num_workers
    strategy = tf.distribute.MultiWorkerMirroredStrategy(
        communication_options=tf.distribute.experimental.CommunicationOptions(
            implementation=tf.distribute.experimental.CollectiveCommunication.RING
        )
    )

    # Callback class for logging training.
    # Katib parses metrics in this format: <metric-name>=<metric-value>.
    class CustomCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            logging.info(
                "Epoch {}/{}. accuracy={:.4f} - loss={:.4f}".format(
                    epoch+1, num_epoch, logs["accuracy"], logs["loss"]
                )
            )

    # Prepare MNIST Dataset.
    def mnist_dataset(batch_size):
        (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
        x_train = x_train / np.float32(255)
        y_train = y_train.astype(np.int64)
        train_dataset = (
            tf.data.Dataset.from_tensor_slices((x_train, y_train))
            .shuffle(60000)
            .repeat()
            .batch(batch_size)
        )
        return train_dataset

    # Build and compile CNN Model.
    def build_and_compile_cnn_model():
        model = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(28, 28)),
                tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
                tf.keras.layers.Conv2D(32, 3, activation="relu"),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(128, activation="relu"),
                tf.keras.layers.Dense(10),
            ]
        )
        model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=tf.keras.optimizers.SGD(learning_rate=lr),
            metrics=["accuracy"],
        )
        return model
    
    # Download Dataset.
    dataset = mnist_dataset(batch_size_global)

    # For dist strategy we should build model under scope().
    if is_dist:
        logging.info("Running Distributed Training")
        logging.info("--------------------------------------------------------------------------------------\n\n")
        with strategy.scope():
            model = build_and_compile_cnn_model()
    else:
        logging.info("Running Single Worker Training")
        logging.info("--------------------------------------------------------------------------------------\n\n")
        model = build_and_compile_cnn_model()
    
    # Start Training.
    model.fit(
        dataset,
        epochs=num_epoch,
        steps_per_epoch=70,
        callbacks=[CustomCallback()],
        verbose=0,
    )
    
import kubeflow.katib as katib

# Set parameters with their distribution for HyperParameter Tuning with Katib.
parameters = {
    "lr": katib.search.double(min=0.1, max=0.2),
    "num_epoch": katib.search.int(min=10, max=15),
    "is_dist": False,
    "num_workers": 1
}

# Start the Katib Experiment.
exp_name = "tune-mnist"
katib_client = katib.KatibClient()

katib_client.tune(
    name=exp_name,
    objective=train_mnist_model,
    parameters=parameters,
    algorithm_name="cmaes",
    objective_metric_name="accuracy",
    additional_metric_names=["loss"],
    max_trial_count=12,
    parallel_trial_count=2,
)

In [None]:
status = katib_client.is_experiment_succeeded(exp_name)
print(f"Katib Experiment is Succeeded: {status}\n")

best_hps = katib_client.get_optimal_hyperparameters(exp_name)

if best_hps != None:
    print("Current Optimal Trial\n")
    print(best_hps)
    
    for hp in best_hps.parameter_assignments:
        if hp.name == "lr":
            best_lr = hp.value
            print(f"Best LR: {best_lr}")
        else:
            best_num_epoch = hp.value
            print(f"Best Num Epochs: {best_num_epoch}")