# Model Building with BoFire

This notebooks shows how to setup and analyze models trained with BoFire. It is still WIP.

## Imports

In [None]:
import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.enum import RegressionMetricsEnum
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.surrogates.api import SingleTaskGPSurrogate
from bofire.plot.feature_importance import plot_feature_importance_by_feature_plotly
from bofire.surrogates.feature_importance import (
    combine_lengthscale_importances,
    combine_permutation_importances,
    lengthscale_importance_hook,
    permutation_importance_hook,
)

## Problem Setup

For didactic purposes, we sample data from a Himmelblau benchmark function and use them to train a SingleTaskGP.

In [None]:
# TODO: replace this after JDs PR is ready.
input_features = Inputs(
    features=[ContinuousInput(key=f"x_{i+1}", bounds=(-4, 4)) for i in range(3)],
)
output_features = Outputs(features=[ContinuousOutput(key="y")])
experiments = input_features.sample(n=50)
experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True)
experiments["valid_y"] = 1

## Cross Validation
### Run the cross validation

In [None]:
data_model = SingleTaskGPSurrogate(
    inputs=input_features,
    outputs=output_features,
)

model = surrogates.map(data_model=data_model)
train_cv, test_cv, pi = model.cross_validate(
    experiments,
    folds=5,
    hooks={
        "permutation_importance": permutation_importance_hook,
        "lengthscale_importance": lengthscale_importance_hook,
    },
)

In [None]:
combined_importances = {
    m.name: combine_permutation_importances(pi["permutation_importance"], m).describe()
    for m in RegressionMetricsEnum
}
combined_importances["lengthscale"] = combine_lengthscale_importances(
    pi["lengthscale_importance"],
).describe()
plot_feature_importance_by_feature_plotly(
    combined_importances,
    relative=False,
    caption="Permutation Feature Importances",
    show_std=True,
    importance_measure="Permutation Feature Importance",
)

### Analyze the cross validation

Plots are added in a future PR.

In [None]:
# Performance on test sets
test_cv.get_metrics(combine_folds=True)

In [None]:
display(test_cv.get_metrics(combine_folds=False))
display(test_cv.get_metrics(combine_folds=False).describe())