Skip to content

Commit

Permalink
Feature/linear deterministic (#385)
Browse files Browse the repository at this point in the history
* add linear deterministic surrogate
  • Loading branch information
jduerholt committed Mar 22, 2024
1 parent 25112bf commit feec334
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 0 deletions.
3 changes: 3 additions & 0 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AnyBotorchSurrogate,
BotorchSurrogates,
)
from bofire.data_models.surrogates.deterministic import LinearDeterministicSurrogate
from bofire.data_models.surrogates.empirical import EmpiricalSurrogate
from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
from bofire.data_models.surrogates.linear import LinearSurrogate
Expand Down Expand Up @@ -46,6 +47,7 @@
LinearSurrogate,
PolynomialSurrogate,
TanimotoGPSurrogate,
LinearDeterministicSurrogate,
]

AnyTrainableSurrogate = Union[
Expand Down Expand Up @@ -74,6 +76,7 @@
LinearSurrogate,
PolynomialSurrogate,
TanimotoGPSurrogate,
LinearDeterministicSurrogate,
]

AnyClassificationSurrogate = ClassificationMLPEnsemble
2 changes: 2 additions & 0 deletions bofire/data_models/surrogates/botorch_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from bofire.data_models.base import BaseModel
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.surrogates.deterministic import LinearDeterministicSurrogate
from bofire.data_models.surrogates.empirical import EmpiricalSurrogate
from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
from bofire.data_models.surrogates.linear import LinearSurrogate
Expand Down Expand Up @@ -34,6 +35,7 @@
TanimotoGPSurrogate,
LinearSurrogate,
PolynomialSurrogate,
LinearDeterministicSurrogate,
]


Expand Down
41 changes: 41 additions & 0 deletions bofire/data_models/surrogates/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Annotated, Dict, Literal, Type

from pydantic import Field, model_validator

from bofire.data_models.features.api import (
AnyOutput,
ContinuousInput,
ContinuousOutput,
DiscreteInput,
)
from bofire.data_models.surrogates.botorch import BotorchSurrogate


class LinearDeterministicSurrogate(BotorchSurrogate):
type: Literal["LinearDeterministicSurrogate"] = "LinearDeterministicSurrogate"
coefficients: Annotated[Dict[str, float], Field(min_length=1)]
intercept: float

@classmethod
def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool:
"""Abstract method to check output type for surrogate models
Args:
my_type: continuous or categorical output
Returns:
bool: True if the output type is valid for the surrogate chosen, False otherwise
"""
return isinstance(my_type, type(ContinuousOutput))

@model_validator(mode="after")
def validate_input_types(self):
if len(self.inputs.get([ContinuousInput, DiscreteInput])) != len(self.inputs):
raise ValueError(
"Only numerical inputs are suppoerted for the `LinearDeterministicSurrogate`"
)
return self

@model_validator(mode="after")
def validate_coefficients(self):
if sorted(self.inputs.get_keys()) != sorted(self.coefficients.keys()):
raise ValueError("coefficient keys do not match input feature keys.")
return self
1 change: 1 addition & 0 deletions bofire/surrogates/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bofire.surrogates.botorch_surrogates import BotorchSurrogates
from bofire.surrogates.deterministic import LinearDeterministicSurrogate
from bofire.surrogates.empirical import EmpiricalSurrogate
from bofire.surrogates.mapper import map
from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate
Expand Down
25 changes: 25 additions & 0 deletions bofire/surrogates/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from botorch.models.deterministic import AffineDeterministicModel

from bofire.data_models.surrogates.api import LinearDeterministicSurrogate as DataModel
from bofire.surrogates.botorch import BotorchSurrogate
from bofire.utils.torch_tools import tkwargs


class LinearDeterministicSurrogate(BotorchSurrogate):
def __init__(
self,
data_model: DataModel,
**kwargs,
):
self.intercept = data_model.intercept
self.coefficients = data_model.coefficients
super().__init__(data_model=data_model, **kwargs)
self.model = AffineDeterministicModel(
b=data_model.intercept,
a=torch.tensor(
[data_model.coefficients[key] for key in self.inputs.get_keys()]
)
.to(**tkwargs)
.unsqueeze(-1),
)
2 changes: 2 additions & 0 deletions bofire/surrogates/mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Type

from bofire.data_models.surrogates import api as data_models
from bofire.surrogates.deterministic import LinearDeterministicSurrogate
from bofire.surrogates.empirical import EmpiricalSurrogate
from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate
Expand All @@ -24,6 +25,7 @@
data_models.LinearSurrogate: SingleTaskGPSurrogate,
data_models.PolynomialSurrogate: SingleTaskGPSurrogate,
data_models.TanimotoGPSurrogate: SingleTaskGPSurrogate,
data_models.LinearDeterministicSurrogate: LinearDeterministicSurrogate,
}


Expand Down
67 changes: 67 additions & 0 deletions tests/bofire/data_models/specs/surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,70 @@
"hyperconfig": None,
},
)

specs.add_valid(
models.LinearDeterministicSurrogate,
lambda: {
"inputs": Inputs(
features=[
ContinuousInput(key="a", bounds=(0, 1)),
ContinuousInput(key="b", bounds=(0, 1)),
]
).model_dump(),
"outputs": Outputs(
features=[
features.valid(ContinuousOutput).obj(),
]
).model_dump(),
"intercept": 5.0,
"coefficients": {"a": 2.0, "b": -3.0},
"input_preprocessing_specs": {},
"dump": None,
},
)

specs.add_invalid(
models.LinearDeterministicSurrogate,
lambda: {
"inputs": Inputs(
features=[
ContinuousInput(key="a", bounds=(0, 1)),
ContinuousInput(key="b", bounds=(0, 1)),
]
).model_dump(),
"outputs": Outputs(
features=[
features.valid(ContinuousOutput).obj(),
]
).model_dump(),
"intercept": 5.0,
"coefficients": {"a": 2.0, "b": -3.0, "c": 5.0},
"input_preprocessing_specs": {},
"dump": None,
},
error=ValueError,
message="coefficient keys do not match input feature keys.",
)

specs.add_invalid(
models.LinearDeterministicSurrogate,
lambda: {
"inputs": Inputs(
features=[
ContinuousInput(key="a", bounds=(0, 1)),
CategoricalInput(key="b", categories=["a", "b"]),
]
).model_dump(),
"outputs": Outputs(
features=[
features.valid(ContinuousOutput).obj(),
]
).model_dump(),
"intercept": 5.0,
"coefficients": {"a": 2.0, "b": -3.0},
"input_preprocessing_specs": {},
"dump": None,
},
error=ValueError,
message="Only numerical inputs are suppoerted for the `LinearDeterministicSurrogate`",
)
26 changes: 26 additions & 0 deletions tests/bofire/surrogates/test_deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
from pandas.testing import assert_frame_equal

import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.surrogates.api import LinearDeterministicSurrogate


def test_linear_deterministic_surrogate():
surrogate_data = LinearDeterministicSurrogate(
inputs=Inputs(
features=[
ContinuousInput(key="a", bounds=(0, 1)),
ContinuousInput(key="b", bounds=(0, 1)),
]
),
outputs=Outputs(features=[ContinuousOutput(key="y")]),
intercept=2.0,
coefficients={"b": 3.0, "a": -2.0},
)
surrogate = surrogates.map(surrogate_data)
assert surrogate.input_preprocessing_specs == {}
experiments = pd.DataFrame(data={"a": [1.0, 2.0], "b": [0.5, 4.0]})
preds = surrogate.predict(experiments)
assert_frame_equal(preds, pd.DataFrame(data={"y_pred": [1.5, 10.0], "y_sd": 0.0}))

0 comments on commit feec334

Please sign in to comment.