In [1]:
from data import load, split, ScaleAbsOne
from gpytorch.constraints import Positive
import torch
import math
import gpytorch
from matplotlib import pyplot as plt

In [2]:
df = load()
train, dev, test = split(df, "new_dose-response_matrices", inner_fold=1, outer_fold=1)

100%|██████████| 10/10 [00:06<00:00,  1.45it/s]
100%|██████████| 10/10 [00:06<00:00,  1.49it/s]


In [3]:
train_target = train.pop("PercentageGrowth")
dev_target = dev.pop("PercentageGrowth")
test_target = test.pop("PercentageGrowth")
scaler = ScaleAbsOne()
train = scaler.fit_transform(train.values)
dev = scaler.transform(dev.values)
test = scaler.transform(test.values)

In [4]:
sub = torch.tensor(train[:150]).float()
subt = torch.tensor(train_target[:150].values).float()

In [6]:
# sub = torch.randn(150, 400)
# subt = torch.randn()

In [15]:
def train(model, likelihood, training_iter, x, y):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iter):
        optimizer.zero_grad()
        output = model(x)
        loss = -mll(output, y)
        print(loss.item())
        loss.backward()
        optimizer.step()

In [22]:
class MultiLinearKernel(gpytorch.kernels.Kernel):
    is_stationary = False
    def __init__(self, length_prior=None, length_constraint=None, **kwargs):
        super().__init__(**kwargs)
        
        self.register_parameter(
            name='raw_length', 
            parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1, 1, 400) * 100)
        )
        if length_constraint is None:
            length_constraint = Positive()
        self.register_constraint("raw_length", length_constraint)

        if length_prior is not None:
            self.register_prior(
                "length_prior",
                length_prior,
                lambda m: m.length,
                lambda m, v : m._set_length(v),
            )

    @property
    def length(self):
        return self.raw_length_constraint.transform(self.raw_length)

    @length.setter
    def length(self, value):
        return self._set_length(value)

    def _set_length(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_length)
        self.initialize(raw_length=self.raw_length_constraint.inverse_transform(value))

    def forward(self, x1, x2, **params):
        prod = torch.einsum("nd, md -> nmd", x1, x2)
        frac = (self.length**2 + prod) / (1 + self.length**2)
        return frac.prod(-1)
    
    
class MultiLinearGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = MultiLinearKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [25]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = MultiLinearGPModel(sub, subt, likelihood)

model.train()
likelihood.train()
train(model, likelihood, 1000, sub, subt)

model.eval()
likelihood.eval()

1286.380126953125
1199.601318359375
1120.7431640625
1049.15966796875
984.2476196289062
925.4046020507812
872.0755004882812
823.7339477539062
779.900146484375
740.1283569335938
704.010009765625
671.1754760742188
641.2890014648438
614.0504150390625
589.186767578125
566.4559936523438
545.6395874023438
526.545166015625
508.9981384277344
492.8446960449219
477.947021484375
464.1829833984375
451.4430847167969
439.6289978027344
428.65374755859375
418.43994140625
408.9180908203125
400.0251159667969
391.7052307128906
383.90850830078125
376.5900573730469
369.7088928222656
363.2284240722656
357.1153869628906
351.34063720703125
345.8764953613281
340.69891357421875
335.7857666015625
331.116455078125
326.6732482910156
322.4388732910156
318.3988342285156
314.5384521484375
310.84564208984375
307.3087158203125
303.9167785644531
300.6604309082031
297.5303955078125
294.5188903808594
291.6182556152344
288.8211669921875
286.1217956542969
283.5140380859375
280.9920959472656
278.5514831542969
276.187255859375

101.27933502197266
101.1592788696289
101.03963470458984
100.9203872680664
100.80144500732422
100.68296813964844
100.56484985351562
100.44710540771484
100.32974243164062
100.21273803710938
100.09611511230469
99.97984313964844
99.86396026611328
99.74844360351562
99.63328552246094
99.51850128173828
99.404052734375
99.28999328613281
99.17626953125
99.06291961669922
98.94989776611328
98.83726501464844
98.72493743896484
98.61299133300781
98.50135803222656
98.39010620117188
98.27919006347656
98.1686019897461
98.05836486816406
97.94847106933594
97.8388900756836
97.72966003417969
97.62080383300781
97.51221466064453
97.40399932861328
97.29609680175781
97.18849182128906
97.0812759399414
96.97431945800781
96.86774444580078
96.76145935058594
96.65548706054688
96.54984283447266
96.44452667236328
96.3395004272461
96.23481750488281
96.13043212890625
96.0263442993164
95.92256164550781
95.81912994384766
95.71597290039062
95.61312866210938
95.5105972290039
95.40835571289062
95.3064193725586
95.2047805786

69.13056945800781
69.0898666381836
69.04924774169922
69.00870513916016
68.96822357177734
68.92780303955078
68.88745880126953
68.84716796875
68.80696105957031
68.76683044433594
68.72673797607422
68.68673706054688
68.64681243896484
68.6069107055664
68.56712341308594
68.52737426757812
68.48771667480469
68.44811248779297
68.40856170654297
68.36909484863281
68.32970428466797
68.29034423828125
68.2510757446289
68.21186828613281
68.17269897460938
68.13363647460938
68.09461975097656
68.05567169189453
68.01678466796875
67.97795867919922
67.939208984375
67.90050506591797
67.86186218261719
67.82330322265625
67.78479766845703
67.7463607788086
67.7079849243164
67.66966247558594
67.63141632080078
67.59321594238281
67.55509948730469
67.51702880859375
67.47903442382812
67.44107055664062
67.4031982421875
67.36536407470703
67.3276138305664
67.2899169921875
67.25227355957031
67.2146987915039
67.17717742919922
67.13972473144531
67.10233306884766
67.06499481201172
67.02771759033203
66.9905014038086
66.9533

GaussianLikelihood(
  (noise_covar): HomoskedasticNoise(
    (raw_noise_constraint): GreaterThan(1.000E-04)
  )
)

In [27]:
model.covar_module.length

tensor([[[100.6320, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623,  95.6115, 129.4643, 117.3705,
          102.3819, 129.8388, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 131.0598, 128.2163, 130.0615, 129.1414, 128.6237, 131.7781,
          131.1822, 116.7076, 112.7965, 105.9786, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 100.6947, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623, 128.0623,
          128.0623, 128.0623, 128.0623, 