In [1]:
from __future__ import annotations

from pairwise_probit import PairwiseGP
import torch
from variational_gp import BinaryClassificationGP
from aepsych.config import Config
from gpytorch.likelihoods import BernoulliLikelihood
from botorch.models.likelihoods.pairwise import PairwiseProbitLikelihood, PairwiseLikelihood
from aepsych.models.base import AEPsychModel
from aepsych.models.surrogate import AEPsychSurrogate
from aepsych.models.pairwise_probit import PairwiseProbitModel
from sklearn.datasets import make_classification
from aepsych.utils import promote_0d, _process_bounds
from scipy.stats import norm
from aepsych.models.utils import select_inducing_points

import time, gpytorch, numpy as np
from typing import Any, Dict, Optional, Union
from aepsych.factory import default_mean_covar_factory
from aepsych.config import Config
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll
from botorch.models import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import Normalize


import dataclasses
import time
from typing import Dict, List, Optional
from aepsych.utils_logging import getLogger
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from botorch.fit import fit_gpytorch_mll
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor

logger = getLogger()

class PairwiseGPModel(PairwiseGP, AEPsychModel, AEPsychSurrogate):
    name = "PairwiseProbitModel"
    outcome_type = "binary"
    stimuli_per_trial = 1

    def __init__(
        self,
        lb: Union[np.ndarray, torch.Tensor],
        ub: Union[np.ndarray, torch.Tensor],
        dim: Optional[int] = None,
        covar_module: Optional[gpytorch.kernels.Kernel] = None,
        max_fit_time: Optional[float] = None,
    ):
        self.lb, self.ub, dim = _process_bounds(lb, ub, dim)

        self.max_fit_time = max_fit_time

        bounds = torch.stack((self.lb, self.ub))
        input_transform = Normalize(d=dim, bounds=bounds)
        if covar_module is None:
            config = Config(
                config_dict={
                    "default_mean_covar_factory": {
                        "lb": str(self.lb.tolist()),
                        "ub": str(self.ub.tolist()),
                    }
                }
            )  # type: ignore
            _, covar_module = default_mean_covar_factory(config)
        
        self.botorch_model_class = None

        super().__init__(
            datapoints=None,
            comparisons=None,
            covar_module=covar_module,
            jitter=1e-3,
            input_transform=input_transform,
        )

        self.dim = dim  # The Pairwise constructor sets self.dim = None.
            

    @classmethod
    def get_mll_class(cls):
        return PairwiseLaplaceMarginalLogLikelihood
    
    def predict(
        self, x, probability_space=False, num_samples=1000, rereference="x_min"
    ):
        if rereference is not None:
            samps = self.sample(x, num_samples, rereference)
            fmean, fvar = samps.mean(0).squeeze(), samps.var(0).squeeze()
        else:
            post = self.posterior(x)
            fmean, fvar = post.mean.squeeze(), post.variance.squeeze()

        if probability_space:
            return (
                promote_0d(norm.cdf(fmean)),
                promote_0d(norm.cdf(fvar)),
            )
        else:
            return fmean, fvar
    
    # def fit(
    #     self,
    #     train_x: torch.Tensor,
    #     train_y: torch.Tensor,
    #     optimizer_kwargs: Optional[Dict[str, Any]] = None,
    #     **kwargs,
    # ):
    #     self.train()
    #     mll = PairwiseLaplaceMarginalLogLikelihood(self.likelihood, self)
    #     datapoints, comparisons = PairwiseProbitModel._pairs_to_comparisons(train_x, train_y)
    #     self.set_train_data(datapoints, comparisons)

    #     optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy()
    #     max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time)
    #     if max_fit_time is not None:
    #         # figure out how long evaluating a single samp
    #         starttime = time.time()
    #         _ = mll(self(datapoints), comparisons)
    #         single_eval_time = time.time() - starttime
    #         n_eval = int(max_fit_time / single_eval_time)
    #         optimizer_kwargs["maxfun"] = n_eval
    #         logger.info(f"fit maxfun is {n_eval}")

    #     logger.info("Starting fit...")
    #     starttime = time.time()
    #     fit_gpytorch_mll(mll, **kwargs, **optimizer_kwargs)
    #     logger.info(f"Fit done, time={time.time()-starttime}")

    def fit(
        self,
        datasets: List[SupervisedDataset],
        metric_names: List[str],
        search_space_digest: SearchSpaceDigest,
        candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
        state_dict: Optional[Dict[str, Tensor]] = None,
        refit: bool = True,
        **kwargs,
    ) -> None:
        self.train()
        self._outcomes = metric_names
        if state_dict:
            self.model.load_state_dict(state_dict)

        if state_dict is None or refit:
            mll = self.get_mll_class()(self.likelihood, self)
            optimizer_kwargs = {}
            if self.max_fit_time is not None:
                # figure out how long evaluating a single samp
                starttime = time.time()
                
                if isinstance(self, PairwiseGPModel):
                    datapoints, comparisons = PairwiseProbitModel._pairs_to_comparisons(datasets[0].X(), datasets[0].Y().squeeze())
                    self.set_train_data(datapoints, comparisons)
                    _ = mll(self.model(datapoints), comparisons)
                else:
                    _ = mll(self.model(datasets[0].X()), datasets[0].Y().squeeze())
                single_eval_time = time.time() - starttime
                n_eval = int(self.max_fit_time / single_eval_time)
                logger.info(f"fit maxfun is {n_eval}")
                optimizer_kwargs["options"] = {"maxfun": n_eval}

            logger.info("Starting fit...")
            starttime = time.time()
            fit_gpytorch_mll(
                mll, optimizer_kwargs=optimizer_kwargs
            )  # TODO: Support flexible optimizers
            logger.info(f"Fit done, time={time.time()-starttime}")

    @classmethod
    def construct_inputs(cls, training_data, **kwargs):
        inputs = super().construct_inputs(training_data=training_data, **kwargs)

        inducing_size = kwargs.get("inducing_size")
        inducing_point_method = kwargs.get("inducing_point_method")
        bounds = kwargs.get("bounds")
        inducing_points = select_inducing_points(
            inducing_size,
            inputs["covar_module"],
            inputs["train_X"],
            bounds,
            inducing_point_method,
        )

        inputs.update(
            {
                "inducing_points": inducing_points,
            }
        )

        return inputs

    @classmethod
    def get_config_options(cls, config: Config, name: str = None):
        options = super().get_config_options(config, name)
        classname = cls.__class__.__name__

        inducing_point_method = config.get(
            classname, "inducing_point_method", fallback="auto"
        )
        inducing_size = config.getint(classname, "inducing_size", fallback=10)
        learn_inducing_points = config.getboolean(
            classname, "learn_inducing_points", fallback=False
        )

        options.update(
            {
                "inducing_size": inducing_size,
                "inducing_point_method": inducing_point_method,
                "learn_inducing_points": learn_inducing_points,
                "likelihood" : PairwiseProbitLikelihood()
            }
        )

        return options

In [2]:
X, y = make_classification(
            n_samples=10,
            n_features=1,
            n_redundant=0,
            n_informative=1,
            random_state=1,
            n_clusters_per_class=1,
        )
x, y = torch.Tensor(X), torch.Tensor(y)
y = y.reshape(-1, 1)

datapoints = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
comparisons = torch.Tensor([[0, 1], [2, 1]])

seed = 1
torch.manual_seed(seed)
np.random.seed(seed)

config_file = "../../configs/ax_pairwise_opt_example.ini"
config = Config(config_fnames=[config_file])
lb = config.getlist("common", "lb", element_type=float)
ub = config.getlist("common", "ub", element_type=float)

pairwise_model = PairwiseGPModel(lb, ub)

dataset = SupervisedDataset(x, y)
search_space_digest = SearchSpaceDigest(lb, ub)

pairwise_model.fit([dataset], ["y"], search_space_digest)
# vars(pairwise_model)

2023-03-22 13:36:17,908 [INFO   ] Starting fit...


TypeError: unhashable type: 'PairwiseGPModel'