## Using a custom BoTorch model with Ax

In this tutorial, we illustrate how to use a custom BoTorch model within Ax's `botorch_modular` API. This allows us to harness the convenience of Ax for running Bayesian Optimization loops, while at the same time maintaining full flexibility in terms of the modeling.

Acquisition functions and strategies for optimizing acquisitions can be swapped out in much the same fashion. See for example the tutorial for [Implementing a custom acquisition function](./custom_acquisition).

If you want to do something non-standard, or would like to have full insight into every aspect of the implementation, please see [this tutorial](./closed_loop_botorch_only) for how to write your own full optimization loop in BoTorch.

Next cell sets up a decorator solely to speed up the testing of the notebook. You can safely ignore this cell and the use of the decorator throughout the tutorial.

In [1]:
import os
from contextlib import contextmanager

from ax.utils.testing.mock import fast_botorch_optimize_context_manager


SMOKE_TEST = os.environ.get("SMOKE_TEST")


@contextmanager
def dummy_context_manager():
    yield


if SMOKE_TEST:
    fast_smoke_test = fast_botorch_optimize_context_manager
else:
    fast_smoke_test = dummy_context_manager

### Implementing the custom model

For this tutorial, we implement a very simple gpytorch Exact GP Model that uses an RBF kernel (with ARD) and infers a (homoskedastic) noise level.

Model definition is straightforward - here we implement a gpytorch `ExactGP` that also inherits from `GPyTorchModel` -- this adds all the api calls that botorch expects in its various modules. 

*Note:* botorch also allows implementing other custom models as long as they follow the minimal `Model` API. For more information, please see the [Model Documentation](../docs/models).

In [2]:
from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.containers import TrainingData
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP


class SimpleCustomGP(ExactGP, GPyTorchModel):

    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    @classmethod
    def construct_inputs(cls, training_data: TrainingData, **kwargs):
        r"""Construct kwargs for the `SimpleCustomGP` from `TrainingData` and other options.

        Args:
            training_data: `TrainingData` container with data for single outcome
                or for multiple outcomes for batched multi-output case.
            **kwargs: None expected for this class.
        """
        return {"train_X": training_data.X, "train_Y": training_data.Y}

### Instantiate a `BoTorchModel` in Ax

A `BoTorchModel` in Ax encapsulates both the surrogate (commonly referred to as `Model` in BoTorch) and an acquisition function. Here, we will only specify the custom surrogate and let Ax choose the default acquisition function.

Most models should work with the base `Surrogate` in Ax, except for BoTorch `ModelListGP`, which works with `ListSurrogate`.
Note that the `Model` (e.g., the `SimpleCustomGP`) must implement `construct_inputs`, as this is used to construct the inputs required for instantiating a `Model` instance from the experiment data.

In case the `Model` requires a complex set of arguments that cannot be constructed using a `construct_inputs` method, one can initialize the `model` and supply it via `Surrogate.from_botorch(model=model, mll_class=<Optional>)`, replacing the `Surrogate(...)` below.

In [3]:
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate


ax_model = BoTorchModel(
    surrogate=Surrogate(
        # The model class to use
        botorch_model_class=SimpleCustomGP,
        # Optional, MLL class with which to optimize model parameters
        # mll_class=ExactMarginalLogLikelihood,
        # Optional, dictionary of keyword arguments to model constructor
        # model_options={}
    ),
    # Optional, acquisition function class to use - see custom acquisition tutorial
    # botorch_acqf_class=qExpectedImprovement,
)

### Combine with a `ModelBridge`

`Model`s in Ax require a `ModelBridge` to interface with `Experiment`s. A `ModelBridge` takes the inputs supplied by the `Experiment` and converts them to the inputs expected by the `Model`. For a `BoTorchModel`, we use `TorchModelBridge`. The usage is as follows:

```
from ax.modelbridge import TorchModelBridge
model_bridge = TorchModelBridge(
    experiment: Experiment,
    search_space: SearchSpace,
    data: Data,
    model: TorchModel,
    transforms: List[Type[Transform]],
    # And additional optional arguments.
)
# To generate a trial
trial = model_bridge.gen(1)
```

For Modular BoTorch interface, we can combine the creation of the `BoTorchModel` and the `TorchModelBridge` into a single step as follows:

```
from ax.modelbridge.registry import Models
model_bridge = Models.BOTORCH_MODULAR(
    experiment=experiment,
    data=data,
    surrogate=Surrogate(SimpleCustomGP),  # Optional, will use default if unspecified
    # Optional, will use default if unspecified
    # botorch_acqf_class=qNoisyExpectedImprovement,  
)
# To generate a trial
trial = model_bridge.gen(1)
```


# Using the custom model in Ax to optimize the Branin function

We will demonstrate this with both the Service API (simpler, easier to use) and the Developer API (advanced, more customizable).

## Optimization with the Service API

A detailed tutorial on the Service API can be found [here](https://ax.dev/tutorials/gpei_hartmann_service.html).

In order to customize the way the candidates are created in Service API, we need to construct a new `GenerationStrategy` and pass it into `AxClient`.

In [4]:
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models


gs = GenerationStrategy(
    steps=[
        # Quasi-random initialization step
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
        ),
        # Bayesian optimization step using the custom acquisition function
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,  # No limitation on how many trials should be produced from this step
            # For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
            model_kwargs={
                "surrogate": Surrogate(SimpleCustomGP),
            },
        ),
    ]
)

### Setting up the experiment

In order to use the `GenerationStrategy` we just created, we will pass it into the `AxClient`.

In [5]:
import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            # It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
            # Otherwise, the parameter would be inferred as an integer range.
            "bounds": [-5.0, 10.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 15.0],
        },
    ],
    objectives={
        "branin": ObjectiveProperties(minimize=True),
    },
)
# Setup a function to evaluate the trials
branin = Branin()


def evaluate(parameters):
    x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
    # In our case, standard error is 0, since we are computing a synthetic function.
    # Our custom model does not utilize the SEM, so this (passing 0) has no effect.
    return {"branin": (branin(x).item(), 0.0)}

[INFO 03-08 21:09:59] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.


[INFO 03-08 21:09:59] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.


[INFO 03-08 21:09:59] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.


[INFO 03-08 21:09:59] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).


### Running the BO loop

In [6]:
with fast_smoke_test():
    for i in range(30):
        parameters, trial_index = ax_client.get_next_trial()
        # Local evaluation here can be replaced with deployment to external system.
        ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

[INFO 03-08 21:09:59] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 1.053491, 'x2': 9.695006}.


[INFO 03-08 21:09:59] ax.service.ax_client: Completed trial 0 with data: {'branin': (42.0839, 0.0)}.


[INFO 03-08 21:09:59] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 9.673876, 'x2': 5.031455}.


[INFO 03-08 21:09:59] ax.service.ax_client: Completed trial 1 with data: {'branin': (6.161984, 0.0)}.


[INFO 03-08 21:09:59] ax.service.ax_client: Generated new trial 2 with parameters {'x1': -2.698363, 'x2': 13.092639}.


[INFO 03-08 21:09:59] ax.service.ax_client: Completed trial 2 with data: {'branin': (4.775831, 0.0)}.


[INFO 03-08 21:09:59] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 4.255731, 'x2': 5.362701}.


[INFO 03-08 21:09:59] ax.service.ax_client: Completed trial 3 with data: {'branin': (20.177181, 0.0)}.


[INFO 03-08 21:09:59] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 9.896219, 'x2': 4.347518}.


[INFO 03-08 21:09:59] ax.service.ax_client: Completed trial 4 with data: {'branin': (3.536641, 0.0)}.


[INFO 03-08 21:10:00] ax.service.ax_client: Generated new trial 5 with parameters {'x1': -3.947539, 'x2': 12.530772}.


[INFO 03-08 21:10:00] ax.service.ax_client: Completed trial 5 with data: {'branin': (6.466525, 0.0)}.


[INFO 03-08 21:10:02] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 10.0, 'x2': 1.062917}.


[INFO 03-08 21:10:02] ax.service.ax_client: Completed trial 6 with data: {'branin': (5.706899, 0.0)}.


[INFO 03-08 21:10:03] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 7.218442, 'x2': 0.418715}.


[INFO 03-08 21:10:03] ax.service.ax_client: Completed trial 7 with data: {'branin': (16.378994, 0.0)}.


[INFO 03-08 21:10:04] ax.service.ax_client: Generated new trial 8 with parameters {'x1': -3.19908, 'x2': 0.113994}.


[INFO 03-08 21:10:04] ax.service.ax_client: Completed trial 8 with data: {'branin': (151.693604, 0.0)}.


[INFO 03-08 21:10:06] ax.service.ax_client: Generated new trial 9 with parameters {'x1': -5.0, 'x2': 15.0}.


[INFO 03-08 21:10:06] ax.service.ax_client: Completed trial 9 with data: {'branin': (17.508297, 0.0)}.


[INFO 03-08 21:10:08] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 8.022765, 'x2': 3.542146}.


[INFO 03-08 21:10:08] ax.service.ax_client: Completed trial 10 with data: {'branin': (12.370395, 0.0)}.


[INFO 03-08 21:10:09] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 5.533142, 'x2': 15.0}.


[INFO 03-08 21:10:09] ax.service.ax_client: Completed trial 11 with data: {'branin': (208.881226, 0.0)}.


[INFO 03-08 21:10:10] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 2.911042, 'x2': 2.161366}.


[INFO 03-08 21:10:10] ax.service.ax_client: Completed trial 12 with data: {'branin': (0.74213, 0.0)}.


[INFO 03-08 21:10:11] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 3.596695, 'x2': 0.0}.


[INFO 03-08 21:10:11] ax.service.ax_client: Completed trial 13 with data: {'branin': (5.165417, 0.0)}.


[INFO 03-08 21:10:13] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 3.632046, 'x2': 2.031652}.


[INFO 03-08 21:10:13] ax.service.ax_client: Completed trial 14 with data: {'branin': (1.541467, 0.0)}.


[INFO 03-08 21:10:16] ax.service.ax_client: Generated new trial 15 with parameters {'x1': -2.117163, 'x2': 9.798724}.


[INFO 03-08 21:10:16] ax.service.ax_client: Completed trial 15 with data: {'branin': (5.033342, 0.0)}.


[INFO 03-08 21:10:20] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 10.0, 'x2': 3.305778}.


[INFO 03-08 21:10:20] ax.service.ax_client: Completed trial 16 with data: {'branin': (2.034841, 0.0)}.


[INFO 03-08 21:10:22] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 2.52183, 'x2': 2.546804}.


[INFO 03-08 21:10:22] ax.service.ax_client: Completed trial 17 with data: {'branin': (2.251922, 0.0)}.


[INFO 03-08 21:10:24] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 3.258201, 'x2': 2.051793}.


[INFO 03-08 21:10:24] ax.service.ax_client: Completed trial 18 with data: {'branin': (0.481058, 0.0)}.


[INFO 03-08 21:10:26] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 3.416109, 'x2': 2.004387}.


[INFO 03-08 21:10:26] ax.service.ax_client: Completed trial 19 with data: {'branin': (0.761814, 0.0)}.


[INFO 03-08 21:10:27] ax.service.ax_client: Generated new trial 20 with parameters {'x1': 3.143472, 'x2': 2.288195}.


[INFO 03-08 21:10:27] ax.service.ax_client: Completed trial 20 with data: {'branin': (0.398119, 0.0)}.


[INFO 03-08 21:10:29] ax.service.ax_client: Generated new trial 21 with parameters {'x1': 3.151661, 'x2': 1.988985}.


[INFO 03-08 21:10:29] ax.service.ax_client: Completed trial 21 with data: {'branin': (0.475756, 0.0)}.


[INFO 03-08 21:10:30] ax.service.ax_client: Generated new trial 22 with parameters {'x1': 3.320216, 'x2': 2.199366}.


[INFO 03-08 21:10:30] ax.service.ax_client: Completed trial 22 with data: {'branin': (0.55421, 0.0)}.


[INFO 03-08 21:10:31] ax.service.ax_client: Generated new trial 23 with parameters {'x1': 3.35505, 'x2': 1.886039}.


[INFO 03-08 21:10:31] ax.service.ax_client: Completed trial 23 with data: {'branin': (0.66797, 0.0)}.


[INFO 03-08 21:10:32] ax.service.ax_client: Generated new trial 24 with parameters {'x1': 2.995448, 'x2': 2.447697}.


[INFO 03-08 21:10:32] ax.service.ax_client: Completed trial 24 with data: {'branin': (0.503379, 0.0)}.


[INFO 03-08 21:10:32] ax.service.ax_client: Generated new trial 25 with parameters {'x1': -3.575003, 'x2': 15.0}.


[INFO 03-08 21:10:32] ax.service.ax_client: Completed trial 25 with data: {'branin': (4.038473, 0.0)}.


[INFO 03-08 21:10:33] ax.service.ax_client: Generated new trial 26 with parameters {'x1': 3.10412, 'x2': 2.184513}.


[INFO 03-08 21:10:33] ax.service.ax_client: Completed trial 26 with data: {'branin': (0.419002, 0.0)}.


[INFO 03-08 21:10:34] ax.service.ax_client: Generated new trial 27 with parameters {'x1': 3.409809, 'x2': 2.126092}.


[INFO 03-08 21:10:34] ax.service.ax_client: Completed trial 27 with data: {'branin': (0.743807, 0.0)}.


[INFO 03-08 21:10:35] ax.service.ax_client: Generated new trial 28 with parameters {'x1': 3.21861, 'x2': 2.25319}.


[INFO 03-08 21:10:36] ax.service.ax_client: Completed trial 28 with data: {'branin': (0.427756, 0.0)}.


[INFO 03-08 21:10:36] ax.service.ax_client: Generated new trial 29 with parameters {'x1': 3.133336, 'x2': 2.442237}.


[INFO 03-08 21:10:36] ax.service.ax_client: Completed trial 29 with data: {'branin': (0.424067, 0.0)}.


### Viewing the evaluated trials

In [7]:
ax_client.get_trials_data_frame()

Unnamed: 0,branin,trial_index,arm_name,x1,x2,trial_status,generation_method
0,42.0839,0,0_0,1.053491,9.695006,COMPLETED,Sobol
3,6.161984,1,1_0,9.673876,5.031455,COMPLETED,Sobol
15,4.775831,2,2_0,-2.698363,13.092639,COMPLETED,Sobol
23,20.177181,3,3_0,4.255731,5.362701,COMPLETED,Sobol
24,3.536641,4,4_0,9.896219,4.347518,COMPLETED,Sobol
25,6.466525,5,5_0,-3.947539,12.530772,COMPLETED,BoTorch
26,5.706899,6,6_0,10.0,1.062917,COMPLETED,BoTorch
27,16.378994,7,7_0,7.218442,0.418715,COMPLETED,BoTorch
28,151.693604,8,8_0,-3.19908,0.113994,COMPLETED,BoTorch
29,17.508297,9,9_0,-5.0,15.0,COMPLETED,BoTorch


In [8]:
parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {parameters}")
print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")

Best parameters: {'x1': 3.143471996153947, 'x2': 2.2881945397553003}
Corresponding mean: {'branin': 0.3981189727783203}, covariance: {'branin': {'branin': 0.0}}


### Plotting the response surface and optimization progress

In [9]:
from ax.utils.notebook.plotting import render
render(ax_client.get_contour_plot())

[INFO 03-08 21:10:37] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'branin'. Remaining parameters are affixed to the middle of their range.


In [10]:
best_parameters, values = ax_client.get_best_parameters()
best_parameters, values[0]

({'x1': 3.143471996153947, 'x2': 2.2881945397553003},
 {'branin': 0.3981189727783203})

In [11]:
render(ax_client.get_optimization_trace(objective_optimum=0.397887))

## Optimization with the Developer API

A detailed tutorial on the Service API can be found [here](https://ax.dev/tutorials/gpei_hartmann_developer.html).

### Set up the Experiment in Ax

We need 3 inputs for an Ax `Experiment`:
- A search space to optimize over;
- An optimization config specifiying the objective / metrics to optimize, and optional outcome constraints;
- A runner that handles the deployment of trials. For a synthetic optimization problem, such as here, this only returns simple metadata about the trial.

In [12]:
import pandas as pd
import torch
from ax import (
    Data,
    Experiment,
    Metric,
    Objective,
    OptimizationConfig,
    ParameterType,
    RangeParameter,
    Runner,
    SearchSpace,
)
from botorch.test_functions import Branin


branin_func = Branin()

# For our purposes, the metric is a wrapper that structures the function output.
class BraninMetric(Metric):
    def fetch_trial_data(self, trial):
        records = []
        for arm_name, arm in trial.arms_by_name.items():
            params = arm.parameters
            tensor_params = torch.tensor([params["x1"], params["x2"]])
            records.append(
                {
                    "arm_name": arm_name,
                    "metric_name": self.name,
                    "trial_index": trial.index,
                    "mean": branin_func(tensor_params),
                    "sem": 0.0,  # SEM - standard error of the mean - corresponds to Yvar in BoTorch.
                }
            )
        return Data(df=pd.DataFrame.from_records(records))


# Search space defines the parameters, their types, and acceptable values.
search_space = SearchSpace(
    parameters=[
        RangeParameter(name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10),
        RangeParameter(name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15),
    ]
)

optimization_config = OptimizationConfig(
    objective=Objective(
        metric=BraninMetric(name="branin_metric", lower_is_better=True),
        minimize=True,  # This is optional since we specified `lower_is_better=True`
    )
)


class MyRunner(Runner):
    def run(self, trial):
        trial_metadata = {"name": str(trial.index)}
        return trial_metadata


exp = Experiment(
    name="branin_experiment",
    search_space=search_space,
    optimization_config=optimization_config,
    runner=MyRunner(),
)

### Run the BO loop

First, we use the Sobol generator to create 5 (quasi-) random initial point in the search space. Ax controls objective evaluations via `Trial`s. 
- We generate a `Trial` using a generator run, e.g., `Sobol` below. A `Trial` specifies relevant metadata as well as the parameters to be evaluated. At this point, the `Trial` is at the `CANDIDATE` stage.
- We run the `Trial` using `Trial.run()`. In our example, this serves to mark the `Trial` as `RUNNING`. In an advanced application, this can be used to dispatch the `Trial` for evaluation on a remote server.
- Once the `Trial` is done running, we mark it as `COMPLETED`. This tells the `Experiment` that it can fetch the `Trial` data. 

A `Trial` supports evaluation of a single parameterization. For parallel evaluations, see [`BatchTrial`](https://ax.dev/docs/core.html#trial-vs-batch-trial).

In [13]:
from ax.modelbridge.registry import Models


sobol = Models.SOBOL(exp.search_space)

for i in range(5):
    trial = exp.new_trial(generator_run=sobol.gen(1))
    trial.run()
    trial.mark_completed()

Once the initial (quasi-) random stage is completed, we can use our `SimpleCustomGP` with the default acquisition function chosen by `Ax` to run the BO loop.

In [14]:
with fast_smoke_test():
    for i in range(25):
        model_bridge = Models.BOTORCH_MODULAR(
            experiment=exp,
            data=exp.fetch_data(),
            surrogate=Surrogate(SimpleCustomGP),
        )
        trial = exp.new_trial(generator_run=model_bridge.gen(1))
        trial.run()
        trial.mark_completed()

View the trials attached to the `Experiment`.

In [15]:
exp.trials

{0: Trial(experiment_name='branin_experiment', index=0, status=TrialStatus.COMPLETED, arm=Arm(name='0_0', parameters={'x1': 3.9835897088050842, 'x2': 12.342978715896606})),
 1: Trial(experiment_name='branin_experiment', index=1, status=TrialStatus.COMPLETED, arm=Arm(name='1_0', parameters={'x1': -3.593796114437282, 'x2': 0.6191267631947994})),
 2: Trial(experiment_name='branin_experiment', index=2, status=TrialStatus.COMPLETED, arm=Arm(name='2_0', parameters={'x1': 1.2375280819833279, 'x2': 10.779688828624785})),
 3: Trial(experiment_name='branin_experiment', index=3, status=TrialStatus.COMPLETED, arm=Arm(name='3_0', parameters={'x1': 8.4092768561095, 'x2': 6.551370648667216})),
 4: Trial(experiment_name='branin_experiment', index=4, status=TrialStatus.COMPLETED, arm=Arm(name='4_0', parameters={'x1': 6.420887252315879, 'x2': 7.687858119606972})),
 5: Trial(experiment_name='branin_experiment', index=5, status=TrialStatus.COMPLETED, arm=Arm(name='5_0', parameters={'x1': -5.0, 'x2': 5.964

View the evaluation data about these trials.

In [16]:
exp.fetch_data().df

Unnamed: 0,arm_name,metric_name,mean,sem,trial_index
0,0_0,branin_metric,116.666603,0.0,0
1,1_0,branin_metric,164.411484,0.0,1
2,2_0,branin_metric,56.06245,0.0,2
3,3_0,branin_metric,27.97538,0.0,3
4,4_0,branin_metric,62.821098,0.0,4
5,5_0,branin_metric,138.68338,0.0,5
6,6_0,branin_metric,121.225121,0.0,6
7,7_0,branin_metric,4.042157,0.0,7
8,8_0,branin_metric,5.210078,0.0,8
9,9_0,branin_metric,0.560489,0.0,9


### Plot results

We can use convenient Ax utilities for plotting the results.

In [17]:
import numpy as np
from ax.plot.trace import optimization_trace_single_method
from ax.utils.notebook.plotting import render


# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple
# optimization runs, so we wrap out best objectives array in another array.
objective_means = np.array([[trial.objective_mean for trial in exp.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.minimum.accumulate(objective_means, axis=1),
    optimum=0.397887,  # Known minimum objective for Branin function.
)
render(best_objective_plot)