# Hyperparameter Optimization on Slurm via SubmitIt

This notebook serves as a quickstart guide for using the Ax library with the SubmitIt library in an ask-tell loop. [SubmitIt](https://github.com/facebookincubator/submitit/) is a Python toolbox for submitting jobs to [Slurm](https://slurm.schedmd.com/quickstart.html). 

The notebook demonstrates how to use the Ax client in an ask-tell loop where each trial is scheduled to run on a Slurm cluster asynchronously.

To use this script, run it on a slurm node either as an interactive notebook or export it as a Python script and run it as a Slurm job.

## Importing Necessary Libraries
Let's start by importing the necessary libraries.

In [7]:
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.notebook.plotting import render
from ax.service.utils.report_utils import exp_to_df
from submitit import AutoExecutor, LocalJob, DebugJob
import time

## Defining the Function to Optimize
We'll define a simple function to optimize. This function takes two parameters, and returns a single metric.

In [8]:
def evaluate(parameters):
    x = parameters["x"]
    y = parameters["y"]
    return {"result": (x - 3)**2 + (y - 4)**2}

Note 1: SubmitIt's [CommandFunction](https://github.com/facebookincubator/submitit/blob/main/docs/examples.md#working-with-commands) allows you to define commands to run on the node and then redirects the standard output.

Note 2: If you are using Hydra to manage configs, SubmitIt also has [Hydra integration](https://hydra.cc/docs/plugins/submitit_launcher/).

## Setting up Ax
We'll use Ax's Service API for this example. We start by initializing an AxClient and creating an experiment.

In [9]:
ax_client = AxClient()
ax_client.create_experiment(
    name="my_experiment",
    parameters=[
        {"name": "x", "type": "range", "bounds": [-10.0, 10.0]},
        {"name": "y", "type": "range", "bounds": [-10.0, 10.0]},
    ],
    objectives={"result": ObjectiveProperties(minimize=True)},
    parameter_constraints=["x + y <= 2.0"],  # Optional.
)

[INFO 01-11 16:18:48] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 01-11 16:18:48] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 01-11 16:18:48] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter y. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 01-11 16:18:48] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[-10.0, 10.0]), RangeParameter(name='y', parameter_type=FLOAT, range=[-10.0, 10.0])], parameter_constraints=[ParameterConstraint(1

Other commonly used [parameters types](https://ax.dev/docs/glossary.html#parameter) include `choice` parameters and `fixed` parameters. 

Tip 1: you can specify additional information for parameters such as `log_scale`, if a parameter operates at a log-scale and `is_ordered` for choice parameters that have a meaningful ordering.

Tip 2: Ax is an excellent choice for multi-objective optimization problems when there are multiple competing objectives and the goal is to find all Pareto-optimal solutions.

Tip 3: One can define constraints on both the parameters and the outcome.

## Setting up SubmitIt
We'll use SubmitIt's `AutoExecutor` for this example. We start by initializing an `AutoExecutor`, and setting a few commonly used parameters. The full list of parameters is available here.

In [10]:
# Log folder and cluster. Specify cluster='local' or cluster='debug' to run the jobs locally during development.
# When we're are ready for deployment, switch to cluster='slurm' 
executor = AutoExecutor(folder="/tmp/submitit_runs", cluster='local') 
executor.update_parameters(timeout_min=60) # Timeout of the slurm job. Not including slurm scheduling delay.
executor.update_parameters(cpus_per_task=2)

Other commonly used Slurm parameters include `partition`, `ntasks_per_node`, `cpus_per_task`, `cpus_per_gpu`, `gpus_per_node`, `gpus_per_task`, `qos`, `mem`, `mem_per_gpu`, `mem_per_cpu`, `account`.

## Running the Optimization Loop
Now, we're ready to run the optimization loop. We'll use an ask-tell loop, where we ask Ax for a suggestion, evaluate it using our function, and then tell Ax the result.

The example loop schedules new jobs whenever there is availability. For tasks that take a similar amount of time regardless of the parameters, it may make more sense to wait for the whole batch to finish before scheduling the next (so ax can make better informed parameter choices).

Note that `get_next_trials` may not use all available `num_parallel_jobs` if it doesn't have good parameter candidates to run.

In [11]:
total_budget = 10
num_parallel_jobs = 3

jobs = []
submitted_jobs = 0
# Run until all the jobs have finished and our budget is used up.
while submitted_jobs < total_budget or jobs:
    for job, trial_index in jobs[:]:
        # Poll if any jobs completed
        # Local and debug jobs don't run until .result() is called.
        if job.done() or type(job) in [LocalJob, DebugJob]:
            result = job.result()
            ax_client.complete_trial(trial_index=trial_index, raw_data=result)
            jobs.remove((job, trial_index))
    
    # Schedule new jobs if there is availablity
    trial_index_to_param, _ = ax_client.get_next_trials(
        max_trials=min(num_parallel_jobs - len(jobs), total_budget - submitted_jobs))
    for trial_index, parameters in trial_index_to_param.items():
        job = executor.submit(evaluate, parameters)
        submitted_jobs += 1
        jobs.append((job, trial_index))
        time.sleep(1)
    
    # Display the current trials.
    display(exp_to_df(ax_client.experiment))

    # Sleep for a bit before checking the jobs again to avoid overloading the cluster. 
    # If you have a large number of jobs, consider adding a sleep statement in the job polling loop aswell.
    time.sleep(30)

[INFO 01-11 16:18:48] ax.service.ax_client: Generated new trial 0 with parameters {'x': -9.878379, 'y': -2.01497}.
[INFO 01-11 16:18:48] ax.service.ax_client: Generated new trial 1 with parameters {'x': 1.162749, 'y': -8.129749}.
[INFO 01-11 16:18:48] ax.service.ax_client: Generated new trial 2 with parameters {'x': -8.151591, 'y': -7.791025}.


[INFO 01-11 16:18:51] ax.service.utils.report_utils: No results present for the specified metrics `[Metric('result')]`. Returning arm parameters and metadata only.


Unnamed: 0,trial_index,arm_name,trial_status,generation_method,x,y
0,0,0_0,RUNNING,Sobol,-9.878379,-2.01497
1,1,1_0,RUNNING,Sobol,1.162749,-8.129749
2,2,2_0,RUNNING,Sobol,-8.151591,-7.791025


[INFO 01-11 16:19:21] ax.service.ax_client: Completed trial 0 with data: {'result': (202.032516, None)}.
[INFO 01-11 16:19:21] ax.service.ax_client: Completed trial 1 with data: {'result': (150.50631, None)}.
[INFO 01-11 16:19:21] ax.service.ax_client: Completed trial 2 with data: {'result': (263.386245, None)}.
[INFO 01-11 16:19:21] ax.service.ax_client: Generated new trial 3 with parameters {'x': -4.992916, 'y': 3.784711}.

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

[INFO 01-11 16:19:21] ax.service.ax_client: Generated new trial 4 with parameters {'x': -6.408905, 'y': -7.56304}.

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining 

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,result,x,y
0,0,0_0,COMPLETED,Sobol,202.032516,-9.878379,-2.01497
1,1,1_0,COMPLETED,Sobol,150.50631,1.162749,-8.129749
2,2,2_0,COMPLETED,Sobol,263.386245,-8.151591,-7.791025
3,3,3_0,RUNNING,Sobol,,-4.992916,3.784711
4,4,4_0,RUNNING,Sobol,,-6.408905,-7.56304


[INFO 01-11 16:19:54] ax.service.ax_client: Completed trial 3 with data: {'result': (63.933052, None)}.
[INFO 01-11 16:19:54] ax.service.ax_client: Completed trial 4 with data: {'result': (222.231402, None)}.
[INFO 01-11 16:19:59] ax.service.ax_client: Generated new trial 5 with parameters {'x': -6.685398, 'y': 8.685398}.

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

[INFO 01-11 16:19:59] ax.modelbridge.torch: The observations are identical to the last set of observations used to fit the model. Skipping model fitting.
[INFO 01-11 16:20:04] ax.service.ax_client: Generated new trial 6 with parameters {'x': -6.690148, 'y': 8.690148}.

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer 

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,result,x,y
0,0,0_0,COMPLETED,Sobol,202.032516,-9.878379,-2.01497
1,1,1_0,COMPLETED,Sobol,150.50631,1.162749,-8.129749
2,2,2_0,COMPLETED,Sobol,263.386245,-8.151591,-7.791025
3,3,3_0,COMPLETED,Sobol,63.933052,-4.992916,3.784711
4,4,4_0,COMPLETED,Sobol,222.231402,-6.408905,-7.56304
5,5,5_0,RUNNING,BoTorch,,-6.685398,8.685398
6,6,6_0,RUNNING,BoTorch,,-6.690148,8.690148
7,7,7_0,RUNNING,BoTorch,,-6.671281,8.671281


[INFO 01-11 16:20:42] ax.service.ax_client: Completed trial 5 with data: {'result': (115.759888, None)}.
[INFO 01-11 16:20:42] ax.service.ax_client: Completed trial 6 with data: {'result': (115.896449, None)}.
[INFO 01-11 16:20:42] ax.service.ax_client: Completed trial 7 with data: {'result': (115.354542, None)}.
[INFO 01-11 16:20:47] ax.service.ax_client: Generated new trial 8 with parameters {'x': -1.287043, 'y': 3.287043}.

The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

[INFO 01-11 16:20:47] ax.modelbridge.torch: The observations are identical to the last set of observations used to fit the model. Skipping model fitting.
[INFO 01-11 16:20:53] ax.service.ax_client: Generated new trial 9 with parameters {'x': -2.552892, 'y': 0.471419}.

The behavior of Dat

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,result,x,y
0,0,0_0,COMPLETED,Sobol,202.032516,-9.878379,-2.01497
1,1,1_0,COMPLETED,Sobol,150.50631,1.162749,-8.129749
2,2,2_0,COMPLETED,Sobol,263.386245,-8.151591,-7.791025
3,3,3_0,COMPLETED,Sobol,63.933052,-4.992916,3.784711
4,4,4_0,COMPLETED,Sobol,222.231402,-6.408905,-7.56304
5,5,5_0,COMPLETED,BoTorch,115.759888,-6.685398,8.685398
6,6,6_0,COMPLETED,BoTorch,115.896449,-6.690148,8.690148
7,7,7_0,COMPLETED,BoTorch,115.354542,-6.671281,8.671281
8,8,8_0,RUNNING,BoTorch,,-1.287043,3.287043
9,9,9_0,RUNNING,BoTorch,,-2.552892,0.471419


[INFO 01-11 16:21:25] ax.service.ax_client: Completed trial 8 with data: {'result': (18.887042, None)}.
[INFO 01-11 16:21:25] ax.service.ax_client: Completed trial 9 with data: {'result': (43.285501, None)}.


Unnamed: 0,trial_index,arm_name,trial_status,generation_method,result,x,y
0,0,0_0,COMPLETED,Sobol,202.032516,-9.878379,-2.01497
1,1,1_0,COMPLETED,Sobol,150.50631,1.162749,-8.129749
2,2,2_0,COMPLETED,Sobol,263.386245,-8.151591,-7.791025
3,3,3_0,COMPLETED,Sobol,63.933052,-4.992916,3.784711
4,4,4_0,COMPLETED,Sobol,222.231402,-6.408905,-7.56304
5,5,5_0,COMPLETED,BoTorch,115.759888,-6.685398,8.685398
6,6,6_0,COMPLETED,BoTorch,115.896449,-6.690148,8.690148
7,7,7_0,COMPLETED,BoTorch,115.354542,-6.671281,8.671281
8,8,8_0,COMPLETED,BoTorch,18.887042,-1.287043,3.287043
9,9,9_0,COMPLETED,BoTorch,43.285501,-2.552892,0.471419



## Finally

We can retrieve the best parameters and render the response surface.

In [12]:
best_parameters, (means, covariances) = ax_client.get_best_parameters()
print(f'Best set of parameters: {best_parameters}')
print(f'Mean objective value: {means}')
# The covariance is only meaningful when multiple objectives are present.

render(ax_client.get_contour_plot())


[INFO 01-11 16:21:55] ax.service.ax_client: Retrieving contour plot with parameter 'x' on X-axis and 'y' on Y-axis, for metric 'result'. Remaining parameters are affixed to the middle of their range.


Best set of parameters: {'x': -1.2870425447033522, 'y': 3.2870425447033504}
Mean objective value: {'result': 18.902048490067955}
