In [1]:
from data import load, split, ScaleAbsOne
from gpytorch.constraints import Positive
import torch
import math
import gpytorch
from matplotlib import pyplot as plt

In [2]:
df = load()
train, dev, test = split(df, "new_dose-response_matrices", inner_fold=1, outer_fold=1)

100%|██████████| 10/10 [00:07<00:00,  1.30it/s]
100%|██████████| 10/10 [00:07<00:00,  1.34it/s]


In [3]:
train_target = train.pop("PercentageGrowth")
dev_target = dev.pop("PercentageGrowth")
test_target = test.pop("PercentageGrowth")
scaler = ScaleAbsOne()
train = scaler.fit_transform(train.values)
dev = scaler.transform(dev.values)
test = scaler.transform(test.values)

In [4]:
sub = torch.tensor(train[:150]).float()
subt = torch.tensor(train_target[:150].values).float()

In [5]:
def train(model, likelihood, training_iter, x, y):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        optimizer.zero_grad()
        output = model(x)
        loss = -mll(output, y)
        print(loss.item())
        loss.backward()
        optimizer.step()

In [6]:
class MultiLinearKernel(gpytorch.kernels.Kernel):
    is_stationary = False
    def __init__(self, length_prior=None, length_constraint=None, **kwargs):
        super().__init__(**kwargs)
        
        self.register_parameter(
            name='raw_length', 
            parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1, 1))
        )
        if length_constraint is None:
            length_constraint = Positive()
        self.register_constraint("raw_length", length_constraint)

        if length_prior is not None:
            self.register_prior(
                "length_prior",
                length_prior,
                lambda m: m.length,
                lambda m, v : m._set_length(v),
            )

    @property
    def length(self):
        return self.raw_length_constraint.transform(self.raw_length)

    @length.setter
    def length(self, value):
        return self._set_length(value)

    def _set_length(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_length)
        self.initialize(raw_length=self.raw_length_constraint.inverse_transform(value))

    def forward(self, x1, x2, **params):
        prod = torch.einsum("nd, md -> nmd", x1, x2)
        frac = (self.length**2 + prod) / (1 + self.length**2)
        return frac.prod(-1)
    
    
class MultiLinearGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = MultiLinearKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [8]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = MultiLinearGPModel(sub, subt, likelihood)

model.train()
likelihood.train()
train(model, likelihood, 1000, sub, subt)

model.eval()
likelihood.eval()

4466.27685546875
4149.9970703125
3863.368408203125
3603.886962890625
3369.152587890625
3156.885986328125
2964.94970703125
2791.355712890625
2634.272216796875
2492.019775390625
2363.072509765625
2246.047607421875
2139.700927734375
2042.9156494140625
1954.690185546875
1874.1331787109375
1800.4478759765625
1732.9251708984375
1670.9365234375
1613.9205322265625
1561.37939453125
1512.86962890625
1467.99755859375
1426.41162109375
1387.79931640625
1351.8817138671875
1318.409423828125
1287.16015625
1257.934326171875
1230.5531005859375
1204.85595703125
1180.6988525390625
1157.95166015625
1136.4976806640625
1116.231201171875
1097.0565185546875
1078.887451171875
1061.6458740234375
1045.2603759765625
1029.66650390625
1014.8057250976562
1000.6240844726562
987.0734252929688
974.1083374023438
961.688720703125
949.7771606445312
938.33935546875
927.3441772460938
916.7628173828125
906.5689697265625
896.7380981445312
887.2479248046875
878.0780029296875
869.2091064453125
860.6237182617188
852.305419921875


230.830810546875
230.4270782470703
230.02468872070312
229.62353515625
229.2237548828125
228.8251495361328
228.4278106689453
228.03182983398438
227.63697814941406
227.2434844970703
226.85116577148438
226.4601593017578
226.07028198242188
225.68167114257812
225.29429626464844
224.9081268310547
224.5231475830078
224.13941955566406
223.75685119628906
223.3754425048828
222.99525451660156
222.6162567138672
222.2383575439453
221.8617401123047
221.48622131347656
221.11190795898438
220.7387237548828
220.36663818359375
219.99569702148438
219.62591552734375
219.25721740722656
218.88970947265625
218.52328491210938
218.15798950195312
217.79380798339844
217.4307098388672
217.06869506835938
216.70782470703125
216.34800720214844
215.98924255371094
215.63162231445312
215.2750244140625
214.91952514648438
214.56507873535156
214.21168518066406
213.8593292236328
213.50804138183594
213.15780639648438
212.8086395263672
212.46044921875
212.1133270263672
211.76719665527344
211.42210388183594
211.07801818847656


125.96070098876953
125.82643127441406
125.69247436523438
125.55876159667969
125.42535400390625
125.29219818115234
125.15929412841797
125.02669525146484
124.89433288574219
124.76225280761719
124.63041687011719
124.49887084960938
124.36760711669922
124.23658752441406
124.10584259033203
123.97534942626953
123.84514617919922
123.71517944335938
123.58549499511719
123.45606994628906
123.32688903808594
123.19798278808594
123.06929779052734
122.94092559814453
122.81279754638672
122.68490600585938
122.55730438232422
122.42994689941406
122.30282592773438
122.17597961425781
122.04938507080078
121.92301940917969
121.79695129394531
121.67110443115234
121.54552459716797
121.42018127441406
121.29512786865234
121.17027282714844
121.04573059082031
120.92139434814453
120.79730224609375
120.67345428466797
120.54987335205078
120.4265365600586
120.30341339111328
120.18055725097656
120.05794525146484
119.9355697631836
119.81346130371094
119.69157409667969
119.56997680664062
119.44857025146484
119.3274230957

GaussianLikelihood(
  (noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)