Code below defines a GP network and supporting functions based on code found [here](https://github.com/RaulAstudillo06/BOFN/blob/main/bofn/models/gp_network.py).

In [None]:
# @title
from __future__ import annotations
import torch
from typing import Any, Tuple
from botorch.models.model import Model
from botorch.models import FixedNoiseGP
from botorch import fit_gpytorch_model
from botorch.posteriors import Posterior
from botorch.models.transforms import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch import Tensor
from typing import List, Optional

class GaussianProcessNetwork(Model):
    r"""
    """

    def __init__(self, train_X, train_Y, dag, active_input_indices, train_Yvar=None, node_GPs=None, normalization_constant_lower=None, normalization_constant_upper=None) -> None:
        r"""
        """
        self.train_X = train_X
        self.train_Y = train_Y
        self.dag = dag
        self.n_nodes = dag.get_n_nodes()
        self.root_nodes = dag.get_root_nodes()
        self.active_input_indices = active_input_indices
        self.train_Yvar = train_Yvar
        self.noise_var = 1e-5

        if node_GPs is not None:
            self.node_GPs = node_GPs
            self.normalization_constant_lower = normalization_constant_lower
            self.normalization_constant_upper = normalization_constant_upper
        else:
            self.node_GPs = [None for k in range(self.n_nodes)]
            self.node_mlls = [None for k in range(self.n_nodes)]
            self.normalization_constant_lower = [[None for j in range(len(self.dag.get_parent_nodes(k)))] for k in range(self.n_nodes)]
            self.normalization_constant_upper = [[None for j in range(len(self.dag.get_parent_nodes(k)))] for k in range(self.n_nodes)]

            for k in self.root_nodes:
                if self.active_input_indices is not None:
                    train_X_node_k = train_X[..., self.active_input_indices[k]]
                else:
                    train_X_node_k = train_X
                train_Y_node_k = train_Y[..., [k]]
                #self.node_GPs[k] = SingleTaskGP(train_X=train_X_node_k, train_Y=train_Y_node_k, outcome_transform=Standardize(m=1, batch_shape=torch.Size([1])))
                self.node_GPs[k] = FixedNoiseGP(train_X=train_X_node_k, train_Y=train_Y_node_k, train_Yvar=torch.ones(train_Y_node_k.shape) * self.noise_var, outcome_transform=Standardize(m=1))
                self.node_mlls[k] = ExactMarginalLogLikelihood(self.node_GPs[k].likelihood, self.node_GPs[k])
                fit_gpytorch_model(self.node_mlls[k])

            for k in range(self.n_nodes):
                if self.node_GPs[k] is None:
                    aux = train_Y[..., self.dag.get_parent_nodes(k)].clone()
                    for j in range(len(self.dag.get_parent_nodes(k))):
                        self.normalization_constant_lower[k][j] = torch.min(aux[..., j])
                        self.normalization_constant_upper[k][j] = torch.max(aux[..., j])
                        aux[..., j] = (aux[..., j] - self.normalization_constant_lower[k][j])/(self.normalization_constant_upper[k][j] - self.normalization_constant_lower[k][j])
                    train_X_node_k = torch.cat([train_X[..., self.active_input_indices[k]], aux], -1)
                    train_Y_node_k = train_Y[..., [k]]
                    aux_model =  FixedNoiseGP(train_X=train_X_node_k, train_Y=train_Y_node_k, train_Yvar=torch.ones(train_Y_node_k.shape) * self.noise_var, outcome_transform=Standardize(m=1))
                    batch_shape = aux_model._aug_batch_shape
                    #self.node_GPs[k] = SingleTaskGP(train_X=train_X_node_k, train_Y=train_Y_node_k, outcome_transform=Standardize(m=1, batch_shape=torch.Size([1])))
                    #self.node_GPs[k] = FixedNoiseGP(train_X=train_X_node_k, train_Y=train_Y_node_k, train_Yvar=torch.ones(train_Y_node_k.shape) * 1e-6, outcome_transform=Standardize(m=1, batch_shape=torch.Size([1])))
                    self.node_GPs[k] = FixedNoiseGP(train_X=train_X_node_k, train_Y=train_Y_node_k, train_Yvar=torch.ones(train_Y_node_k.shape) * self.noise_var, outcome_transform=Standardize(m=1, batch_shape=torch.Size([])))
                    self.node_mlls[k] = ExactMarginalLogLikelihood(self.node_GPs[k].likelihood, self.node_GPs[k])
                    fit_gpytorch_model(self.node_mlls[k])

    def posterior(self, X: Tensor, posterior_transform=None, observation_noise=False) -> MultivariateNormalNetwork:
        r"""Computes the posterior over model outputs at the provided points.
        Args:
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
                of the feature space and `q` is the number of points considered
                jointly.
            observation_noise: If True, add the observation noise from the
                likelihood to the posterior. If a Tensor, use it directly as the
                observation noise (must be of shape `(batch_shape) x q`).
        Returns:
            A `GPyTorchPosterior` object, representing a batch of `b` joint
            distributions over `q` points. Includes observation noise if
            specified.
        """
        return MultivariateNormalNetwork(self.node_GPs, self.dag, X, self.active_input_indices, self.normalization_constant_lower, self.normalization_constant_upper)

    def forward(self, x: Tensor) -> MultivariateNormalNetwork:
        return MultivariateNormalNetwork(self.node_GPs, self.dag, x, self.active_input_indices, self.normalization_constant)

    def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
        r"""Condition the model on new observations.
        Args:
            X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
                the feature space, `n'` is the number of points per batch, and
                `batch_shape` is the batch shape (must be compatible with the
                batch shape of the model).
            Y: A `batch_shape' x n' x m`-dim Tensor, where `m` is the number of
                model outputs, `n'` is the number of points per batch, and
                `batch_shape'` is the batch shape of the observations.
                `batch_shape'` must be broadcastable to `batch_shape` using
                standard broadcasting semantics. If `Y` has fewer batch dimensions
                than `X`, it is assumed that the missing batch dimensions are
                the same for all `Y`.
        Returns:
            A `Model` object of the same type, representing the original model
            conditioned on the new observations `(X, Y)` (and possibly noise
            observations passed in via kwargs).
        """
        fantasy_models = [None for k in range(self.n_nodes)]

        for k in self.root_nodes:
            if self.active_input_indices is not None:
                X_node_k = X[..., self.active_input_indices[k]]
            else:
                X_node_k = X
            Y_node_k = Y[..., [k]]
            fantasy_models[k] = self.node_GPs[k].condition_on_observations(X_node_k, Y_node_k, noise=torch.ones(Y_node_k.shape[1:]) * self.noise_var)

        for k in range(self.n_nodes):
            if fantasy_models[k] is None:
                aux = Y[..., self.dag.get_parent_nodes(k)].clone()
                for j in range(len(self.dag.get_parent_nodes(k))):
                    aux[..., j] = (aux[..., j] - self.normalization_constant_lower[k][j])/(self.normalization_constant_upper[k][j] - self.normalization_constant_lower[k][j])
                aux_shape = [aux.shape[0]] + [1] * X[..., self.active_input_indices[k]].ndim
                X_aux = X[..., self.active_input_indices[k]].unsqueeze(0).repeat(*aux_shape)
                X_node_k = torch.cat([X_aux, aux], -1)
                Y_node_k = Y[..., [k]]
                fantasy_models[k] = self.node_GPs[k].condition_on_observations(X_node_k, Y_node_k, noise=torch.ones(Y_node_k.shape[1:]) * self.noise_var)

        return GaussianProcessNetwork(dag=self.dag, train_X=X, train_Y=Y, active_input_indices=self.active_input_indices, node_GPs=fantasy_models, normalization_constant_lower=self.normalization_constant_lower, normalization_constant_upper=self.normalization_constant_upper)


class MultivariateNormalNetwork(Posterior):
    def __init__(self, node_GPs, dag, X, indices_X=None, normalization_constant_lower=None, normalization_constant_upper=None):
        self.node_GPs = node_GPs
        self.dag = dag
        self.n_nodes = dag.get_n_nodes()
        self.root_nodes = dag.get_root_nodes()
        self.X = X
        self.active_input_indices = indices_X
        self.normalization_constant_lower = normalization_constant_lower
        self.normalization_constant_upper = normalization_constant_upper

    @property
    def device(self) -> torch.device:
        r"""The torch device of the posterior."""
        return "cpu"

    @property
    def dtype(self) -> torch.dtype:
        r"""The torch dtype of the posterior."""
        return torch.double

    @property
    def event_shape(self) -> torch.Size:
        r"""The event shape (i.e. the shape of a single sample) of the posterior."""
        shape = list(self.X.shape)
        shape[-1] = self.n_nodes
        shape = torch.Size(shape)
        return shape

    @property
    def base_sample_shape(self) -> torch.Size:
        r"""The base shape of the base samples expected in `rsample`.

        Informs the sampler to produce base samples of shape
        `sample_shape x base_sample_shape`.
        """
        shape = torch.Size(list([1,1,self.n_nodes]))
        return shape

    @property
    def batch_range(self) -> Tuple[int, int]:
        r"""The t-batch range.

        This is used in samplers to identify the t-batch component of the
        `base_sample_shape`. The base samples are expanded over the t-batches to
        provide consistency in the acquisition values, i.e., to ensure that a
        candidate produces same value regardless of its position on the t-batch.
        """
        return (0, -1)

    def rsample_from_base_samples(self, sample_shape: torch.Size, base_samples: Tensor) -> Tensor:
        return self.rsample(sample_shape, base_samples)

    def rsample(self, sample_shape=torch.Size(), base_samples=None):
        #t0 =  time.time()
        nodes_samples = torch.empty(sample_shape + self.event_shape)
        nodes_samples = nodes_samples.double()
        nodes_samples_available = [False for k in range(self.n_nodes)]
        for k in self.root_nodes:
            #t0 =  time.time()
            if self.active_input_indices is not None:
                X_node_k = self.X[..., self.active_input_indices[k]]
            else:
                X_node_k = self.X
            multivariate_normal_at_node_k = self.node_GPs[k].posterior(X_node_k)
            if base_samples is not None:
                nodes_samples[..., k] = multivariate_normal_at_node_k.rsample(sample_shape, base_samples=base_samples[..., [k]])[..., 0]
            else:
                nodes_samples[..., k] = multivariate_normal_at_node_k.rsample(sample_shape)[..., 0]
            nodes_samples_available[k] = True
            #t1 = time.time()
            #print('Part A of the code took: ' + str(t1 - t0))

        while not all(nodes_samples_available):
            for k in range(self.n_nodes):
                parent_nodes = self.dag.get_parent_nodes(k)
                if not nodes_samples_available[k] and all([nodes_samples_available[j] for j in parent_nodes]):
                    #t0 =  time.time()
                    parent_nodes_samples_normalized = nodes_samples[..., parent_nodes].clone()
                    for j in range(len(parent_nodes)):
                        parent_nodes_samples_normalized[..., j] = (parent_nodes_samples_normalized[..., j] - self.normalization_constant_lower[k][j])/(self.normalization_constant_upper[k][j] - self.normalization_constant_lower[k][j])
                    X_node_k = self.X[..., self.active_input_indices[k]]
                    aux_shape = [sample_shape[0]] + [1] * X_node_k.ndim
                    X_node_k = X_node_k.unsqueeze(0).repeat(*aux_shape)
                    X_node_k = torch.cat([X_node_k, parent_nodes_samples_normalized], -1)
                    multivariate_normal_at_node_k = self.node_GPs[k].posterior(X_node_k)
                    if base_samples is not None:
                        #print(torch.sqrt(multivariate_normal_at_node_k.variance).shape)
                        #print(torch.flatten(base_samples[..., k]).shape)
                        my_aux = torch.sqrt(multivariate_normal_at_node_k.variance)
                        #print(my_aux.ndim)
                        if my_aux.ndim == 4:
                            nodes_samples[...,k] = (multivariate_normal_at_node_k.mean + torch.einsum('abcd,a->abcd', torch.sqrt(multivariate_normal_at_node_k.variance), torch.flatten(base_samples[..., k])))[..., 0]
                        elif my_aux.ndim == 5:
                            nodes_samples[...,k] = (multivariate_normal_at_node_k.mean + torch.einsum('abcde,a->abcde', torch.sqrt(multivariate_normal_at_node_k.variance), torch.flatten(base_samples[..., k])))[..., 0]
                        else:
                            print(error)
                    else:
                        nodes_samples[..., k] = multivariate_normal_at_node_k.rsample()[0, ..., 0]
                    nodes_samples_available[k] = True
                    #t1 = time.time()
                    #print('Part B of the code took: ' + str(t1 - t0))
        #t1 = time.time()
        #print('Taking this sample took: ' + str(t1 - t0))
        return nodes_samples

class DAG(object):

    def __init__(self, parent_nodes:List[List[Optional[int]]]):
        self.parent_nodes = parent_nodes
        self.n_nodes = len(parent_nodes)
        self.root_nodes = []
        for k in range(self.n_nodes):
            if len(parent_nodes[k]) == 0:
                self.root_nodes.append(k)

    def get_n_nodes(self):
        return self.n_nodes

    def get_parent_nodes(self, k):
        return self.parent_nodes[k]

    def get_root_nodes(self):
        return self.root_nodes

def generate_initial_design(num_samples: int, input_dim: int, seed=None):
    # generate training data
    if seed is not None:
        old_state = torch.random.get_rng_state()
        torch.manual_seed(seed)
        X = torch.rand([num_samples, input_dim])
        torch.random.set_rng_state(old_state)
    else:
        X = torch.rand([num_samples, input_dim])
    return X

In [None]:
# @title
from botorch.acquisition import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor
from typing import Optional

class PosteriorMean(MCAcquisitionFunction):
    """
    """

    def __init__(
        self,
        model: Model,
        sampler: Optional[MCSampler] = None,
        objective: Optional[MCAcquisitionObjective] = None,
        X_pending: Optional[Tensor] = None,
    ) -> None:
        r"""
        """
        super().__init__(
            model=model, sampler=sampler, objective=objective, X_pending=X_pending
        )

    @concatenate_pending_points
    @t_batch_mode_transform()
    def forward(self, X: Tensor) -> Tensor:
        r"""
        """
        posterior = self.model.posterior(X)
        samples = self.sampler(posterior)
        obj = self.objective(samples)
        obj = obj.mean(dim=0)[..., 0]
        return obj

In [None]:
# @title
import torch
from torch import Tensor
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.optim import optimize_acqf
from botorch.optim.initializers import (
    gen_batch_initial_conditions,
    gen_one_shot_kg_initial_conditions,
)

def optimize_acqf_and_get_suggested_point(
    acq_func,
    bounds,
    batch_size,
    posterior_mean=None,
    ) -> Tensor:
    """Optimizes the acquisition function, and returns a new candidate."""
    input_dim = bounds.shape[1]
    num_restarts=10*input_dim
    raw_samples=100*input_dim

    ic_gen = (
        gen_one_shot_kg_initial_conditions
        if isinstance(acq_func, qKnowledgeGradient)
        else gen_batch_initial_conditions
    )
    batch_initial_conditions = ic_gen(
        acq_function=acq_func,
        bounds=bounds,
        q=batch_size,
        num_restarts=num_restarts,
        raw_samples=raw_samples,
        options={"batch_limit": num_restarts},
    )

    if posterior_mean is not None:
        baseline_candidate, _ = optimize_acqf(
            acq_function=posterior_mean,
            bounds=bounds,
            q=batch_size,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options={"batch_limit": 5},
        )

        if isinstance(acq_func, qKnowledgeGradient):
            augmented_q_batch_size = acq_func.get_augmented_q_batch_size(batch_size)
            baseline_candidate = baseline_candidate.detach().repeat(1, augmented_q_batch_size, 1)
        else:
            baseline_candidate = baseline_candidate.detach().view(torch.Size([1, batch_size, input_dim]))

        batch_initial_conditions = torch.cat([batch_initial_conditions, baseline_candidate], 0)
        num_restarts += 1

    candidate, acq_value = optimize_acqf(
        acq_function=acq_func,
        bounds=bounds,
        q=batch_size,
        num_restarts=num_restarts,
        raw_samples=raw_samples,
        batch_initial_conditions=batch_initial_conditions,
        options={"batch_limit": 2},
        #options={'disp': True, 'iprint': 101},
    )

    if baseline_candidate is not None:
        baseline_acq_value = acq_func.forward(baseline_candidate)[0].detach()
        print('Test begins')
        print(f"Best acquisition value: {acq_value}")
        print(f"Base acquisition value: {baseline_acq_value}")
        print('Test ends')
        if baseline_acq_value >= acq_value:
            print('Baseline candidate was best found.')
            candidate = baseline_candidate

    new_x = candidate.detach().view([batch_size, input_dim])
    return new_x

def get_new_suggested_point(
    algo: str,
    X: Tensor,
    network_output_at_X: Tensor,
    objective_at_X: Tensor,
    network_to_objective_transform: Callable,
    dag: DAG,
    active_input_indices: List[int],
) -> Tensor:
    input_dim = X.shape[-1]

    if algo == "Random":
        return torch.rand([1, input_dim])
    elif algo == "EIFN":
        # Model
        model = GaussianProcessNetwork(train_X=X, train_Y=network_output_at_X, dag=dag,
                          active_input_indices=active_input_indices)
        # Sampler
        qmc_sampler = SobolQMCNormalSampler(torch.Size([128]))
        # Acquisition function
        acquisition_function = qExpectedImprovement(
            model=model,
            best_f=objective_at_X.max().item(),
            sampler=qmc_sampler,
            objective=network_to_objective_transform,

        )
        posterior_mean_function = PosteriorMean(
            model=model,
            sampler=qmc_sampler,
            objective=network_to_objective_transform,
        )
    elif algo == "EICF":
        model = fit_gp_model(X=X, Y=network_output_at_X)
        qmc_sampler = SobolQMCNormalSampler(num_samples=128)
        acquisition_function = qExpectedImprovement(
            model=model,
            best_f=objective_at_X.max().item(),
            sampler=qmc_sampler,
            objective=network_to_objective_transform,

        )
        posterior_mean_function = PosteriorMean(
            model=model,
            sampler=qmc_sampler,
            objective=network_to_objective_transform,
        )
    elif algo == "EI":
        model = fit_gp_model(X=X, Y=objective_at_X)
        acquisition_function = ExpectedImprovement(
            model=model, best_f=objective_at_X.max().item())
        posterior_mean_function = GPPosteriorMean(model=model)
    elif algo == "KG":
        model = fit_gp_model(X=X, Y=objective_at_X)
        acquisition_function = qKnowledgeGradient(
            model=model, num_fantasies=8)
        posterior_mean_function = GPPosteriorMean(model=model)

    new_x = optimize_acqf_and_get_suggested_point(
        acq_func=acquisition_function,
        bounds=torch.tensor([[0. for i in range(input_dim)], [
                            1. for i in range(input_dim)]]),
        batch_size=1,
        posterior_mean=posterior_mean_function,
    )

    return new_x

In [None]:
# Gaussian process network example
import torch
from botorch.acquisition.objective import GenericMCObjective
from botorch.settings import debug
from torch import Tensor
from botorch.acquisition import ExpectedImprovement, qExpectedImprovement
from botorch.acquisition import PosteriorMean as GPPosteriorMean
from botorch.sampling.normal import SobolQMCNormalSampler
import time

class Dropwave:
    def __init__(self):
        self.n_nodes = 2
        self.input_dim = 2

    def evaluate(self, X):
        X_scaled = 10.24 * X - 5.12
        input_shape = X_scaled.shape
        output = torch.empty(input_shape[:-1] + torch.Size([self.n_nodes]))
        norm_X = torch.norm(X_scaled, dim=-1)
        output[..., 0] = norm_X
        output[..., 1] = (1.0 + torch.cos(12.0 * norm_X)) /(2.0 + 0.5 * (norm_X ** 2))
        return output

torch.set_default_dtype(torch.float64)
debug._set_state(True)

dropwave = Dropwave()
input_dim = dropwave.input_dim
n_nodes = 2
problem = 'dropwave'

def function_network(X: Tensor):
    return dropwave.evaluate(X=X)

# Underlying DAG
parent_nodes = []
parent_nodes.append([])
parent_nodes.append([0])
dag= DAG(parent_nodes=parent_nodes)

# Active input indices
active_input_indices = [[0, 1], []]

# Function that maps the network output to the objective value
network_to_objective_transform = lambda Y: Y[..., -1]
network_to_objective_transform = GenericMCObjective(network_to_objective_transform)

# Generate initial data
n_init_evals = 2*(input_dim + 1)
trial = 42
X = generate_initial_design(num_samples=n_init_evals, input_dim=input_dim, seed=trial)
network_output_at_X = function_network(X)
objective_at_X = network_to_objective_transform(network_output_at_X)

# Algorithm
algo = 'EIFN'

# New suggested point
t0 = time.time()
new_x = get_new_suggested_point(
    algo=algo,
    X=X,
    network_output_at_X=network_output_at_X,
    objective_at_X=objective_at_X,
    network_to_objective_transform=network_to_objective_transform,
    dag=dag,
    active_input_indices=active_input_indices,
)
t1 = time.time()
print(f"New point: {new_x}")
print(f"Took {t1 - t0} seconds")

# # Fit the GPNetwork model
# model = GaussianProcessNetwork(train_X=X, train_Y=network_output_at_X, dag=dag, active_input_indices=active_input_indices)

# # Sampler
# qmc_sampler = SobolQMCNormalSampler(torch.Size([128]))

# # Number of test points
# Ntest = 500

# if True:
#     # Define the mean function
#     mean_function = PosteriorMean(model, qmc_sampler, network_to_objective_transform)

#     # Calculate mean of GPnetwork over set of test points
#     Xtest = torch.rand((Ntest,1,input_dim))
#     objective_mean_at_Xtest = mean_function(Xtest)

#     # Calculate the maximum point
#     print(f"Max found objective mean is: {objective_mean_at_Xtest.max()}")

# else:
#     # Define the acquisition function
#     acquisition_function = qExpectedImprovement(
#         model=model,
#         best_f=objective_at_X.max().item(),
#         sampler=qmc_sampler,
#         objective=network_to_objective_transform,
#     )

#     # Calculate qEI for GPnetwork over set of test points
#     Xtest = torch.rand((Ntest,1,input_dim))
#     acquisition_func_at_Xtest = acquisition_function(Xtest)

#     # Calculate the maximum point
#     print(f"Max found acq func is: {acquisition_func_at_Xtest.max()}")

Test begins
Best acquisition value: 0.018715645602991826
Base acquisition value: 0.0013770879083504395
Test ends
New point: tensor([[0.5490, 0.4243]])
Took 5.683061122894287 seconds
