In [3]:

from pairwise_probit import PairwiseGP
import torch
from variational_gp import BinaryClassificationGP
from aepsych.config import Config
from gpytorch.likelihoods import BernoulliLikelihood
from botorch.models.likelihoods.pairwise import PairwiseProbitLikelihood, PairwiseLikelihood
from aepsych.models.base import AEPsychModel
from aepsych.models.surrogate import AEPsychSurrogate
from aepsych.models.pairwise_probit import PairwiseProbitModel
from sklearn.datasets import make_classification
from aepsych.utils import promote_0d, _process_bounds
from scipy.stats import norm

import time
from typing import Any, Dict, Optional
from aepsych.config import Config
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll
from botorch.models import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import Normalize
logger = getLogger()

class PairwiseGPModel(PairwiseGP, AEPsychModel, AEPsychSurrogate):
    name = "PairwiseProbitModel"
    outcome_type = "binary"
    stimuli_per_trial = 1

    def __init__(
        self,
        config: Config,
        datapoints: Optional[torch.Tensor] = None,
        comparisons: Optional[torch.Tensor] = None,
    ):
        config_opts = self.get_config_options(config)
        lb = config.getlist("common", "lb", element_type=float)
        ub = config.getlist("common", "ub", element_type=float)
        self.lb, self.ub, dim = _process_bounds(lb, ub, None)

        
        bounds = torch.stack((self.lb, self.ub))
        input_transform = Normalize(d=dim, bounds=bounds)

        super().__init__(
            datapoints=datapoints,
            comparisons=comparisons,
            covar_module=config_opts["covar_module"],
            jitter=1e-3,
            input_transform=input_transform,
        )

        self.dim = dim  # The Pairwise constructor sets self.dim = None.
            

    @classmethod
    def get_mll_class(cls):
        return PairwiseLaplaceMarginalLogLikelihood
    
    def predict(
        self, x, probability_space=False, num_samples=1000, rereference="x_min"
    ):
        if rereference is not None:
            samps = self.sample(x, num_samples, rereference)
            fmean, fvar = samps.mean(0).squeeze(), samps.var(0).squeeze()
        else:
            post = self.posterior(x)
            fmean, fvar = post.mean.squeeze(), post.variance.squeeze()

        if probability_space:
            return (
                promote_0d(norm.cdf(fmean)),
                promote_0d(norm.cdf(fvar)),
            )
        else:
            return fmean, fvar
    
    def fit(
        self,
        train_x: torch.Tensor,
        train_y: torch.Tensor,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        self.train()
        mll = PairwiseLaplaceMarginalLogLikelihood(self.likelihood, self)
        datapoints, comparisons = PairwiseProbitModel._pairs_to_comparisons(train_x, train_y)
        self.set_train_data(datapoints, comparisons)

        optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy()
        max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time)
        if max_fit_time is not None:
            # figure out how long evaluating a single samp
            starttime = time.time()
            _ = mll(self(datapoints), comparisons)
            single_eval_time = time.time() - starttime
            n_eval = int(max_fit_time / single_eval_time)
            optimizer_kwargs["maxfun"] = n_eval
            logger.info(f"fit maxfun is {n_eval}")

        logger.info("Starting fit...")
        starttime = time.time()
        fit_gpytorch_mll(mll, **kwargs, **optimizer_kwargs)
        logger.info(f"Fit done, time={time.time()-starttime}")


    @classmethod
    def get_config_options(cls, config: Config, name: str = None):
        options = super().get_config_options(config, name)
        classname = cls.__class__.__name__

        inducing_point_method = config.get(
            classname, "inducing_point_method", fallback="auto"
        )
        inducing_size = config.getint(classname, "inducing_size", fallback=10)
        learn_inducing_points = config.getboolean(
            classname, "learn_inducing_points", fallback=False
        )

        options.update(
            {
                "inducing_size": inducing_size,
                "inducing_point_method": inducing_point_method,
                "learn_inducing_points": learn_inducing_points,
                "likelihood" : PairwiseProbitLikelihood()
            }
        )

        return options

In [5]:

X, y = make_classification(
            n_samples=10,
            n_features=1,
            n_redundant=0,
            n_informative=1,
            random_state=1,
            n_clusters_per_class=1,
        )
X, y = torch.Tensor(X), torch.Tensor(y)

datapoints = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
comparisons = torch.Tensor([[0, 1], [2, 1]])


config_file = "../../configs/ax_pairwise_opt_example.ini"
config = Config(config_fnames=[config_file])
pairwise_model = PairwiseGPModel(config, datapoints= datapoints)

config_opts = pairwise_model.get_config_options(config)
print(config_opts)
vars(pairwise_model)

Normalize()
{'likelihood': PairwiseProbitLikelihood(), 'covar_module': ScaleKernel(
  (base_kernel): RBFKernel(
    (lengthscale_prior): GammaPrior()
    (raw_lengthscale_constraint): GreaterThan(1.000E-04)
  )
  (outputscale_prior): SmoothedBoxPrior()
  (raw_outputscale_constraint): Positive()
), 'mean_module': ConstantMean(), 'max_fit_time': None, 'inducing_size': 10, 'inducing_point_method': 'auto', 'learn_inducing_points': False}


{'lb': tensor([0., 0.], dtype=torch.float32),
 'ub': tensor([1., 1.], dtype=torch.float32),
 'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict([('datapoints', None),
              ('comparisons', None),
              ('D', None),
              ('DT', None),
              ('utility', None),
              ('covar_chol', None),
              ('likelihood_hess', None),
              ('hlcov_eye', None),
              ('covar', None),
              ('covar_inv', None)]),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict([(29,
               <torch.nn.modules.module._WrappedHook at 0x1fee0e333d0>)]),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('likelihood', PairwiseProbitLikelihood()),
              ('mean_module', ConstantMean()),
 

In [6]:
binary_X, binary_y = make_classification(
            n_samples=10,
            n_features=1,
            n_redundant=0,
            n_informative=1,
            random_state=1,
            n_clusters_per_class=1,
        )
binary_X, binary_y = torch.Tensor(binary_X), torch.Tensor(binary_y).reshape(-1, 1)

       
binary_model = BinaryClassificationGP(
    train_X=binary_X, train_Y=binary_y, likelihood=BernoulliLikelihood(), inducing_points=10
)

vars(binary_model)


Input data is not contained to the unit cube. Please consider min-max scaling the input data.


Input data is not standardized. Please consider scaling the input to zero mean and unit variance.



{'_num_outputs': 1,
 '_input_batch_shape': torch.Size([]),
 '_aug_batch_shape': torch.Size([]),
 '_is_custom_likelihood': True,
 '_inducing_point_allocator': <botorch.models.utils.inducing_point_allocators.GreedyVarianceReduction at 0x1fee0e2f280>,
 'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('model',
               _SingleTaskVariationalGP(
                 (variational_strategy): VariationalStrategy(
                   (_variational_distribution): CholeskyVariationalDistribution()
                 )
                 (mean_module): ConstantMean()
                 (covar_module): ScaleKernel(
                   (base_kernel