In [1]:
from azureml.core import Workspace, Experiment
from azureml.train.dnn import TensorFlow
from azureml.widgets import RunDetails

In [2]:
from azureml.train.hyperdrive import RandomParameterSampling, BanditPolicy, HyperDriveConfig, PrimaryMetricGoal
from azureml.train.hyperdrive import choice, loguniform

In [3]:
ws = Workspace.from_config()

gpu_cluster = ws.compute_targets['gpu-cluster']
food_data = ws.datastores['food_images']

In [12]:
script_arguments = {
    "--data-dir": food_data.as_mount(),
    "--epochs": 50
}

tf_config = TensorFlow(source_directory="..",
                       entry_script='code/train/train.py',
                       script_params=script_arguments,
                       compute_target=gpu_cluster,
                       conda_packages=['pillow', 'pandas'],
                       pip_packages=['click', 'seaborn'],
                       use_docker=True,
                       use_gpu=True
                      )

# Run on subset of food categories
tf_config.run_config.arguments.extend(['apple_pie', 
                                       'baby_back_ribs', 
                                       'baklava', 
                                       'beef_carpaccio'])



In [13]:
param_sampler = RandomParameterSampling(
    {
        '--minibatch-size': choice(16, 32, 64, 128),
        '--learning-rate': loguniform(-6, -1),
        '--optimizer': choice('adadelta', 'rmsprop', 'adagrad', 'adam')
    }
)

etpolicy = BanditPolicy(evaluation_interval=2, slack_factor=0.2)

In [22]:
hdc = HyperDriveConfig(estimator=tf_config, 
                       hyperparameter_sampling=param_sampler, 
                       policy=etpolicy, 
                       primary_metric_name='acc', 
                       primary_metric_goal=PrimaryMetricGoal.MAXIMIZE, 
                       max_total_runs=50,
                       max_concurrent_runs=5)

In [15]:
hd_experiment = Experiment(ws, 'hyperparameter_search')

In [23]:
hd_run = hd_experiment.submit(hdc)

In [24]:
RunDetails(hd_run).show()

A Jupyter Widget