Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-task BO with Service API #2546

Closed
sgbaird opened this issue Jun 26, 2024 · 6 comments
Closed

Multi-task BO with Service API #2546

sgbaird opened this issue Jun 26, 2024 · 6 comments
Assignees
Labels
question Further information is requested

Comments

@sgbaird
Copy link
Contributor

sgbaird commented Jun 26, 2024

Should I perish the thought?

xref: #1038

Fairly naive, non-functional starting point:

import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties

from ax.modelbridge.factory import Models
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy

from ax.modelbridge.registry import ST_MTGP_trans

obj1_name = "branin"

def branin(x1, x2):
    y = float(
        (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
        + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
        + 10
    )

    return y

gs = GenerationStrategy(
    steps=[
        GenerationStep(
            model=Models.ST_MTGP,
            num_trials=-1,
            max_parallelism=3,
            model_kwargs={"transforms": ST_MTGP_trans, "transform_configs": None},
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs, random_seed=42)

ax_client.create_experiment(
    parameters=[
        {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
        {"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
    ],
    objectives={
        obj1_name: ObjectiveProperties(minimize=True),
    },
)

for _ in range(10):

    parameterization, trial_index = ax_client.get_next_trial()

    # extract parameters
    x1 = parameterization["x1"]
    x2 = parameterization["x2"]

    results = branin(x1, x2)
    ax_client.complete_trial(trial_index=trial_index, raw_data=results)


best_parameters, metrics = ax_client.get_best_parameters()
(honegumi) PS C:\Users\sterg\Documents\GitHub\sgbaird\honegumi> & C:/Users/sterg/miniforge3/envs/honegumi/python.exe c:/Users/sterg/Documents/GitHub/sgbaird/honegumi/scripts/refreshers/multi_task.py
[INFO 06-26 16:25:51] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[WARNING 06-26 16:25:51] ax.service.ax_client: Random seed set to 42. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 06-26 16:25:51] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0])], parameter_constraints=[]).
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': -3.0, 'x2': 5.0}] as trial 0.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 0 with data: {'branin': (48.620235, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 0.0, 'x2': 6.2}] as trial 1.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 1 with data: {'branin': (19.642113, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 5.9, 'x2': 2.0}] as trial 2.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 2 with data: {'branin': (19.70361, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 1.5, 'x2': 2.0}] as trial 3.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 3 with data: {'branin': (14.301934, None)}.
[INFO 06-26 16:25:51] ax.core.experiment: Attached custom parameterizations [{'x1': 1.0, 'x2': 9.0}] as trial 4.
[INFO 06-26 16:25:51] ax.service.ax_client: Completed trial 4 with data: {'branin': (35.100744, None)}.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '0') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '1') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '2') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '3') is constant, within tolerance.
[INFO 06-26 16:25:51] ax.modelbridge.transforms.standardize_y: Outcome ('branin', '4') is constant, within tolerance.
C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ..\torch\csrc\utils\tensor_new.cpp:620.)
  summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\acquisition\cached_cholesky.py:89: RuntimeWarning: `cache_root` is only supported for GPyTorchModels that are not MultiTask models and don't produce a TransformedPosterior. Got a model of type <class 'botorch.models.model_list_gp_regression.ModelListGP'>. Setting `cache_root = False`.
  warnings.warn(
Traceback (most recent call last):
  File "c:\Users\sterg\Documents\GitHub\sgbaird\honegumi\scripts\refreshers\multi_task.py", line 114, in <module>
    parameterization, trial_index = ax_client.get_next_trial()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\utils\common\executils.py", line 163, in actual_wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\service\ax_client.py", line 539, in get_next_trial
    generator_run=self._gen_new_generator_run(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\service\ax_client.py", line 1790, in _gen_new_generator_run
    return not_none(self.generation_strategy).gen(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_strategy.py", line 370, in gen
    return self._gen_multiple(
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_strategy.py", line 683, in _gen_multiple
    generator_run = self._curr.gen(
                    ^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 712, in gen
    gr = super().gen(
         ^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 272, in gen
    generator_run = self._gen(
                    ^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\generation_node.py", line 334, in _gen
    return model_spec.gen(
           ^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\model_spec.py", line 221, in gen
    return fitted_model.gen(**model_gen_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\base.py", line 786, in gen
    gen_results = self._gen(
                  ^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\torch.py", line 686, in _gen
    gen_results = not_none(self.model).gen(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch_modular\model.py", line 428, in gen
    candidates, expected_acquisition_value, weights = acqf.optimize(
                                                      ^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch_modular\acquisition.py", line 450, in optimize
    candidates, acqf_values = optimize_acqf(
                              ^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 567, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 588, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\botorch\optim\optimize.py", line 400, in _optimize_acqf_batch
    batch_candidates = opt_inputs.post_processing_func(batch_candidates)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch.py", line 531, in botorch_rounding_func
    [rounding_func(x) for x in X.view(-1, d)]  # pyre-ignore: [16]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\models\torch\botorch.py", line 531, in <listcomp>
    [rounding_func(x) for x in X.view(-1, d)]  # pyre-ignore: [16]
     ^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\torch.py", line 287, in <lambda>
    self._array_to_tensor(array_func(x.detach().cpu().clone().numpy()))
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\modelbridge_utils.py", line 664, in _roundtrip_transform
    observation_features = t.untransform_observation_features(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\sterg\miniforge3\envs\honegumi\Lib\site-packages\ax\modelbridge\transforms\choice_encode.py", line 122, in untransform_observation_features
    obsf.parameters[p_name] = reverse_transform[pval]
                              ~~~~~~~~~~~~~~~~~^^^^^^
KeyError: 3.7423404678702354

Would this be something for BOTORCH_MODULAR in a custom generation strategy instead?

@danielcohenlive
Copy link

Great question @sgbaird! This is something we do internally in the service API with batch trials. We do have future plans to open source our AxBatchClient, but it's unfortunately not out yet.

With batch trials, you'd have a GS consisting of

  1. SOBOL
  2. GPEI or BOTORCH_MODULAR without fixed_features and status_quo_features
  • At this point, there's only one trial, so you can't do multitask yet.
  1. ST_MTGP or BOTORCH_MODULAR with fixed_features and status_quo_features
    The fixed_features and status_quo_features are going to point to a trial index, so you'd want those to both point to the target trial, probably the most recent one. I'm not aware of any way to group non batch trials into tasks.

What are you trying to do? I noticed honegumi in the prompt. Is this for the honegumi interface or a real world use case? Is it intentional or accidental that this use case has non batch trials?

@saitcakmak
Copy link
Contributor

Hi @sgbaird. Multi-task BO can take many forms depending on what you're trying to achieve. In any case, you need to provide Ax with a way to identify what task each trial belongs to. One way to do this would be to add a task parameter in your search space. When generating new trials, you can specify what task you want them to be generated as well. Here's an example using AxClient.

import numpy as np
from ax.core.observation import ObservationFeatures
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
from ax.modelbridge.transforms.unit_x import UnitX
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.typeutils import not_none
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import init_notebook_plotting, render

init_notebook_plotting()


# Update as needed. See ax/modelbridge/registry for default list of transforms.
transforms= [
    TaskChoiceToIntTaskChoice,  # Since we're using a string valued task parameter.
    UnitX,  
]

# Custom generation strategy to support the multi-task search space.
generation_strategy = GenerationStrategy(
    name="MultiTaskMBM",
    steps=[
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,
            model_kwargs={"deduplicate": True, "transforms": transforms},
        ),
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,
            model_kwargs={"transforms": transforms},
        ),
    ],
)

ax_client = AxClient(generation_strategy=generation_strategy)

ax_client.create_experiment(
    name="hartmann_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x3",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x4",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x5",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x6",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        # Add the task parameter!
        {
            "name": "task",
            "type": "choice",
            "values": ["base", "shifted"],
            "is_task": True,
            "target_value": "base"
        }
    ],
    objectives={"hartmann6": ObjectiveProperties(minimize=True)},
)


# Evaluation produces different results based on task value.
def evaluate(parameterization):
    x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
    value = hartmann6(x)
    if parameterization.get("task") == "shifted":
        value += 100
    # In our case, standard error is 0, since we are computing a synthetic function.
    return {"hartmann6": (value, 0.0)}



for i in range(10):
    trial = ax_client.experiment.new_trial(
        generator_run=ax_client.generation_strategy.gen(
            experiment=ax_client.experiment,
            n=1,
            pending_observations=ax_client._get_pending_observation_features(
                experiment=ax_client.experiment
            ),
            # Need to specify what task we want to generate from. Switching between the two here.
            fixed_features=ObservationFeatures(
                {"task": "base" if i % 2 else "shifted"}
            ),
        )
    )
    trial.mark_running(no_runner_required=True)
    parameterization, trial_index = not_none(trial.arm).parameters, trial.index
    ax_client.complete_trial(
        trial_index=trial_index, raw_data=evaluate(parameterization)
    )

# We can verify that the model is a ModelListGP of MultiTaskGP.
mb = ax_client.generation_strategy.model
mb.model.surrogate.model

@sgbaird
Copy link
Contributor Author

sgbaird commented Jul 29, 2024

Great question @sgbaird! This is something we do internally in the service API with batch trials. We do have future plans to open source our AxBatchClient, but it's unfortunately not out yet.

With batch trials, you'd have a GS consisting of

  1. SOBOL
  2. GPEI or BOTORCH_MODULAR without fixed_features and status_quo_features
  • At this point, there's only one trial, so you can't do multitask yet.
  1. ST_MTGP or BOTORCH_MODULAR with fixed_features and status_quo_features
    The fixed_features and status_quo_features are going to point to a trial index, so you'd want those to both point to the target trial, probably the most recent one. I'm not aware of any way to group non batch trials into tasks.

What are you trying to do? I noticed honegumi in the prompt. Is this for the honegumi interface or a real world use case? Is it intentional or accidental that this use case has non batch trials?

Sorry this took so long to get back to you! I lost track of this. Yes, the idea was for Honegumi and a BO tutorial on multi-task, since there are a lot of chemistry and materials science use-cases like this. I was trying to keep it in simplest case, so I hadn't considered/included batch trials. If there's not an easy way to support batch trials, then I'd set it so the batch option within the synchrony row in Honegumi would be crossed out if multi-task is set to True.

@sgbaird
Copy link
Contributor Author

sgbaird commented Jul 29, 2024

Hi @sgbaird. Multi-task BO can take many forms depending on what you're trying to achieve. In any case, you need to provide Ax with a way to identify what task each trial belongs to. One way to do this would be to add a task parameter in your search space. When generating new trials, you can specify what task you want them to be generated as well. Here's an example using AxClient.

Thank you for this! I was able to run it and plan to do some additional testing.

sgbaird added a commit to sgbaird/honegumi that referenced this issue Jul 29, 2024
@lena-kashtelyan
Copy link
Contributor

@sgbaird is this still open or resolved? : )

@lena-kashtelyan lena-kashtelyan added the question Further information is requested label Jul 31, 2024
@sgbaird
Copy link
Contributor Author

sgbaird commented Jul 31, 2024

Will consider solved for now, and post back/reopen if I run into any issues! Thanks all for the help! 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants