In [1]:
from data import load, split, ScaleAbsOne
from gpytorch.constraints import Positive

In [2]:
df = load()
train, dev, test = split(df, "new_dose-response_matrices", 1, 1)

100%|██████████| 10/10 [00:06<00:00,  1.47it/s]
100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


In [3]:
train_target = train.pop("PercentageGrowth")
dev_target = dev.pop("PercentageGrowth")
test_target = test.pop("PercentageGrowth")
scaler = ScaleAbsOne()
train = scaler.fit_transform(train.values)
dev = scaler.transform(dev.values)
test = scaler.transform(test.values)

In [4]:
import torch
import gpytorch

In [5]:
sub = torch.tensor(train[:1500]).float()
subt = torch.tensor(train_target[:1500].values).float()

In [6]:
from gpytorch.constraints import Positive

class MultiLinearKernel(gpytorch.kernels.Kernel):
    # the sinc kernel is stationary
    is_stationary = False

    # We will register the parameter when initializing the kernel
    def __init__(self, length_prior=None, length_constraint=None, **kwargs):
        super().__init__(**kwargs)

        # register the raw parameter
        self.register_parameter(
            name='raw_length', parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1, 1))
        )

        # set the parameter constraint to be positive, when nothing is specified
        if length_constraint is None:
            length_constraint = Positive()

        # register the constraint
        self.register_constraint("raw_length", length_constraint)

        # set the parameter prior, see
        # https://docs.gpytorch.ai/en/latest/module.html#gpytorch.Module.register_prior
        if length_prior is not None:
            self.register_prior(
                "length_prior",
                length_prior,
                lambda m: m.length,
                lambda m, v : m._set_length(v),
            )

    # now set up the 'actual' paramter
    @property
    def length(self):
        # when accessing the parameter, apply the constraint transform
        return self.raw_length_constraint.transform(self.raw_length)

    @length.setter
    def length(self, value):
        return self._set_length(value)

    def _set_length(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_length)
        # when setting the paramater, transform the actual value to a raw one by applying the inverse transform
        self.initialize(raw_length=self.raw_length_constraint.inverse_transform(value))

    # this is the kernel function
    def forward(self, x1, x2, **params):
        prod = torch.einsum("nd, md -> nmd", x1, x2)
        frac = (self.length**2 + prod) / (1 + self.length**2)
        return frac.prod(-1)

In [7]:
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
#        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.covar_module = MultiLinearKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(sub, subt, likelihood)

In [8]:
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

In [9]:
for i in range(50):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(sub)
    # Calc loss and backprop gradients
    loss = -mll(output, subt)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, 50, loss.item(),
        model.covar_module.base_kernel.lengthscale.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()

ModuleAttributeError: 'MultiLinearKernel' object has no attribute 'base_kernel'