# Using external methods for candidate generation in Ax

Out of the box, Ax offers many options for candidate generation, most of which utilize Bayesian optimization algorithms built using [BoTorch](https://botorch.org/).
Users can allow Ax to select which methods from BoTorch will be used for optimization automatically (optionally guiding the choice by calling [`configure_generation_strategy`](../../recipes/influence-gs-choice)) and can also direct Ax to use specific BoTorch components directly via the [Modular Botorch framework](../modular_botorch/index.mdx).
Users that want to leverage Ax for experiment state management and orchestration via the `Client` and other features (e.g., early stopping, analyses, storage), while relying on other methods not implemented in BoTorch for candidate generation can implement their own `ExternalGenerationNode`. 

A `GenerationNode` is a building block of a `GenerationStrategy`. They can be combined together utilize different methods for generating candidates at different stages of an experiment. `ExternalGenerationNode` exposes a lightweight interface to allow the users to easily integrate their methods into Ax, and use them as standalone or with other `GenerationNode`s in a `GenerationStrategy`.

In this tutorial, we will implement a simple generation node using `RandomForestRegressor` from sklearn, and combine it with Sobol (for initialization) to optimize the Hartmann6 problem.

NOTE: This generation node using `RandomForestRegressor` is a naive implementation and for demonstration purposes only; it is not representative of how more sophisticated random forest Bayesian optimization algorithms perform. We do not recommend using this strategy as it typically does not perform well compared to Ax's default algorithms due to its overly greedy behavior.

In [None]:
import time
from typing import Any

import numpy as np

from ax.adapter.registry import Generators
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
from ax.analysis.summary import Summary
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.parameter import RangeParameter
from ax.core.types import TParameterization
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
from ax.generation_strategy.generation_node import GenerationNode
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MaxTrials

from sklearn.ensemble import RandomForestRegressor

In [None]:
class RandomForestGenerationNode(ExternalGenerationNode):
    """A generation node that uses the RandomForestRegressor
    from sklearn to predict candidate performance and picks the
    next point as the random sample that has the best prediction.

    To leverage external methods for candidate generation, the user must
    create a subclass that implements ``update_generator_state`` and
    ``get_next_candidate`` methods. This can then be provided
    as a node into a ``GenerationStrategy``, either as standalone or as
    part of a larger generation strategy with other generation nodes,
    e.g., with a Sobol node for initialization.
    """

    def __init__(self, num_samples: int, regressor_options: dict[str, Any]) -> None:
        """Initialize the generation node.

        Args:
            regressor_options: Options to pass to the random forest regressor.
            num_samples: Number of random samples from the search space
                used during candidate generation. The sample with the best
                prediction is recommended as the next candidate.
        """
        t_init_start = time.monotonic()
        super().__init__(node_name="RandomForest")
        self.num_samples: int = num_samples
        self.regressor: RandomForestRegressor = RandomForestRegressor(
            **regressor_options
        )
        # We will set these later when updating the state.
        # Alternatively, we could have required experiment as an input
        # and extracted them here.
        self.parameters: list[RangeParameter] | None = None
        self.minimize: bool | None = None
        # Recording time spent in initializing the generator. This is
        # used to compute the time spent in candidate generation.
        self.fit_time_since_gen: float = time.monotonic() - t_init_start

    def update_generator_state(self, experiment: Experiment, data: Data) -> None:
        """A method used to update the state of the generator. This includes any
        models, predictors or any other custom state used by the generation node.
        This method will be called with the up-to-date experiment and data before
        ``get_next_candidate`` is called to generate the next trial(s). Note
        that ``get_next_candidate`` may be called multiple times (to generate
        multiple candidates) after a call to  ``update_generator_state``.

        For this example, we will train the regressor using the latest data from
        the experiment.

        Args:
            experiment: The ``Experiment`` object representing the current state of the
                experiment. The key properties includes ``trials``, ``search_space``,
                and ``optimization_config``. The data is provided as a separate arg.
            data: The data / metrics collected on the experiment so far.
        """
        search_space = experiment.search_space
        parameter_names = list(search_space.parameters.keys())
        metric_names = list(experiment.optimization_config.metrics.keys())
        if any(
            not isinstance(p, RangeParameter) for p in search_space.parameters.values()
        ):
            raise NotImplementedError(
                "This example only supports RangeParameters in the search space."
            )
        if search_space.parameter_constraints:
            raise NotImplementedError(
                "This example does not support parameter constraints."
            )
        if len(metric_names) != 1:
            raise NotImplementedError(
                "This example only supports single-objective optimization."
            )
        # Get the data for the completed trials.
        num_completed_trials = len(experiment.trials_by_status[TrialStatus.COMPLETED])
        x = np.zeros([num_completed_trials, len(parameter_names)])
        y = np.zeros([num_completed_trials, 1])
        for t_idx, trial in experiment.trials.items():
            if trial.status == "COMPLETED":
                trial_parameters = trial.arm.parameters
                x[t_idx, :] = np.array([trial_parameters[p] for p in parameter_names])
                trial_df = data.df[data.df["trial_index"] == t_idx]
                y[t_idx, 0] = trial_df[trial_df["metric_name"] == metric_names[0]][
                    "mean"
                ].item()

        # Train the regressor.
        self.regressor.fit(x, y)
        # Update the attributes not set in __init__.
        self.parameters = search_space.parameters
        self.minimize = experiment.optimization_config.objective.minimize

    def get_next_candidate(
        self, pending_parameters: list[TParameterization]
    ) -> TParameterization:
        """Get the parameters for the next candidate configuration to evaluate.

        We will draw ``self.num_samples`` random samples from the search space
        and predict the objective value for each sample. We will then return
        the sample with the best predicted value.

        Args:
            pending_parameters: A list of parameters of the candidates pending
                evaluation. This is often used to avoid generating duplicate candidates.
                We ignore this here for simplicity.

        Returns:
            A dictionary mapping parameter names to parameter values for the next
            candidate suggested by the method.
        """
        bounds = np.array([[p.lower, p.upper] for p in self.parameters.values()])
        unit_samples = np.random.random_sample([self.num_samples, len(bounds)])
        samples = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * unit_samples
        # Predict the objective value for each sample.
        y_pred = self.regressor.predict(samples)
        # Find the best sample.
        best_idx = np.argmin(y_pred) if self.minimize else np.argmax(y_pred)
        best_sample = samples[best_idx, :]
        # Convert the sample to a parameterization.
        candidate = {
            p_name: best_sample[i].item()
            for i, p_name in enumerate(self.parameters.keys())
        }
        return candidate

## Construct the GenerationStrategy

We will use Sobol for the first 25 trials and defer to random forest for the rest.

In [None]:
generation_strategy = GenerationStrategy(
    name="Sobol+RandomForest",
    nodes=[
        GenerationNode(
            node_name="Sobol",
            generator_specs=[GeneratorSpec(Generators.SOBOL)],
            transition_criteria=[
                MaxTrials(
                    # This specifies the maximum number of trials to generate from this node,
                    # and the next node in the strategy.
                    threshold=25,
                    block_transition_if_unmet=True,
                    transition_to="RandomForest",
                )
            ],
        ),
        RandomForestGenerationNode(num_samples=128, regressor_options={}),
    ],
)

## Run a simple experiment using the `Client`

More details on how to use the `Client` can be found in the following [tutorial](../getting_started/index.mdx).

In [None]:
# Define the Hartmann6 function which we will minimize.
def hartmann6(x1, x2, x3, x4, x5, x6):
    alpha = np.array([1.0, 1.2, 3.0, 3.2])
    A = np.array(
        [
            [10, 3, 17, 3.5, 1.7, 8],
            [0.05, 10, 17, 0.1, 8, 14],
            [3, 3.5, 1.7, 10, 17, 8],
            [17, 8, 0.05, 10, 0.1, 14],
        ]
    )
    P = 10**-4 * np.array(
        [
            [1312, 1696, 5569, 124, 8283, 5886],
            [2329, 4135, 8307, 3736, 1004, 9991],
            [2348, 1451, 3522, 2883, 3047, 6650],
            [4047, 8828, 8732, 5743, 1091, 381],
        ]
    )

    outer = 0.0
    for i in range(4):
        inner = 0.0
        for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
            inner += A[i, j] * (x - P[i, j]) ** 2
        outer += alpha[i] * np.exp(-inner)
    return -outer

Configure the experiment and set is GenerationStrategy to be our Sobol+RandomForest strategy defined previously.

In [None]:
client = Client()

client.configure_experiment(
    parameters=[
        RangeParameterConfig(name=f"x{i}", parameter_type="float", bounds=(0, 1))
        for i in range(1, 7)
    ]
)
client.configure_optimization(objective="-1 * hartmann6")

client.set_generation_strategy(generation_strategy=generation_strategy)

Run the optimization loop as you would conduct any other Ax experiment. Under the hood Ax will dispatch to your custom `ExternalGenerationNode` when generating new candidate trials.

In [None]:
for _ in range(50):  # Run 50 rounds of trials (25 Sobol points, 25 with the RandomForest node)
    trials = client.get_next_trials(max_trials=1)

    for trial_index, parameters in trials.items():
        x1 = parameters["x1"]
        x2 = parameters["x2"]
        x3 = parameters["x3"]
        x4 = parameters["x4"]
        x5 = parameters["x5"]
        x6 = parameters["x6"]

        # Set raw_data as a dictionary with metric names as keys and results as values
        raw_data = {"hartmann6": hartmann6(x1, x2, x3, x4, x5, x6)}

        # Complete the trial with the result
        client.complete_trial(trial_index=trial_index, raw_data=raw_data)

View your results by computing Analyses. Not all Analyses implemented by Ax will be compatible with an `ExternalGenerationNode` (specifically those which rely on an Ax `Adapter` for their input data) though many which rely solely on raw observation will work out of the box.

In [None]:
client.compute_analyses(
    analyses=[ArmEffectsPlot(metric_name="hartmann6", use_model_predictions=False), Summary()]
)