# Hyperparameter Optimization on Slurm via SubmitIt

This notebook serves as a quickstart guide for using the Ax library with the SubmitIt library in an ask-tell loop. [SubmitIt](https://github.com/facebookincubator/submitit/) is a Python toolbox for submitting jobs to [Slurm](https://slurm.schedmd.com/quickstart.html). 

The notebook demonstrates how to use the Ax client in an ask-tell loop where each trial is scheduled to run on a Slurm cluster asynchronously.

To use this script, run it on a slurm node either as an interactive notebook or export it as a Python script and run it as a Slurm job.

## Importing Necessary Libraries
Let's start by importing the necessary libraries.

In [1]:
from ax.service.ax_client import AxClient, ObjectiveProperties
from submitit import AutoExecutor, LocalJob, DebugJob
import time

## Defining the Function to Optimize
We'll define a simple function to optimize. This function takes two parameters and returns a single metric.

In [2]:
def evaluate(parameters):
    x = parameters["x"]
    y = parameters["y"]
    return {"result": (x - 3)**2 + (y - 4)**2}

Note 1: SubmitIt's [CommandFunction](https://github.com/facebookincubator/submitit/blob/main/docs/examples.md#working-with-commands) allows you to define commands to run on the node and then redirects the standard output.

Note 2: If you are using Hydra to manage configs, SubmitIt also has [Hydra integration](https://hydra.cc/docs/plugins/submitit_launcher/).

## Setting up Ax
We'll use Ax's Service API for this example. We start by initializing an AxClient, and creaeting an experiment.

In [3]:
ax_client = AxClient()
ax_client.create_experiment(
    name="my_experiment",
    parameters=[
        {"name": "x", "type": "range", "bounds": [-10.0, 10.0]},
        {"name": "y", "type": "range", "bounds": [-10.0, 10.0]},
    ],
    objectives={"result": ObjectiveProperties(minimize=True)},
    parameter_constraints=["x + y <= 2.0"],  # Optional. Similarly, one can also define constraints on the outcome.
)

[INFO 01-10 18:30:07] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 01-10 18:30:07] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 01-10 18:30:07] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter y. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 01-10 18:30:07] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[-10.0, 10.0]), RangeParameter(name='y', parameter_type=FLOAT, range=[-10.0, 10.0])], parameter_constraints=[ParameterConstraint(1

Other commonly used [parameters types](https://ax.dev/docs/glossary.html#parameter) include `choice` parameters and `fixed` parameters. Tip: you can specify additional information for parameters such as `log_scale`, if a parameter operates at a log-scale and `is_ordered` for choice parameters that have a meaningful ordering.

Advanced tip: Ax is an excellent choice for multi-objective optimization problems when there are multiple competing objectives and the goal is to find all Pareto-optimal solutions.

## Setting up SubmitIt
We'll use SubmitIt's `AutoExecutor` for this example. We start by initializing an `AutoExecutor` and setting a few commonly used parameters. The full list of parameters is available here.

In [4]:
# Log folder and cluster. Specify cluser='local' or cluster='debug' to run the jobs locally during development.
executor = AutoExecutor(folder="submitit_runs", cluster='slurm') 
executor.update_parameters(timeout_min=60) # Timeout of the slurm job. Not including slurm scheduling delay.
executor.update_parameters(cpus_per_task=2)

Other commonly used Slurm parameters include `partition`, `ntasks_per_node`, `cpus_per_task`, `cpus_per_gpu`, `gpus_per_node`, `gpus_per_task`, `qos`, `mem`, `mem_per_gpu`, `mem_per_cpu`, `account`.

## Running the Optimization Loop
Now we're ready to run the optimization loop. We'll use an ask-tell loop, where we ask Ax for a suggestion, evaluate it using our function, and then tell Ax the result.

The example loop schedules new jobs whenever there is availability. For tasks that take a similar amount of time regardless of the parameters, it may make more sense to wait for the whole batch to finish before scheduling the next (so ax can make better informed parameter choices).

Note that `get_next_trials` may not use all available `num_parallel_jobs` if it doesn't have good parameter candidates to run.

In [5]:
total_budget = 10
num_parallel_jobs = 3

jobs = []
submitted_jobs = 0
# Run until all the jobs finished and our budget is used up.
while submitted_jobs < total_budget or jobs:
    for job, trial_index in jobs[:]:
        # Check if any jobs completed
        # Local and debug jobs don't run until .result() is called.
        if job.done() or type(job) in [LocalJob, DebugJob]:
            result = job.result()
            ax_client.complete_trial(trial_index=trial_index, raw_data=result)
            jobs.remove((job, trial_index))
    
    # Schedule new jobs if there is availablity
    trial_index_to_param, _ = ax_client.get_next_trials(
        max_trials=min(num_parallel_jobs - len(jobs), total_budget - submitted_jobs))
    for trial_index, parameters in trial_index_to_param.items():
        job = executor.submit(evaluate, parameters)
        submitted_jobs += 1
        jobs.append((job, trial_index))
        time.sleep(1)
    
    # Display the current trials.
    display(ax_client.generation_strategy.trials_as_df)

    # Sleep for a bit before checking the jobs again to avoid overloading the cluster. 
    # If you have a large number of jobs, consider adding a sleep statement in the job polling loop aswell.
    time.sleep(30)

[INFO 01-10 18:30:25] ax.service.ax_client: Generated new trial 0 with parameters {'x': -9.681339, 'y': 6.24629}.
[INFO 01-10 18:30:25] ax.service.ax_client: Generated new trial 1 with parameters {'x': -5.6056, 'y': -3.068945}.
[INFO 01-10 18:30:25] ax.service.ax_client: Generated new trial 2 with parameters {'x': -7.548772, 'y': -3.560617}.
[INFO 01-10 18:30:26] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.


Unnamed: 0,Generation Step,Generation Model,Trial Index,Trial Status,Arm Parameterizations
0,GenerationStep_0,Sobol,0,RUNNING,"{'0_0': {'x': -9.68, 'y': 6.25}}"
1,GenerationStep_0,Sobol,1,RUNNING,"{'1_0': {'x': -5.61, 'y': -3.07}}"
2,GenerationStep_0,Sobol,2,RUNNING,"{'2_0': {'x': -7.55, 'y': -3.56}}"


[INFO 01-10 18:30:56] ax.service.ax_client: Completed trial 0 with data: {'result': (165.862173, None)}.
[INFO 01-10 18:30:56] ax.service.ax_client: Completed trial 1 with data: {'result': (124.026332, None)}.
[INFO 01-10 18:30:56] ax.service.ax_client: Completed trial 2 with data: {'result': (168.439503, None)}.
[INFO 01-10 18:30:56] ax.service.ax_client: Generated new trial 3 with parameters {'x': 6.138409, 'y': -6.487855}.
  return cls(df=pd.concat(dfs, axis=0, sort=True))
[INFO 01-10 18:30:56] ax.service.ax_client: Generated new trial 4 with parameters {'x': -7.564852, 'y': -7.215186}.
[INFO 01-10 18:30:56] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.


Unnamed: 0,Generation Step,Generation Model,Trial Index,Trial Status,Arm Parameterizations
0,GenerationStep_0,Sobol,0,COMPLETED,"{'0_0': {'x': -9.68, 'y': 6.25}}"
1,GenerationStep_0,Sobol,1,COMPLETED,"{'1_0': {'x': -5.61, 'y': -3.07}}"
2,GenerationStep_0,Sobol,2,COMPLETED,"{'2_0': {'x': -7.55, 'y': -3.56}}"
3,GenerationStep_0,Sobol,3,RUNNING,"{'3_0': {'x': 6.14, 'y': -6.49}}"
4,GenerationStep_0,Sobol,4,RUNNING,"{'4_0': {'x': -7.56, 'y': -7.22}}"


[INFO 01-10 18:31:26] ax.service.ax_client: Completed trial 3 with data: {'result': (119.844715, None)}.
[INFO 01-10 18:31:26] ax.service.ax_client: Completed trial 4 with data: {'result': (237.396504, None)}.
[INFO 01-10 18:31:33] ax.service.ax_client: Generated new trial 5 with parameters {'x': -0.610494, 'y': -0.808445}.
  return cls(df=pd.concat(dfs, axis=0, sort=True))
[INFO 01-10 18:31:33] ax.modelbridge.torch: The observations are identical to the last set of observations used to fit the model. Skipping model fitting.
[INFO 01-10 18:31:38] ax.service.ax_client: Generated new trial 6 with parameters {'x': 10.0, 'y': -10.0}.
  return cls(df=pd.concat(dfs, axis=0, sort=True))
[INFO 01-10 18:31:38] ax.modelbridge.torch: The observations are identical to the last set of observations used to fit the model. Skipping model fitting.
[INFO 01-10 18:31:44] ax.service.ax_client: Generated new trial 7 with parameters {'x': 4.48408, 'y': -2.48408}.
[INFO 01-10 18:31:45] ax.modelbridge.generat

Unnamed: 0,Generation Step,Generation Model,Trial Index,Trial Status,Arm Parameterizations
0,GenerationStep_0,Sobol,0,COMPLETED,"{'0_0': {'x': -9.68, 'y': 6.25}}"
1,GenerationStep_0,Sobol,1,COMPLETED,"{'1_0': {'x': -5.61, 'y': -3.07}}"
2,GenerationStep_0,Sobol,2,COMPLETED,"{'2_0': {'x': -7.55, 'y': -3.56}}"
3,GenerationStep_0,Sobol,3,COMPLETED,"{'3_0': {'x': 6.14, 'y': -6.49}}"
4,GenerationStep_0,Sobol,4,COMPLETED,"{'4_0': {'x': -7.56, 'y': -7.22}}"
5,GenerationStep_1,BoTorch,5,RUNNING,"{'5_0': {'x': -0.61, 'y': -0.81}}"
6,GenerationStep_1,BoTorch,6,RUNNING,"{'6_0': {'x': 10.0, 'y': -10.0}}"
7,GenerationStep_1,BoTorch,7,RUNNING,"{'7_0': {'x': 4.48, 'y': -2.48}}"


[INFO 01-10 18:32:15] ax.service.ax_client: Completed trial 5 with data: {'result': (36.156811, None)}.
[INFO 01-10 18:32:15] ax.service.ax_client: Completed trial 6 with data: {'result': (245.0, None)}.
[INFO 01-10 18:32:15] ax.service.ax_client: Completed trial 7 with data: {'result': (44.245782, None)}.
[INFO 01-10 18:32:21] ax.service.ax_client: Generated new trial 8 with parameters {'x': 1.757354, 'y': 0.242646}.
  return cls(df=pd.concat(dfs, axis=0, sort=True))
[INFO 01-10 18:32:21] ax.modelbridge.torch: The observations are identical to the last set of observations used to fit the model. Skipping model fitting.
[INFO 01-10 18:32:27] ax.service.ax_client: Generated new trial 9 with parameters {'x': 1.364128, 'y': -2.222973}.
[INFO 01-10 18:32:27] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.


Unnamed: 0,Generation Step,Generation Model,Trial Index,Trial Status,Arm Parameterizations
0,GenerationStep_0,Sobol,0,COMPLETED,"{'0_0': {'x': -9.68, 'y': 6.25}}"
1,GenerationStep_0,Sobol,1,COMPLETED,"{'1_0': {'x': -5.61, 'y': -3.07}}"
2,GenerationStep_0,Sobol,2,COMPLETED,"{'2_0': {'x': -7.55, 'y': -3.56}}"
3,GenerationStep_0,Sobol,3,COMPLETED,"{'3_0': {'x': 6.14, 'y': -6.49}}"
4,GenerationStep_0,Sobol,4,COMPLETED,"{'4_0': {'x': -7.56, 'y': -7.22}}"
5,GenerationStep_1,BoTorch,5,COMPLETED,"{'5_0': {'x': -0.61, 'y': -0.81}}"
6,GenerationStep_1,BoTorch,6,COMPLETED,"{'6_0': {'x': 10.0, 'y': -10.0}}"
7,GenerationStep_1,BoTorch,7,COMPLETED,"{'7_0': {'x': 4.48, 'y': -2.48}}"
8,GenerationStep_1,BoTorch,8,RUNNING,"{'8_0': {'x': 1.76, 'y': 0.24}}"
9,GenerationStep_1,BoTorch,9,RUNNING,"{'9_0': {'x': 1.36, 'y': -2.22}}"


[INFO 01-10 18:32:57] ax.service.ax_client: Completed trial 8 with data: {'result': (15.661878, None)}.
[INFO 01-10 18:32:57] ax.service.ax_client: Completed trial 9 with data: {'result': (41.401477, None)}.
[INFO 01-10 18:32:57] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.


Unnamed: 0,Generation Step,Generation Model,Trial Index,Trial Status,Arm Parameterizations
0,GenerationStep_0,Sobol,0,COMPLETED,"{'0_0': {'x': -9.68, 'y': 6.25}}"
1,GenerationStep_0,Sobol,1,COMPLETED,"{'1_0': {'x': -5.61, 'y': -3.07}}"
2,GenerationStep_0,Sobol,2,COMPLETED,"{'2_0': {'x': -7.55, 'y': -3.56}}"
3,GenerationStep_0,Sobol,3,COMPLETED,"{'3_0': {'x': 6.14, 'y': -6.49}}"
4,GenerationStep_0,Sobol,4,COMPLETED,"{'4_0': {'x': -7.56, 'y': -7.22}}"
5,GenerationStep_1,BoTorch,5,COMPLETED,"{'5_0': {'x': -0.61, 'y': -0.81}}"
6,GenerationStep_1,BoTorch,6,COMPLETED,"{'6_0': {'x': 10.0, 'y': -10.0}}"
7,GenerationStep_1,BoTorch,7,COMPLETED,"{'7_0': {'x': 4.48, 'y': -2.48}}"
8,GenerationStep_1,BoTorch,8,COMPLETED,"{'8_0': {'x': 1.76, 'y': 0.24}}"
9,GenerationStep_1,BoTorch,9,COMPLETED,"{'9_0': {'x': 1.36, 'y': -2.22}}"



## Finally, we can retrieve the best parameters

In [6]:
best_parameters, (means, covariances) = ax_client.get_best_parameters()
print(f'Best set of parameters: {best_parameters}')
print(f'Mean objective value: {means}')
print(f'Covariance between objetives: {covariances}')



Best set of parameters: {'x': 1.7573539166976264, 'y': 0.2426460833022439}
Mean objective value: {'result': 15.661877743670697}
Covariance between objetives: {'result': {'result': nan}}
