In [2]:
# Ax wrappers for BoTorch components
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.acquisition import Acquisition

# Ax data tranformation layer
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.registry import Cont_X_trans, Y_trans, Models

# Test Ax objects
from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_data

# BoTorch components
from botorch.models.model import Model
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.acquisition.monte_carlo import qExpectedImprovement, qNoisyExpectedImprovement
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

# Setup and Usage of BoTorch Models in Ax

Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components like `Model` and `AcquisitionFunction` and combine them into a single `Model` object in Ax. The wrapper abstractions: `Surrogate`, `Acquisition`, and `BoTorchModel` – are located in `ax/models/torch/botorch_modular` directory and aim to encapsulate boilerplate code that interfaces between Ax and BoTorch. This functionality is in beta-release and still evolving.

This tutorial walks through setting up a custom combination of BoTorch components in Ax in following steps:

1. **Quick-start example of `BoTorchModel` use**
1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**
   1. Example showing all possible options
   2. Example with minimal options that uses the defaults
   3. Using pre-constructed BoTorch Model (e.g. in research or development)
   4. Surrogate and Acquisition Q&A
2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**
   1. TODO
3. **Using `Models.BOTORCH_MODULAR`** (convenience wrapper that enables storage and resumability)
4. **Utilizing `BoTorchModel` in generation strategies** (abstraction that allows to chain models together and use them in Ax Service API etc.)
5. **Customizing a `Surrogate` or `Acquisition`** (for cases where existing subcomponent classes are not sufficient)

## 1. Quick-start example

Here we set up a `BoTorchModel` with `SingleTaskGP` with `qNoisyExpectedImprovement`, one of the most popular combinations in Ax:

In [12]:
experiment = get_branin_experiment(with_trial=True)
data = get_branin_data(trials=[experiment.trials[0]])

In [13]:
# `Models` automatically selects a model + model bridge combination. 
# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.
model_bridge_with_GPKG = Models.BOTORCH_MODULAR(
    experiment=experiment,
    data=data,
    surrogate=Surrogate(SingleTaskGP),  # Optional, will use default if unspecified
    botorch_acqf_class=qNoisyExpectedImprovement,  # Optional, will use default if unspecified
)

[INFO 06-09 17:03:23] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.


Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`).

In [15]:
generator_run = model_bridge_with_GPKG.gen(n=1)
generator_run.arms[0]

Arm(parameters={'x1': 10.0, 'x2': 15.0})

-----
Before you read the rest of this tutorial:

- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use ['model'](https://ax.dev/docs/glossary.html#model) to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See [Models documentation page](https://ax.dev/docs/models.html) for more information.
- Learn about `ModelBridge` in Ax, as users should rarely be interacting with a `Model` object directly (more about ModelBridge, a data transformation layer in Ax, [here](https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack)).

## 2. BoTorchModel = Surrogate + Acquisition

A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class.

### 2A. Example with all the options
Below are the full set of configurable settings of a `BoTorchModel` with their descriptions:

In [None]:
model = BoTorchModel(
    # Optional `Surrogate` specification to use instead of default
    surrogate=Surrogate(
        # BoTorch `Model` type
        botorch_model_class=FixedNoiseGP,
        # Optional, MLL class with which to optimize model parameters
        mll_class=ExactMarginalLogLikelihood,
        # Optional, dictionary of keyword arguments to underlying 
        # BoTorch `Model` constructor
        model_options={}
    ),
    # Optional options to pass to auto-picked `Surrogate` if not
    # specifying the `surrogate` argument
    surrogate_options={},
    
    # Optional BoTorch `AcquisitionFunction` to use instead of default
    botorch_acqf_class=qExpectedImprovement,
    # Optional dict of keyword arguments, passed to the input 
    # constructor for the given BoTorch `AcquisitionFunction`
    acquisition_options={},
    # Optional Ax `Acquisition` subclass if the given BoTorch
    # `AcquisitionFunction` require
    acquisition_class=None,
    
    # Less common model settings shown with default values, refer
    # to `BoTorchModel` documentation for detail
    refit_on_update=True,
    refit_on_cv=False,
    warm_start_refit=True,
)

## 2B. Example that uses defaults and requires no options

BoTorchModel does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic).

In [None]:
# The surrogate is not specified, so it will be auto-selected
# during `model.fit`.
GPEI_model = BoTorchModel(botorch_acqf_class=qExpectedImprovement)

# The acquisition class is not specified, so it will be 
# auto-selected during `model.gen` or `model.evaluate_acquisition`
GPEI_model = BoTorchModel(surrogate=Surrogate(FixedNoiseGP))

# Both the surrogate and acquisition class will be auto-selected.
GPEI_model = BoTorchModel()

## 2C. `Surrogate` from pre-instantiated BoTorch `Model`

Alternatively, for BoTorch `Model`-s that require complex instantiation procedures (or is in development stage), leverage the `from_botorch` instantiation method of Surrogate:

In [19]:
from_botorch_model = BoTorchModel(
    surrogate=Surrogate.from_botorch(
        # BoTorch `Model` instance, with training data already set
        model=...,  
        # Optional, MLL class with which to optimize model parameters
        mll_class=ExactMarginalLogLikelihood,
    )
)

## 2D. `Surrogate` and `Acquisition` Q&A

**Why is `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.

**How to know when to use specify acquisition_class (and thereby a non-default Acquisition type) instead of just passing in botorch_acqf_class?** In short, custom `Acquisition` subclasses are needed when a given `AcquisitionFunction` in BoTorch needs some non-standard subcomponents or inputs (e.g. a custom BoTorch `AcquisitionObjective`). <TODO>

**Please post any other questions you have to our dedicated issue on Github: https://github.com/facebook/Ax/issues/363.** This functionality is in beta-release and your feedback will be of great help to us!

## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?

### 3a. Making a `Surrogate` from BoTorch `Model`:
Most models should work with base `Surrogate` in Ax, except for BoTorch `ModelListGP`, which works with `ListSurrogate`. `ModelListGP` is a special case because its purpose is to combine multiple sub-models into a single `Model` in BoTorch. It is most commonly used for multi-task optimization.

If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:
1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/master/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/master/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept `TrainingData` and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.
2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from `TrainingData` and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`.

In [22]:
surrogate = Surrogate(
    botorch_model_class=MyModelClass,  # Must implement `construct_inputs`
    # Optional dict of additional keyword arguments to `MyModelClass`
    model_options={},
)

For a `ModelListGP`, the setup is similar, except that the surrogate is defined in terms of sub-models rather than one model. Both of the following options will work:

In [None]:
surrogate = ListSurrogate(
    botorch_submodel_class_per_outcome={
        "metric_a": MyModelClass, 
        "metric_b": MyOtherModelClass,
    },
    submodel_options_per_outcome={"metric_a": {}, "metric_b": {}},
)

In [None]:
surrogate = ListSurrogate(
    # Shortcut if all submodels are the same type
    botorch_submodel_class=MyModelClass,
    # Shortcut if all submodel options are the same
    submodel_options={},
)

### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax