In [1]:
import plotly.graph_objects as go
import numpy as np
import torch

In [2]:
def plot_model(model, loss_params=None, data_params=None):
    # ランダムシードの設定
    torch.manual_seed(42)
    np.random.seed(42)

    def func(x):
        if type(x) == torch.Tensor:
            return torch.sin(x)
        return np.sin(x)
    
    if data_params is None:
        data_params = {}
    n_points = data_params.get('n_points', 50)
    noise_std = data_params.get('noise_std', 0.1)
    outlier_ratio = data_params.get('outlier_ratio', 0.03)
    

    # サイン波データの生成
    def generate_sine_data(n_points=50, noise_std=0.1, outlier_ratio=0.03, func=np.sin):
        x = np.linspace(0, 2 * np.pi, n_points)
        y = func(x)
        noise = np.random.normal(0, noise_std, size=y.shape)
        y_noisy = y + noise

        n_outliers = int(n_points * outlier_ratio)
        outlier_indices = np.random.choice(n_points, n_outliers, replace=False)
        y_noisy[outlier_indices] = 5 + np.random.uniform(-0.1, 0.1, n_outliers)

        return torch.tensor(x, dtype=torch.float64), torch.tensor(y_noisy, dtype=torch.float64)

    # データ生成
    train_x, train_y = generate_sine_data(n_points=n_points, noise_std=noise_std, outlier_ratio=outlier_ratio, func=func)

    # データを2D形式に変換
    train_x_2d = train_x.unsqueeze(-1).to(torch.float64)
    train_y_2d = train_y.unsqueeze(-1).to(torch.float64)

    # モデルを訓練
    model.fit(train_x_2d, train_y_2d, loss_params=loss_params)

    # テスト用データの生成
    test_x = torch.linspace(0, 2 * np.pi, 200).unsqueeze(-1).to(torch.float64)
    true_y = func(test_x)

    # 予測
    posterior = model.posterior(test_x)
    mean = posterior.mean
    variance = posterior.variance
    std = torch.sqrt(variance)

    # 信用区間
    lower_bound = mean - 2 * std
    upper_bound = mean + 2 * std

    # プロット作成
    fig = go.Figure()

    # 真の関数
    fig.add_trace(go.Scatter(
        x=test_x.squeeze().numpy(),
        y=true_y.squeeze().numpy(),
        mode='lines',
        name='True Function',
        line=dict(color='red')
    ))

    # 観測データ
    fig.add_trace(go.Scatter(
        x=train_x.numpy(),
        y=train_y.numpy(),
        mode='markers',
        name='Observations',
        marker=dict(color='blue')
    ))

    # 平均関数
    fig.add_trace(go.Scatter(
        x=test_x.squeeze().numpy(),
        y=mean.squeeze().detach().numpy(),
        mode='lines',
        name='Predictive Mean',
        line=dict(color='green')
    ))

    # 信用区間の塗りつぶし
    fig.add_trace(go.Scatter(
        x=np.concatenate([test_x.squeeze().numpy(), test_x.squeeze().numpy()[::-1]]),
        y=np.concatenate([upper_bound.squeeze().detach().numpy(), lower_bound.squeeze().detach().numpy()[::-1]]),
        fill='toself',
        fillcolor="green",
        line=dict(color='rgba(255, 165, 0, 0)'),
        opacity=0.2,
        name='Confidence Interval (2σ)',
        showlegend=True
    ))

    # プロットの装飾
    fig.update_layout(
        title="BNN Regression on Sine Function",
        xaxis_title="x",
        yaxis_title="y",
        legend=dict(orientation="h", y=-0.2),
        template="plotly_white"
    )

    fig.show()

In [3]:
from models.laplace import LaplaceBNN

n = 3

# モデルの初期化
args = {
    "regnet_dims": [64*n, 64*n, 64*n],
    "regnet_activation": "tanh",
    "prior_var": 1.0,
    "noise_var": 0.2,
    "iterative": True,
}
input_dim = 1
output_dim = 1
device = torch.device("cpu")

model = LaplaceBNN(args, input_dim=input_dim, output_dim=output_dim, device=device)


data_params = {
    "n_points": 50,
    "noise_std": 0.9,
    "outlier_ratio": 0.03
}

loss_params = {
    "n_epochs": 1000,
    "weight_decay": 1e-3,
    "artl_weight": 0,
    "lambd": 1e-3,
    "k": (1, 2, 3,), # order
    "q": 2,  # exponent
    "M": 100
}

plot_model(model, loss_params=loss_params, data_params=data_params)

n_epochs: 1000
weight_decay: 0.001
artl_weight: 0
h: 45
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100


In [40]:
model.posterior

<bound method LaplaceBNN.posterior of LaplaceBNN(
  (nn): RegNet(
    (linear0): Linear(in_features=1, out_features=192, bias=True)
    (tanh0): Tanh()
    (linear1): Linear(in_features=192, out_features=192, bias=True)
    (tanh1): Tanh()
    (linear2): Linear(in_features=192, out_features=192, bias=True)
    (tanh2): Tanh()
    (linear3): Linear(in_features=192, out_features=1, bias=True)
  )
)>

# BoTorch

In [12]:
import torch
from botorch.acquisition.logei import qLogExpectedImprovement
from botorch.models.model import Model
from botorch.optim import optimize_acqf
from botorch.sampling import SobolQMCNormalSampler
from botorch.test_functions import Branin
from botorch.utils.transforms import normalize, unnormalize

# Assuming `LaplaceBNN` is already imported
from models.laplace import LaplaceBNN

def run_botorch_with_laplace_bnn():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define the benchmark function (Branin function here)
    branin = Branin().to(device)

    # Define bounds for the optimization problem
    bounds = torch.tensor([[0.0, 0.0], [1.0, 1.0]], device=device, dtype=torch.float64)

    # Normalize initial data
    train_x = torch.rand(5, 2, device=device, dtype=torch.float64)
    train_y = branin(train_x).unsqueeze(-1)

    # Define the model
    args = {
        "regnet_dims": [10, 10],
        "regnet_activation": "relu",
        "prior_var": 1.0,
        "noise_var": 1e-2,
        "iterative": False,
    }
    laplace_bnn = LaplaceBNN(args, input_dim=2, output_dim=1, device=device)

    # Fit the model
    laplace_bnn.fit(train_x, train_y)

    for iteration in range(10):
        print(f"Iteration {iteration + 1}")

        # Define the acquisition function
        sampler = SobolQMCNormalSampler(sample_shape=torch.Size([500]))
        qEI = qLogExpectedImprovement(model=laplace_bnn, best_f=train_y.max(), sampler=sampler)

        # Optimize the acquisition function to find the next point to evaluate
        candidates, _ = optimize_acqf(
            acq_function=qEI,
            bounds=bounds,
            q=1,
            num_restarts=10,
            raw_samples=50,
        )

        # Evaluate the objective at the selected candidate
        new_x = candidates.detach()
        new_y = branin(new_x).unsqueeze(-1)

        # Update the training data
        train_x = torch.cat([train_x, new_x])
        train_y = torch.cat([train_y, new_y])

        # Refit the model with the new data
        laplace_bnn.fit(train_x, train_y)

        print(f"Optimized candidate: {new_x}")
        print(f"Objective value: {new_y}")

if __name__ == "__main__":
    run_botorch_with_laplace_bnn()


n_epochs: 1000
weight_decay: 0
artl_weight: 1.0
h: 4
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100
Iteration 1
n_epochs: 1000
weight_decay: 0
artl_weight: 1.0
h: 5
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100
Optimized candidate: tensor([[0., 0.]], dtype=torch.float64)
Objective value: tensor([[55.6021]], dtype=torch.float64)
Iteration 2
n_epochs: 1000
weight_decay: 0
artl_weight: 1.0
h: 6
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100
Optimized candidate: tensor([[0., 0.]], dtype=torch.float64)
Objective value: tensor([[55.6021]], dtype=torch.float64)
Iteration 3
n_epochs: 1000
weight_decay: 0
artl_weight: 1.0
h: 7
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100
Optimized candidate: tensor([[0., 0.]], dtype=torch.float64)
Objective value: tensor([[55.6021]], dtype=torch.float64)
Iteration 4
n_epochs: 1000
weight_decay: 0
artl_weight: 1.0
h: 8
lambd: 0.001
k: (1, 2, 3)
q: 2
M: 100
Optimized candidate: tensor([[0., 0.]], dtype=torch.float64)
Objective value: tensor([[55.6021]], dtype=torch.float64)
Iteration 5
n_epochs:

# ax

In [15]:
from ax import (
    Experiment,
    OptimizationConfig,
    Objective,
    ParameterType,
    RangeParameter,
    SearchSpace,
)
from ax.modelbridge.registry import Models
from ax.service.utils.report_utils import exp_to_df
from botorch.test_functions import Branin
from ax.core.metric import Metric
from ax.core.types import ComparisonOp


# 目的関数用のカスタムMetricクラス
class BraninMetric(Metric):
    def fetch_trial_data(self, trial):
        x1 = trial.arm.parameters["x1"]
        x2 = trial.arm.parameters["x2"]
        branin = Branin()
        # Branin関数の評価値を返す
        return {"branin": branin(torch.tensor([x1, x2], dtype=torch.float64)).item()}


# 検索空間を定義
search_space = SearchSpace(
    parameters=[
        RangeParameter(name="x1", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0),
        RangeParameter(name="x2", parameter_type=ParameterType.FLOAT, lower=0.0, upper=1.0),
    ]
)

# 最適化設定を定義（目的は最小化）
optimization_config = OptimizationConfig(
    objective=Objective(metric=BraninMetric(name="branin"), minimize=True)
)

# Experimentの設定
experiment = Experiment(
    name="branin_experiment",
    search_space=search_space,
    optimization_config=optimization_config,
)

# 初期データの収集
sobol = Models.SOBOL(search_space=search_space)
for _ in range(5):
    trial = experiment.new_trial(generator_run=sobol.gen(1))
    trial.run()

# 結果の表示
df = exp_to_df(experiment)
print(df)


ValueError: No runner set on trial or experiment.

In [16]:
import os
from contextlib import contextmanager, nullcontext

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio

# Ax uses Plotly to produce interactive plots. These are great for viewing and analysis,
# though they also lead to large file sizes, which is not ideal for files living in GH.
# Changing the default to `png` strips the interactive components to get around this.
pio.renderers.default = "png"

SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30


from typing import Optional

from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor


class SimpleCustomGP(ExactGP, GPyTorchModel):

    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
        # NOTE: This ignores train_Yvar and uses inferred noise instead.
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    

from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate


ax_model = BoTorchModel(
    surrogate=Surrogate(
        # The model class to use
        botorch_model_class=SimpleCustomGP,
        # Optional, MLL class with which to optimize model parameters
        # mll_class=ExactMarginalLogLikelihood,
        # Optional, dictionary of keyword arguments to model constructor
        # model_options={}
    ),
    # Optional, acquisition function class to use - see custom acquisition tutorial
    # botorch_acqf_class=qExpectedImprovement,
)

from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models


gs = GenerationStrategy(
    steps=[
        # Quasi-random initialization step
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
        ),
        # Bayesian optimization step using the custom acquisition function
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,  # No limitation on how many trials should be produced from this step
            # For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
            model_kwargs={
                "surrogate": Surrogate(SimpleCustomGP),
            },
        ),
    ]
)


import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            # It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
            # Otherwise, the parameter would be inferred as an integer range.
            "bounds": [-5.0, 10.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 15.0],
        },
    ],
    objectives={
        "branin": ObjectiveProperties(minimize=True),
    },
)
# Setup a function to evaluate the trials
branin = Branin()


def evaluate(parameters):
    x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
    # The GaussianLikelihood used by our model infers an observation noise level,
    # so we pass an sem value of NaN to indicate that observation noise is unknown
    return {"branin": (branin(x).item(), float("nan"))}


if SMOKE_TEST:
    fast_smoke_test = fast_botorch_optimize_context_manager
else:
    fast_smoke_test = nullcontext

# Set a seed for reproducible tutorial output
torch.manual_seed(0)


with fast_smoke_test():
    for i in range(NUM_EVALS):
        parameters, trial_index = ax_client.get_next_trial()
        # Local evaluation here can be replaced with deployment to external system.
        ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))


ax_client.get_trials_data_frame()

[INFO 12-12 19:28:46] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-12 19:28:46] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 19:28:46] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 19:28:46] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).

Encountered exc

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,branin,x1,x2
0,0,0_0,COMPLETED,Sobol,104.365417,0.62583,14.359564
1,1,1_0,COMPLETED,Sobol,2.996862,3.166217,3.867106
2,2,2_0,COMPLETED,Sobol,66.530632,9.560105,10.718323
3,3,3_0,COMPLETED,Sobol,198.850861,-3.878664,0.117947
4,4,4_0,COMPLETED,Sobol,5.811776,-2.362858,8.855021
5,5,5_0,COMPLETED,BoTorch,6.622622,2.562464,4.928441
6,6,6_0,COMPLETED,BoTorch,31.202353,5.498808,4.949998
7,7,7_0,COMPLETED,BoTorch,38.799263,-2.318758,4.441083
8,8,8_0,COMPLETED,BoTorch,12.325905,-1.545733,7.324723
9,9,9_0,COMPLETED,BoTorch,78.91967,-5.0,9.051272


In [25]:
import os
from contextlib import nullcontext

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio
from typing import Optional
from torch import Tensor
import torch

from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Set Plotly rendering to static image to reduce file size
pio.renderers.default = "png"

# Determine smoke test mode for quick execution
SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30


class SimpleCustomGP(ExactGP, GPyTorchModel):
    """A simple custom Gaussian Process model using GPyTorch."""

    _num_outputs = 1  # Inform GPyTorchModel API

    def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)

    def forward(self, x):
        """Forward pass through the GP model."""
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


# Define a generation strategy with a custom model
gs = GenerationStrategy(
    steps=[
        # Sobol initialization
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,
        ),
        # Bayesian optimization step using SimpleCustomGP
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,
            model_kwargs={
                "surrogate": Surrogate(SimpleCustomGP),
            },
        ),
    ]
)

# Initialize the AxClient
ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [-5.0, 10.0],  # Ensure float bounds to avoid integer range
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 15.0],
        },
    ],
    objectives={
        "branin": ObjectiveProperties(minimize=True),
    },
)

# Define the Branin function for evaluation
branin = Branin()


def evaluate(parameters):
    """Evaluate the Branin function for given parameters."""
    x = torch.tensor([[parameters[f"x{i+1}"] for i in range(2)]], dtype=torch.float64)
    return {"branin": (branin(x).item(), float("nan"))}


# Fast execution context for testing
fast_smoke_test = fast_botorch_optimize_context_manager if SMOKE_TEST else nullcontext

# Set a seed for reproducibility
torch.manual_seed(0)

# Execute the optimization loop
with fast_smoke_test():
    for i in range(NUM_EVALS):
        parameters, trial_index = ax_client.get_next_trial()
        raw_data = evaluate(parameters)
        ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

print("Optimization completed.")
ax_client.get_trials_data_frame()

[INFO 12-12 19:59:10] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-12 19:59:10] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 19:59:10] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.


[INFO 12-12 19:59:10] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-12 19:59:10] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.62583, 'x2': 14.359564} using model Sobol.
[INFO 12-12 19:59:10] ax.service.ax_client: Completed trial 0 with data: {'branin': (104.365427, nan)}.

Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.

[INFO 12-12 19:59:10] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 3.166217, 'x2': 3.867106} using model Sobol.
[INFO 12-12 19:59:10] ax.service.ax_client: Completed trial 1 with data: {'branin': (2.996864, nan)}.

Encountered exception in computing model fit quality: 

Optimization completed.


Unnamed: 0,trial_index,arm_name,trial_status,generation_method,branin,x1,x2
0,0,0_0,COMPLETED,Sobol,104.365427,0.62583,14.359564
1,1,1_0,COMPLETED,Sobol,2.996864,3.166217,3.867106
2,2,2_0,COMPLETED,Sobol,66.530653,9.560105,10.718323
3,3,3_0,COMPLETED,Sobol,198.850837,-3.878664,0.117947
4,4,4_0,COMPLETED,Sobol,5.811777,-2.362858,8.855021
5,5,5_0,COMPLETED,BoTorch,6.622979,2.562422,4.928511
6,6,6_0,COMPLETED,BoTorch,31.202334,5.498725,4.950084
7,7,7_0,COMPLETED,BoTorch,38.810004,-2.319322,4.441081
8,8,8_0,COMPLETED,BoTorch,12.326239,-1.545653,7.324714
9,9,9_0,COMPLETED,BoTorch,78.919329,-5.0,9.051294


In [28]:
ax_client.generation_strategy.model

TorchModelBridge(model=BoTorchModel)

In [34]:
import os
from contextlib import nullcontext

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio
from typing import Optional
from torch import Tensor
import torch
from botorch.models.model import Model
from botorch.utils.datasets import FixedNoiseDataset

from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin

# Assuming `LaplaceBNN` is already implemented and imported
from models.laplace import LaplaceBNN

# Set Plotly rendering to static image to reduce file size
pio.renderers.default = "png"

# Determine smoke test mode for quick execution
SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30

# Define a wrapper class for LaplaceBNN to integrate with Ax
class LaplaceBNNWrapper(Model):
    def __init__(self, input_dim, output_dim, device):
        super().__init__()
        self.model = LaplaceBNN(
            args={
                "regnet_dims": [10, 10],
                "regnet_activation": "relu",
                "prior_var": 1.0,
                "noise_var": 1e-2,
                "iterative": False,
            },
            input_dim=input_dim,
            output_dim=output_dim,
            device=device,
        )

    def forward(self, X):
        with torch.no_grad():
            mean, variance = self.model(X)
        return mean, variance

    @staticmethod
    def construct_inputs(training_data: FixedNoiseDataset, **kwargs):
        return {
            "train_X": training_data.X,
            "train_Y": training_data.Y,
            "train_Yvar": training_data.Yvar,
            **kwargs,
        }

    def fit(self, train_X, train_Y, train_Yvar=None):
        self.model.fit(train_X, train_Y)

# Define a generation strategy with LaplaceBNN as the surrogate model
gs = GenerationStrategy(
    steps=[
        # Sobol initialization
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,
        ),
        # Bayesian optimization step using LaplaceBNN
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,
            model_kwargs={
                "surrogate": Surrogate(
                    botorch_model_class=LaplaceBNNWrapper,
                    model_options={"input_dim": 2, "output_dim": 1, "device": "cuda" if torch.cuda.is_available() else "cpu"},
                ),
            },
        ),
    ]
)

# Initialize the AxClient
ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="laplace_bnn_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [-5.0, 10.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 15.0],
        },
    ],
    objectives={
        "branin": ObjectiveProperties(minimize=True),
    },
)

# Define the Branin function for evaluation
branin = Branin()

def evaluate(parameters):
    x = torch.tensor([[parameters[f"x{i+1}"] for i in range(2)]], dtype=torch.float64)
    return {"branin": (branin(x).item(), float("nan"))}

# Generate initial Sobol samples for the model
torch.manual_seed(0)

for _ in range(5):  # Number of Sobol trials
    parameters, trial_index = ax_client.get_next_trial()
    raw_data = evaluate(parameters)
    ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Fast execution context for testing
fast_smoke_test = fast_botorch_optimize_context_manager if SMOKE_TEST else nullcontext

# Execute the optimization loop for the remaining trials
with fast_smoke_test():
    for i in range(NUM_EVALS - 5):  # Subtract initial trials
        parameters, trial_index = ax_client.get_next_trial()
        raw_data = evaluate(parameters)
        ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

print("Optimization completed.")


[INFO 12-12 20:09:55] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-12 20:09:55] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:09:55] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:09:55] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).

Encountered exc

TypeError: Can't instantiate abstract class LaplaceBNNWrapper with abstract method posterior

In [33]:
ax_client.generation_strategy.model

In [None]:
# https://chatgpt.com/c/675a9fbe-164c-8001-944c-6f373ccfb08a


import os
import torch
from contextlib import nullcontext

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.models.model import Model
from botorch.posteriors import Posterior
from botorch.utils.datasets import FixedNoiseDataset
from botorch.test_functions import Branin

# Assuming LaplaceBNN and LaplacePosterior are implemented and imported
from models.laplace import LaplaceBNN


class LaplaceBNNWrapper(Model):
    def __init__(self, args, input_dim, output_dim, device):
        super().__init__()
        self.model = LaplaceBNN(args, input_dim, output_dim, device)

    def posterior(
        self, 
        X: torch.Tensor, 
        observation_noise: bool = False, 
        **kwargs
    ) -> Posterior:
        """
        Returns the posterior predictive distribution for the given input `X`.
        """
        return self.model.posterior(X)

    @staticmethod
    def construct_inputs(training_data: FixedNoiseDataset, **kwargs) -> dict:
        """
        Converts Ax training data to LaplaceBNN inputs.
        """
        return {
            "train_X": training_data.X,
            "train_Y": training_data.Y,
            "train_Yvar": training_data.Yvar,
            **kwargs,
        }

    def fit(self, train_X, train_Y, train_Yvar=None, **kwargs):
        """
        Fit the LaplaceBNN model using the provided training data.
        """
        self.model.fit(train_X, train_Y)


# Define the generation strategy
gs = GenerationStrategy(
    steps=[
        # Sobol initialization
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,
        ),
        # Bayesian optimization using LaplaceBNN
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,
            model_kwargs={
                "surrogate": Surrogate(
                    botorch_model_class=LaplaceBNNWrapper,
                    model_options={
                        "args": {
                            "regnet_dims": [10, 10],
                            "regnet_activation": "relu",
                            "prior_var": 1.0,
                            "noise_var": 1e-2,
                            "iterative": False,
                        },
                        "input_dim": 2,
                        "output_dim": 1,
                        "device": "cuda" if torch.cuda.is_available() else "cpu",
                    },
                ),
            },
        ),
    ]
)

# Initialize AxClient
ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
    name="laplace_bnn_bo_experiment",
    parameters=[
        {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
        {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
    ],
    objectives={"branin": ObjectiveProperties(minimize=True)},
)

# Define the Branin function for evaluation
def evaluate(parameters):
    x = torch.tensor([[parameters[f"x{i+1}"] for i in range(2)]], dtype=torch.float64)
    return {"branin": (Branin()(x).item(), float("nan"))}

# Set Plotly rendering and determine execution mode
NUM_EVALS = 30
SMOKE_TEST = os.environ.get("SMOKE_TEST")
fast_smoke_test = fast_botorch_optimize_context_manager if SMOKE_TEST else nullcontext

# Generate initial Sobol samples
torch.manual_seed(0)
for _ in range(5):  # Number of Sobol trials
    parameters, trial_index = ax_client.get_next_trial()
    raw_data = evaluate(parameters)
    ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Run Bayesian optimization
with fast_smoke_test():
    for i in range(NUM_EVALS - 5):  # Subtract initial trials
        parameters, trial_index = ax_client.get_next_trial()
        raw_data = evaluate(parameters)
        ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

print("Optimization completed.")


[INFO 12-12 20:42:03] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-12 20:42:03] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:42:03] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:42:03] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).

Encountered exc

TypeError: LaplaceBNNWrapper.__init__() got an unexpected keyword argument 'train_X'