In [None]:
import sys
sys.path.append("/mnt/code/")


In [None]:
import argparse
import os
import mlflow
import tensorflow
from filelock import FileLock
from tensorflow.keras.datasets import mnist

import ray
from ray import train, tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.air.integrations.keras import ReportCheckpointCallback
from domino_mlflow_utils.mlflow_utilities import DominoMLflowUtilities

In [None]:
service_host = os.environ["RAY_HEAD_SERVICE_HOST"]
service_port = os.environ["RAY_HEAD_SERVICE_PORT"]
print(ray.is_initialized())

if not ray.is_initialized():

    address=f"ray://{service_host}:{service_port}"
    temp_dir='/mnt/data/{}/'.format(os.environ['DOMINO_PROJECT_NAME']) #set to a dataset
    print(temp_dir)
    ray.init(address=address, _temp_dir=temp_dir, runtime_env={"py_modules": ['/mnt/code/domino_mlflow_utils']})

print('Ray Initializied')
print(f'Ray Host={service_host} and Ray Port={service_port}')

In [None]:
experiment_name = 'RAY-TUNE-KERAS'+'-' + os.environ['DOMINO_STARTING_USERNAME'] + '-' + os.environ['DOMINO_PROJECT_NAME']
mlflow_tracking_uri = os.environ['CLUSTER_MLFLOW_TRACKING_URI']
client = mlflow.tracking.MlflowClient()
experiment = client.get_experiment_by_name(name=experiment_name)
if(experiment is None):
    print('Creating experiment ')
    client.create_experiment(name=experiment_name)
    experiment = client.get_experiment_by_name(name=experiment_name)

print(experiment_name)
mlflow.set_experiment(experiment_name=experiment_name)


In [None]:
from ray.air._internal.mlflow import _MLflowLoggerUtil
mlflow_util = _MLflowLoggerUtil()
def initialize_run():    
    client = mlflow.tracking.MlflowClient()
    mlflow_tracking_uri = os.environ['CLUSTER_MLFLOW_TRACKING_URI']

    mlflow_util.setup_mlflow(
            tracking_uri=mlflow_tracking_uri,            
            experiment_name=experiment_name,
        )
    now = round(time.time())
    now_str=time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(now))



    mlflow_util.start_run(tags={}, run_name=f"root-{now_str}")
    return run.info.run_id

In [None]:
import tensorflow as tf
temp_dir='/mnt/data/{}/'.format(os.environ['DOMINO_PROJECT_NAME']) #set to a dataset
client = mlflow.tracking.MlflowClient()
mlflow_tracking_uri = os.environ['CLUSTER_MLFLOW_TRACKING_URI']

def train_mnist(config):
    # https://github.com/tensorflow/tensorflow/issues/32159
    

    batch_size = 128
    num_classes = 10
    epochs = 12
    parent_run_id = config['parent_run_id']

    with FileLock(os.path.expanduser("~/.data.lock")):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(config["hidden"], activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=tf.keras.optimizers.SGD(lr=config["lr"], momentum=config["momentum"]),
        metrics=["accuracy"],
    )
    #cb = ReportCheckpointCallback2(metrics={"mean_accuracy": "accuracy"})
    
    mlflow.tensorflow.autolog()
    run_tags={}
    run_tags["mlflow.parentRunId"] = parent_run_id
    mlflow_utils = DominoMLflowUtilities()    
    mlflow_utils.init(experiment_name,config,run_tags=run_tags)

    model.fit(
            x_train,
            y_train,
            batch_size=batch_size,
            epochs=epochs,
            verbose=0,
            validation_data=(x_test, y_test),
            callbacks=[ReportCheckpointCallback(metrics={"mean_accuracy": "accuracy"})],
        )
    mlflow_utils.finish()


def tune_mnist(parent_run_id):
    sched = AsyncHyperBandScheduler(
        time_attr="training_iteration", max_t=400, grace_period=20
    )

    tuner = tune.Tuner(
        tune.with_resources(train_mnist, resources={"cpu": 1, "gpu": 0}),
        tune_config=tune.TuneConfig(
            metric="mean_accuracy",
            mode="max",
            scheduler=sched,
            num_samples=10,            
        ),
        run_config=train.RunConfig(
            name="exp",
            stop={"mean_accuracy": 0.99},
            storage_path=temp_dir,
        ),
        param_space={
            "threads": 2,
            "lr": tune.uniform(0.001, 0.1),
            "momentum": tune.uniform(0.1, 0.9),
            "hidden": tune.randint(32, 512),
            "parent_run_id": parent_run_id
        },
    )
    results = tuner.fit()
    
    print("Best hyperparameters found were: ", results.get_best_result().config)
   

In [None]:
with mlflow.start_run() as run:
    parent_run_id = run.info.run_id
    print(parent_run_id)
    tune_mnist(parent_run_id)


**What the final output looks like**

```
(TunerInternal pid=1977) Trial status: 10 TERMINATED
(TunerInternal pid=1977) Current time: 2024-03-04 04:11:39. Total running time: 3min 52s
(TunerInternal pid=1977) Logical resource usage: 1.0/4 CPUs, 0/0 GPUs
(TunerInternal pid=1977) Current best trial: ceeae_00004 with mean_accuracy=0.9462000131607056 and params={'threads': 2, 'lr': 0.07199524345173318, 'momentum': 0.5562304366072593, 'hidden': 496, 'parent_run_id': 'e836dcb261914663946a937639f9933f'}
(TunerInternal pid=1977) ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
(TunerInternal pid=1977) │ Trial name                status               lr     momentum     hidden        acc     iter     total time (s) │
(TunerInternal pid=1977) ├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
(TunerInternal pid=1977) │ train_mnist_ceeae_00000   TERMINATED   0.0986987      0.112969         89   0.912367       12            33.5511 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00001   TERMINATED   0.0351482      0.585167         64   0.9278         12            32.0017 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00002   TERMINATED   0.0474651      0.123255         91   0.914867       12            39.5255 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00003   TERMINATED   0.0353766      0.49207         440   0.941683       12           101.955  │
(TunerInternal pid=1977) │ train_mnist_ceeae_00004   TERMINATED   0.0719952      0.55623         496   0.9462         12            78.4446 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00005   TERMINATED   0.0388087      0.123387        505   0.926833       12            83.2023 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00006   TERMINATED   0.0747164      0.44443         429   0.9388         12            82.2907 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00007   TERMINATED   0.017199       0.513358        484   0.943233       12           102.385  │
(TunerInternal pid=1977) │ train_mnist_ceeae_00008   TERMINATED   0.0109866      0.126216        141   0.919483       12            41.9762 │
(TunerInternal pid=1977) │ train_mnist_ceeae_00009   TERMINATED   0.00986737     0.645657         83   0.938217       12            33.1493 │
(TunerInternal pid=1977) ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
(TunerInternal pid=1977) 
Best hyperparameters found were:  {'threads': 2, 'lr': 0.07199524345173318, 'momentum': 0.5562304366072593, 'hidden': 496, 'parent_run_id': 'e836dcb261914663946a937639f9933f'}
```