Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building a benchmark dataset with qNegIntegratedPosteriorVariance: use of ScalarizedPosteriorTransform instead of ScalarizedObjective (DEPRECATED) type objectives #1312

Closed
sgbaird opened this issue Dec 10, 2022 · 5 comments
Assignees

Comments

@sgbaird
Copy link
Contributor

sgbaird commented Dec 10, 2022

I want to build a dataset meant for benchmarking, and I figured that using qNIPV would make sense here #930. The inputs are red, green, and blue LED powers, and the outputs are eight discrete wavelengths sparks-baird/self-driving-lab-demo#121. While I could use quasi-random methods for generation, in a more real-world scenario where the experiments can take a much longer time to run, creating an example with something more sophisticated seemed like the way to go.

(Aside: maybe it would make sense to use qNEHVI + SAASBO here?)

I'm noticing in the BoTorch docs it says to use ScalarizedPosteriorTransform instead.

from typing import Any, Dict, Optional

from botorch.acquisition.active_learning import (
    MCSampler,
    qNegIntegratedPosteriorVariance,
)

from botorch.acquisition.input_constructors import (
    MaybeDict,
    acqf_input_constructor,
    construct_inputs_mc_base,
)

from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor

from botorch.acquisition.objective import AcquisitionObjective

@acqf_input_constructor(qNegIntegratedPosteriorVariance)
def construct_inputs_qNIPV(
    model: Model,
    mc_points: Tensor,
    training_data: MaybeDict[SupervisedDataset],
    objective: Optional[AcquisitionObjective] = None,
    X_pending: Optional[Tensor] = None,
    sampler: Optional[MCSampler] = None,
    **kwargs: Any,
) -> Dict[str, Any]:

    if model.num_outputs == 1:
        objective = None

    base_inputs = construct_inputs_mc_base(
        model=model,
        training_data=training_data,
        sampler=sampler,
        X_pending=X_pending,
        objective=objective,
    )

    return {**base_inputs, "mc_points": mc_points}

from typing import Any, Dict, Optional

import torch

from ax.modelbridge import get_sobol
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.ax_client import AxClient
from botorch.models.gp_regression import SingleTaskGP
from ax.service.utils.instantiation import ObjectiveProperties

num_sobol = 10
num_qnipv = 290
total_trials = num_sobol + num_qnipv
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ax_client_tmp = AxClient(torch_device=torch_device)
ax_client_tmp.create_experiment(parameters=parameters)
sobol = get_sobol(ax_client_tmp.experiment.search_space)
mc_points = sobol.gen(1024).param_df.values
mcp = torch.tensor(mc_points)

model_kwargs_val = {
    "surrogate": Surrogate(SingleTaskGP),
    "botorch_acqf_class": qNegIntegratedPosteriorVariance,
    "acquisition_options": {"mc_points": mcp},
}

gs = GenerationStrategy(
    steps=[
        GenerationStep(model=Models.SOBOL, num_trials=num_sobol),
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=num_qnipv,
            model_kwargs=model_kwargs_val,
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="clslab-light-experiment",
    parameters=parameters,
    objectives={
        ch_name: ObjectiveProperties(minimize=True) for ch_name in sdl.channel_names
    },
)
...
UnsupportedError                          Traceback (most recent call last)
Cell In [12], line 9
      6     return new_data
      8 for _ in range(20):
----> 9     trial_params, trial_index = ax_client.get_next_trial()
     10     data = evaluate(trial_params)
     11     ax_client.complete_trial(
     12         trial_index=trial_index, raw_data=data
     13     )

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\utils\common\executils.py:161, in retry_on_exception.<locals>.func_wrapper.<locals>.actual_wrapper(*args, **kwargs)
    157             wait_interval = min(
    158                 MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1)
    159             )
    160             time.sleep(wait_interval)
--> 161         return func(*args, **kwargs)
    163 # If we are here, it means the retries were finished but
    164 # The error was suppressed. Hence return the default value provided.
    165 return default_return_on_suppression

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\service\ax_client.py:480, in AxClient.get_next_trial(self, ttl_seconds, force)
    476         raise OptimizationShouldStop(message=global_stopping_message)
    478 try:
    479     trial = self.experiment.new_trial(
--> 480         generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
    481     )
    482 except MaxParallelismReachedException as e:
    483     if self._early_stopping_strategy is not None:

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\service\ax_client.py:1606, in AxClient._gen_new_generator_run(self, n)
   1599 # If random seed is not set for this optimization, context manager does
   1600 # nothing; otherwise, it sets the random seed for torch, but only for the
   1601 # scope of this call. This is important because torch seed is set globally,
   1602 # so if we just set the seed without the context manager, it can have
   1603 # serious negative impact on the performance of the models that employ
   1604 # stochasticity.
   1605 with manual_seed(seed=self._random_seed):
-> 1606     return not_none(self.generation_strategy).gen(
   1607         experiment=self.experiment,
   1608         n=n,
   1609         pending_observations=self._get_pending_observation_features(
   1610             experiment=self.experiment
   1611         ),
   1612     )

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\modelbridge\generation_strategy.py:334, in GenerationStrategy.gen(self, experiment, data, n, pending_observations, **kwargs)
    297 def gen(
    298     self,
    299     experiment: Experiment,
   (...)
    303     **kwargs: Any,
    304 ) -> GeneratorRun:
    305     """Produce the next points in the experiment. Additional kwargs passed to
    306     this method are propagated directly to the underlying model's `gen`, along
    307     with the `model_gen_kwargs` set on the current generation step.
   (...)
    332             resuggesting points that are currently being evaluated.
    333     """
--> 334     return self._gen_multiple(
    335         experiment=experiment,
    336         num_generator_runs=1,
    337         data=data,
    338         n=n,
    339         pending_observations=pending_observations,
    340         **kwargs,
    341     )[0]

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\modelbridge\generation_strategy.py:475, in GenerationStrategy._gen_multiple(self, experiment, num_generator_runs, data, n, pending_observations, **kwargs)
    473 for _ in range(num_generator_runs):
    474     try:
--> 475         generator_run = _gen_from_generation_step(
    476             generation_step=self._curr,
    477             input_max_gen_draws=MAX_GEN_DRAWS,
    478             n=n,
    479             pending_observations=pending_observations,
    480             model_gen_kwargs=kwargs,
    481             should_deduplicate=self._curr.should_deduplicate,
    482             arms_by_signature=self.experiment.arms_by_signature,
    483         )
    484         generator_run._generation_step_index = self._curr.index
    485         self._generator_runs.append(generator_run)

File c:\Users\sterg\Miniconda3\envs\sdl-demo\lib\site-packages\ax\modelbridge\generation_strategy.py:842, in _gen_from_generation_step(input_max_gen_draws, generation_step, n, pending_observations, model_gen_kwargs, should_deduplicate, arms_by_signature)
    840 if n_gen_draws > input_max_gen_draws:
    841     raise GenerationStrategyRepeatedPoints(MAX_GEN_DRAWS_EXCEEDED_MESSAGE)
--> 842 generator_run = generation_step.gen(
    843     n=n,
    844     pending_observations=pending_observations,
    845     **model_gen_kwargs,
    846 )
    847 should_generate_run = should_deduplicate and any(
    848     arm.signature in arms_by_signature for arm in generator_run.arms
    849 )
    850 n_gen_draws += 1
...
     66 return ScalarizedPosteriorTransform(
     67     weights=objective.weights, offset=objective.offset
     68 )

UnsupportedError: qNegIntegratedPosteriorVariance only supports ScalarizedObjective (DEPRECATED) type objectives.

Here is a self-contained example using vanilla single-objective optimization based on #460 (comment):

from typing import Any, Dict, Optional

import torch

# from ax.core.objective import ScalarizedObjective
from ax.modelbridge import get_sobol
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.ax_client import AxClient
from botorch.acquisition.active_learning import (
    MCSampler,
    qNegIntegratedPosteriorVariance,
)
from botorch.acquisition.input_constructors import (
    MaybeDict,
    acqf_input_constructor,
    construct_inputs_mc_base,
)
from botorch.acquisition.objective import AcquisitionObjective
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


@acqf_input_constructor(qNegIntegratedPosteriorVariance)
def construct_inputs_qNIPV(
    model: Model,
    mc_points: Tensor,
    training_data: MaybeDict[SupervisedDataset],
    objective: Optional[AcquisitionObjective] = None,
    X_pending: Optional[Tensor] = None,
    sampler: Optional[MCSampler] = None,
    **kwargs: Any,
) -> Dict[str, Any]:

    if model.num_outputs == 1:
        objective = None

    base_inputs = construct_inputs_mc_base(
        model=model,
        training_data=training_data,
        sampler=sampler,
        X_pending=X_pending,
        objective=objective,
    )

    return {**base_inputs, "mc_points": mc_points}


def objective_function(x):
    f = x["x1"] ** 2 + x["x2"] ** 2 + x["x3"] ** 2
    return {"f": (f, None)}


parameters = [
    {"name": "x1", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
    {"name": "x2", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
    {"name": "x3", "type": "range", "bounds": [0.0, 1.0], "value_type": "float"},
]
ax_client_tmp = AxClient()
ax_client_tmp.create_experiment(parameters=parameters)
sobol = get_sobol(ax_client_tmp.experiment.search_space)
mc_points = sobol.gen(1024).param_df.values
mcp = torch.tensor(mc_points)

model_kwargs_val = {
    "surrogate": Surrogate(SingleTaskGP),
    "botorch_acqf_class": qNegIntegratedPosteriorVariance,
    "acquisition_options": {"mc_points": mcp},
}

gs = GenerationStrategy(
    steps=[
        GenerationStep(model=Models.SOBOL, num_trials=num_sobol),
        GenerationStep(
            model=Models.BOTORCH_MODULAR, num_trials=num_qnipv, model_kwargs=model_kwargs_val
        ),
    ]
)

ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="active_learning_experiment",
    parameters=parameters,
    objective_name="f",
    minimize=True,
)

for _ in range(20):
    trial_params, trial_index = ax_client.get_next_trial()
    data = objective_function(trial_params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=data["f"])
@sgbaird sgbaird changed the title qNegIntegratedPosteriorVariance only supports ScalarizedObjective (DEPRECATED) type objectives How to use ScalarizedPosteriorTransform instead of ScalarizedObjective (DEPRECATED) type objectives (for qNegIntegratedPosteriorVariance) Dec 10, 2022
@sgbaird sgbaird changed the title How to use ScalarizedPosteriorTransform instead of ScalarizedObjective (DEPRECATED) type objectives (for qNegIntegratedPosteriorVariance) Building a benchmark dataset with qNegIntegratedPosteriorVariance: use of ScalarizedPosteriorTransform instead of ScalarizedObjective (DEPRECATED) type objectives Dec 10, 2022
@sgbaird
Copy link
Contributor Author

sgbaird commented Dec 10, 2022

Mostly curious if there's an easy way to plug this in via the Service API, and if not (which is totally fine), whether you'd recommend using the Developer API or BoTorch.

FYI this isn't a major blocker for me - just something I want to keep in mind for later.

For now, I might go with quasi-random generation of candidates.

@danielcohenlive
Copy link

Thanks for the question @sgbaird. So basically you're asking is there a way to use a Scalarized Objective? I would just drop into the dev API for that and use:

ax_client.experiment.objective = ScalarizedObjective(...)

We don't support it from ax_client.create_experiment()

@sgbaird
Copy link
Contributor Author

sgbaird commented Dec 12, 2022

@danielcohenlive thanks! I'll give that a try

@danielcohenlive danielcohenlive self-assigned this Dec 12, 2022
@danielcohenlive
Copy link

I'm going to close this for now, but feel free to reopen if you have further question

sgbaird added a commit to sgbaird/botorch that referenced this issue Jun 21, 2023
…edPosteriorTransform

The example usage at the end was already updated to use `ScalarizedPosteriorTransform`.

facebook/Ax#1312
facebook-github-bot pushed a commit to pytorch/botorch that referenced this issue Jun 21, 2023
…edPosteriorTransform (#1898)

Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

The example usage at the end was already updated to use `ScalarizedPosteriorTransform`.

facebook/Ax#1312

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1898

Test Plan:
N/A

## Related PRs

Doc change

Reviewed By: esantorella

Differential Revision: D46882718

Pulled By: saitcakmak

fbshipit-source-id: 152df53b3185f39734f545b3ed32ffde1e99c5cd
@sgbaird
Copy link
Contributor Author

sgbaird commented Jun 21, 2023

@danielcohenlive

Aside:

The docs for ScalarizedObjective say:

DEPRECATED - Use ScalarizedPosteriorTransform instead

For ease, here's the link to ScalarizedPosteriorTransform

The BoTorch tutorial for Writing a custom acquisition function shows usage of ScalarizedPosteriorTransform at the end:

from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.acquisition.analytic import UpperConfidenceBound

pt = ScalarizedPosteriorTransform(weights=torch.tensor([0.1, 0.5]))
SUCB = UpperConfidenceBound(gp, beta=0.1, posterior_transform=pt)

I have a notebook that I'd like to get functioning as a demonstration example. I'll give this a go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants