In [1]:
import torch
import gpytorch
import numpy as np
from torch import nn
from torch import matmul as m
from gpytorch.kernels import RBFKernel, ScaleKernel
from torch.distributions import Normal, Poisson, MultivariateNormal


from torch.utils.tensorboard import SummaryWriter
from deep_fields.models.utils.basic_setups import create_dir_and_writer
from deep_fields.models.gaussian_processes.gaussian_processes import multivariate_normal, white_noise_kernel

In [10]:
locations_dimension = 2
prior_locations_mean =  0.
prior_locations_std = 6.69
number_of_realizations = 100
birth_intensity = 2.
kernel_sigma = 0.1
kernel_lenght_scales = [1., 2.]

#locations
locations_prior = Normal(torch.full((locations_dimension,), prior_locations_mean),
                         torch.full((locations_dimension,), prior_locations_std))
locations = locations_prior.sample(sample_shape=(100,))

In [3]:
class GPModel(nn.Module):
    
    def __init__(self,kernel_sigma,kernel_lenght_scales):
        nn.Module.__init__(self)        
        kernel = ScaleKernel(RBFKernel(ard_num_dims=locations_dimension, requires_grad=True),
                     requires_grad=True) + white_noise_kernel()
        kernel_hypers = {"raw_outputscale": torch.tensor(kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(kernel_lenght_scales)}
        kernel.kernels[0].initialize(**kernel_hypers)
        
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = kernel

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x).evaluate()
        return mean_x,covar_x
    
def log_likelihood(covar_x,f):
    covar_dim =  covar_x.shape[0]
    covar_x_inverse = covar_x.inverse()
    covar_det = (covar_x).det()
    
    log_inverse = m(f,m(covar_x_inverse,f))
    log_det = torch.log(covar_det)
    log_probability = -.5*(log_det + log_inverse)
    return log_probability

In [6]:
gp0 = GPModel(kernel_sigma = 3.,kernel_lenght_scales = [1., 1.])
mean_x, covar_x = gp0(locations)
f = MultivariateNormal(mean_x,covar_x).sample()
log_likelihood(covar_x,f)

tensor(-73.6307, grad_fn=<MulBackward0>)

In [7]:
gp1 = GPModel(kernel_sigma=1.,kernel_lenght_scales =[.1, .1])
mean_x, covar_x = gp1(locations)
log_likelihood(covar_x,f)

tensor(-106.9959, grad_fn=<MulBackward0>)

In [8]:
gp1 = GPModel(kernel_sigma=.1,kernel_lenght_scales=[.1, .1])
training_iter = 2000
gp1.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(gp1.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    mean_x, covar_x = gp1(locations)
    # Calc loss and backprop gradients
    loss = -log_likelihood(covar_x,f)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (
        i + 1, training_iter, loss.item()
    ))
    print([a.data for a in gp1.covar_module.parameters()])
    optimizer.step()

Iter 1/2000 - Loss: 155.804
[tensor(0.1000), tensor([[0.1000, 0.1000]])]
Iter 2/2000 - Loss: 154.558
[tensor(0.1100), tensor([[0.1100, 0.1100]])]
Iter 3/2000 - Loss: 153.325
[tensor(0.1200), tensor([[0.1200, 0.1200]])]
Iter 4/2000 - Loss: 152.106
[tensor(0.1300), tensor([[0.1300, 0.1300]])]
Iter 5/2000 - Loss: 150.900
[tensor(0.1400), tensor([[0.1400, 0.1400]])]
Iter 6/2000 - Loss: 149.709
[tensor(0.1499), tensor([[0.1500, 0.1500]])]
Iter 7/2000 - Loss: 148.532
[tensor(0.1599), tensor([[0.1600, 0.1600]])]
Iter 8/2000 - Loss: 147.368
[tensor(0.1698), tensor([[0.1700, 0.1700]])]
Iter 9/2000 - Loss: 146.219
[tensor(0.1798), tensor([[0.1800, 0.1799]])]
Iter 10/2000 - Loss: 145.084
[tensor(0.1897), tensor([[0.1900, 0.1899]])]
Iter 11/2000 - Loss: 143.964
[tensor(0.1995), tensor([[0.1999, 0.1999]])]
Iter 12/2000 - Loss: 142.858
[tensor(0.2094), tensor([[0.2099, 0.2098]])]
Iter 13/2000 - Loss: 141.767
[tensor(0.2192), tensor([[0.2199, 0.2198]])]
Iter 14/2000 - Loss: 140.691
[tensor(0.2290), t

[tensor(1.0219), tensor([[0.8638, 0.8079]])]
Iter 120/2000 - Loss: 90.997
[tensor(1.0277), tensor([[0.8647, 0.8086]])]
Iter 121/2000 - Loss: 90.837
[tensor(1.0335), tensor([[0.8656, 0.8093]])]
Iter 122/2000 - Loss: 90.678
[tensor(1.0392), tensor([[0.8665, 0.8100]])]
Iter 123/2000 - Loss: 90.522
[tensor(1.0450), tensor([[0.8674, 0.8108]])]
Iter 124/2000 - Loss: 90.367
[tensor(1.0507), tensor([[0.8682, 0.8115]])]
Iter 125/2000 - Loss: 90.214
[tensor(1.0564), tensor([[0.8690, 0.8121]])]
Iter 126/2000 - Loss: 90.063
[tensor(1.0620), tensor([[0.8698, 0.8128]])]
Iter 127/2000 - Loss: 89.914
[tensor(1.0676), tensor([[0.8706, 0.8135]])]
Iter 128/2000 - Loss: 89.766
[tensor(1.0732), tensor([[0.8714, 0.8142]])]
Iter 129/2000 - Loss: 89.620
[tensor(1.0788), tensor([[0.8721, 0.8149]])]
Iter 130/2000 - Loss: 89.476
[tensor(1.0844), tensor([[0.8729, 0.8156]])]
Iter 131/2000 - Loss: 89.334
[tensor(1.0899), tensor([[0.8736, 0.8163]])]
Iter 132/2000 - Loss: 89.193
[tensor(1.0954), tensor([[0.8743, 0.81

Iter 230/2000 - Loss: 80.648
[tensor(1.5391), tensor([[0.9314, 0.8956]])]
Iter 231/2000 - Loss: 80.596
[tensor(1.5429), tensor([[0.9319, 0.8963]])]
Iter 232/2000 - Loss: 80.544
[tensor(1.5466), tensor([[0.9323, 0.8971]])]
Iter 233/2000 - Loss: 80.492
[tensor(1.5504), tensor([[0.9327, 0.8978]])]
Iter 234/2000 - Loss: 80.441
[tensor(1.5541), tensor([[0.9332, 0.8985]])]
Iter 235/2000 - Loss: 80.390
[tensor(1.5578), tensor([[0.9336, 0.8992]])]
Iter 236/2000 - Loss: 80.340
[tensor(1.5615), tensor([[0.9340, 0.8999]])]
Iter 237/2000 - Loss: 80.290
[tensor(1.5652), tensor([[0.9344, 0.9006]])]
Iter 238/2000 - Loss: 80.241
[tensor(1.5689), tensor([[0.9348, 0.9013]])]
Iter 239/2000 - Loss: 80.192
[tensor(1.5725), tensor([[0.9353, 0.9020]])]
Iter 240/2000 - Loss: 80.143
[tensor(1.5762), tensor([[0.9357, 0.9027]])]
Iter 241/2000 - Loss: 80.095
[tensor(1.5798), tensor([[0.9361, 0.9034]])]
Iter 242/2000 - Loss: 80.047
[tensor(1.5834), tensor([[0.9365, 0.9041]])]
Iter 243/2000 - Loss: 80.000
[tensor(1

Iter 348/2000 - Loss: 76.632
[tensor(1.9115), tensor([[0.9717, 0.9667]])]
Iter 349/2000 - Loss: 76.611
[tensor(1.9141), tensor([[0.9720, 0.9672]])]
Iter 350/2000 - Loss: 76.590
[tensor(1.9168), tensor([[0.9722, 0.9677]])]
Iter 351/2000 - Loss: 76.569
[tensor(1.9194), tensor([[0.9725, 0.9682]])]
Iter 352/2000 - Loss: 76.549
[tensor(1.9220), tensor([[0.9728, 0.9687]])]
Iter 353/2000 - Loss: 76.528
[tensor(1.9246), tensor([[0.9730, 0.9692]])]
Iter 354/2000 - Loss: 76.507
[tensor(1.9272), tensor([[0.9733, 0.9697]])]
Iter 355/2000 - Loss: 76.487
[tensor(1.9298), tensor([[0.9736, 0.9702]])]
Iter 356/2000 - Loss: 76.467
[tensor(1.9324), tensor([[0.9738, 0.9706]])]
Iter 357/2000 - Loss: 76.447
[tensor(1.9350), tensor([[0.9741, 0.9711]])]
Iter 358/2000 - Loss: 76.427
[tensor(1.9376), tensor([[0.9743, 0.9716]])]
Iter 359/2000 - Loss: 76.407
[tensor(1.9402), tensor([[0.9746, 0.9721]])]
Iter 360/2000 - Loss: 76.388
[tensor(1.9428), tensor([[0.9749, 0.9726]])]
Iter 361/2000 - Loss: 76.368
[tensor(1

Iter 475/2000 - Loss: 74.801
[tensor(2.1988), tensor([[0.9995, 1.0198]])]
Iter 476/2000 - Loss: 74.792
[tensor(2.2008), tensor([[0.9996, 1.0202]])]
Iter 477/2000 - Loss: 74.782
[tensor(2.2027), tensor([[0.9998, 1.0205]])]
Iter 478/2000 - Loss: 74.773
[tensor(2.2046), tensor([[1.0000, 1.0209]])]
Iter 479/2000 - Loss: 74.764
[tensor(2.2065), tensor([[1.0002, 1.0212]])]
Iter 480/2000 - Loss: 74.754
[tensor(2.2084), tensor([[1.0003, 1.0216]])]
Iter 481/2000 - Loss: 74.745
[tensor(2.2103), tensor([[1.0005, 1.0219]])]
Iter 482/2000 - Loss: 74.736
[tensor(2.2122), tensor([[1.0007, 1.0222]])]
Iter 483/2000 - Loss: 74.727
[tensor(2.2141), tensor([[1.0009, 1.0226]])]
Iter 484/2000 - Loss: 74.718
[tensor(2.2160), tensor([[1.0010, 1.0229]])]
Iter 485/2000 - Loss: 74.709
[tensor(2.2179), tensor([[1.0012, 1.0233]])]
Iter 486/2000 - Loss: 74.700
[tensor(2.2198), tensor([[1.0014, 1.0236]])]
Iter 487/2000 - Loss: 74.691
[tensor(2.2217), tensor([[1.0016, 1.0239]])]
Iter 488/2000 - Loss: 74.683
[tensor(2

Iter 599/2000 - Loss: 73.962
[tensor(2.4073), tensor([[1.0180, 1.0570]])]
Iter 600/2000 - Loss: 73.958
[tensor(2.4087), tensor([[1.0182, 1.0572]])]
Iter 601/2000 - Loss: 73.953
[tensor(2.4102), tensor([[1.0183, 1.0575]])]
Iter 602/2000 - Loss: 73.948
[tensor(2.4117), tensor([[1.0184, 1.0577]])]
Iter 603/2000 - Loss: 73.944
[tensor(2.4131), tensor([[1.0185, 1.0580]])]
Iter 604/2000 - Loss: 73.939
[tensor(2.4145), tensor([[1.0187, 1.0582]])]
Iter 605/2000 - Loss: 73.935
[tensor(2.4160), tensor([[1.0188, 1.0585]])]
Iter 606/2000 - Loss: 73.930
[tensor(2.4174), tensor([[1.0189, 1.0587]])]
Iter 607/2000 - Loss: 73.925
[tensor(2.4189), tensor([[1.0190, 1.0590]])]
Iter 608/2000 - Loss: 73.921
[tensor(2.4203), tensor([[1.0192, 1.0592]])]
Iter 609/2000 - Loss: 73.916
[tensor(2.4217), tensor([[1.0193, 1.0595]])]
Iter 610/2000 - Loss: 73.912
[tensor(2.4232), tensor([[1.0194, 1.0597]])]
Iter 611/2000 - Loss: 73.908
[tensor(2.4246), tensor([[1.0195, 1.0600]])]
Iter 612/2000 - Loss: 73.903
[tensor(2

[tensor(2.5549), tensor([[1.0305, 1.0825]])]
Iter 714/2000 - Loss: 73.559
[tensor(2.5560), tensor([[1.0306, 1.0827]])]
Iter 715/2000 - Loss: 73.556
[tensor(2.5572), tensor([[1.0307, 1.0829]])]
Iter 716/2000 - Loss: 73.554
[tensor(2.5583), tensor([[1.0308, 1.0831]])]
Iter 717/2000 - Loss: 73.551
[tensor(2.5595), tensor([[1.0309, 1.0833]])]
Iter 718/2000 - Loss: 73.549
[tensor(2.5606), tensor([[1.0309, 1.0834]])]
Iter 719/2000 - Loss: 73.546
[tensor(2.5617), tensor([[1.0310, 1.0836]])]
Iter 720/2000 - Loss: 73.544
[tensor(2.5628), tensor([[1.0311, 1.0838]])]
Iter 721/2000 - Loss: 73.541
[tensor(2.5640), tensor([[1.0312, 1.0840]])]
Iter 722/2000 - Loss: 73.539
[tensor(2.5651), tensor([[1.0313, 1.0842]])]
Iter 723/2000 - Loss: 73.536
[tensor(2.5662), tensor([[1.0314, 1.0844]])]
Iter 724/2000 - Loss: 73.534
[tensor(2.5673), tensor([[1.0315, 1.0846]])]
Iter 725/2000 - Loss: 73.532
[tensor(2.5685), tensor([[1.0316, 1.0848]])]
Iter 726/2000 - Loss: 73.529
[tensor(2.5696), tensor([[1.0317, 1.08

Iter 829/2000 - Loss: 73.337
[tensor(2.6727), tensor([[1.0400, 1.1023]])]
Iter 830/2000 - Loss: 73.335
[tensor(2.6735), tensor([[1.0401, 1.1025]])]
Iter 831/2000 - Loss: 73.334
[tensor(2.6744), tensor([[1.0402, 1.1026]])]
Iter 832/2000 - Loss: 73.333
[tensor(2.6753), tensor([[1.0402, 1.1028]])]
Iter 833/2000 - Loss: 73.332
[tensor(2.6762), tensor([[1.0403, 1.1029]])]
Iter 834/2000 - Loss: 73.330
[tensor(2.6771), tensor([[1.0404, 1.1031]])]
Iter 835/2000 - Loss: 73.329
[tensor(2.6780), tensor([[1.0405, 1.1032]])]
Iter 836/2000 - Loss: 73.327
[tensor(2.6789), tensor([[1.0405, 1.1034]])]
Iter 837/2000 - Loss: 73.326
[tensor(2.6798), tensor([[1.0406, 1.1035]])]
Iter 838/2000 - Loss: 73.325
[tensor(2.6806), tensor([[1.0407, 1.1037]])]
Iter 839/2000 - Loss: 73.323
[tensor(2.6815), tensor([[1.0407, 1.1038]])]
Iter 840/2000 - Loss: 73.322
[tensor(2.6824), tensor([[1.0408, 1.1039]])]
Iter 841/2000 - Loss: 73.320
[tensor(2.6833), tensor([[1.0409, 1.1041]])]
Iter 842/2000 - Loss: 73.319
[tensor(2

[tensor(2.7690), tensor([[1.0476, 1.1182]])]
Iter 952/2000 - Loss: 73.207
[tensor(2.7697), tensor([[1.0477, 1.1183]])]
Iter 953/2000 - Loss: 73.207
[tensor(2.7704), tensor([[1.0477, 1.1184]])]
Iter 954/2000 - Loss: 73.206
[tensor(2.7711), tensor([[1.0478, 1.1186]])]
Iter 955/2000 - Loss: 73.205
[tensor(2.7718), tensor([[1.0478, 1.1187]])]
Iter 956/2000 - Loss: 73.205
[tensor(2.7725), tensor([[1.0479, 1.1188]])]
Iter 957/2000 - Loss: 73.204
[tensor(2.7732), tensor([[1.0479, 1.1189]])]
Iter 958/2000 - Loss: 73.203
[tensor(2.7739), tensor([[1.0480, 1.1190]])]
Iter 959/2000 - Loss: 73.202
[tensor(2.7745), tensor([[1.0480, 1.1191]])]
Iter 960/2000 - Loss: 73.201
[tensor(2.7752), tensor([[1.0481, 1.1192]])]
Iter 961/2000 - Loss: 73.201
[tensor(2.7759), tensor([[1.0481, 1.1193]])]
Iter 962/2000 - Loss: 73.200
[tensor(2.7766), tensor([[1.0482, 1.1194]])]
Iter 963/2000 - Loss: 73.199
[tensor(2.7772), tensor([[1.0482, 1.1196]])]
Iter 964/2000 - Loss: 73.199
[tensor(2.7779), tensor([[1.0483, 1.11

Iter 1070/2000 - Loss: 73.140
[tensor(2.8416), tensor([[1.0532, 1.1300]])]
Iter 1071/2000 - Loss: 73.140
[tensor(2.8422), tensor([[1.0532, 1.1301]])]
Iter 1072/2000 - Loss: 73.139
[tensor(2.8427), tensor([[1.0533, 1.1302]])]
Iter 1073/2000 - Loss: 73.139
[tensor(2.8432), tensor([[1.0533, 1.1302]])]
Iter 1074/2000 - Loss: 73.139
[tensor(2.8438), tensor([[1.0533, 1.1303]])]
Iter 1075/2000 - Loss: 73.138
[tensor(2.8443), tensor([[1.0534, 1.1304]])]
Iter 1076/2000 - Loss: 73.138
[tensor(2.8448), tensor([[1.0534, 1.1305]])]
Iter 1077/2000 - Loss: 73.137
[tensor(2.8453), tensor([[1.0535, 1.1306]])]
Iter 1078/2000 - Loss: 73.137
[tensor(2.8459), tensor([[1.0535, 1.1307]])]
Iter 1079/2000 - Loss: 73.137
[tensor(2.8464), tensor([[1.0535, 1.1307]])]
Iter 1080/2000 - Loss: 73.136
[tensor(2.8469), tensor([[1.0536, 1.1308]])]
Iter 1081/2000 - Loss: 73.136
[tensor(2.8474), tensor([[1.0536, 1.1309]])]
Iter 1082/2000 - Loss: 73.136
[tensor(2.8479), tensor([[1.0537, 1.1310]])]
Iter 1083/2000 - Loss: 73

Iter 1190/2000 - Loss: 73.104
[tensor(2.8977), tensor([[1.0574, 1.1389]])]
Iter 1191/2000 - Loss: 73.103
[tensor(2.8981), tensor([[1.0575, 1.1390]])]
Iter 1192/2000 - Loss: 73.103
[tensor(2.8985), tensor([[1.0575, 1.1391]])]
Iter 1193/2000 - Loss: 73.103
[tensor(2.8989), tensor([[1.0575, 1.1391]])]
Iter 1194/2000 - Loss: 73.103
[tensor(2.8993), tensor([[1.0575, 1.1392]])]
Iter 1195/2000 - Loss: 73.102
[tensor(2.8997), tensor([[1.0576, 1.1393]])]
Iter 1196/2000 - Loss: 73.102
[tensor(2.9002), tensor([[1.0576, 1.1393]])]
Iter 1197/2000 - Loss: 73.102
[tensor(2.9006), tensor([[1.0576, 1.1394]])]
Iter 1198/2000 - Loss: 73.102
[tensor(2.9009), tensor([[1.0577, 1.1395]])]
Iter 1199/2000 - Loss: 73.102
[tensor(2.9013), tensor([[1.0577, 1.1395]])]
Iter 1200/2000 - Loss: 73.101
[tensor(2.9017), tensor([[1.0577, 1.1396]])]
Iter 1201/2000 - Loss: 73.101
[tensor(2.9021), tensor([[1.0578, 1.1396]])]
Iter 1202/2000 - Loss: 73.101
[tensor(2.9025), tensor([[1.0578, 1.1397]])]
Iter 1203/2000 - Loss: 73

Iter 1304/2000 - Loss: 73.085
[tensor(2.9383), tensor([[1.0604, 1.1454]])]
Iter 1305/2000 - Loss: 73.085
[tensor(2.9386), tensor([[1.0605, 1.1454]])]
Iter 1306/2000 - Loss: 73.085
[tensor(2.9389), tensor([[1.0605, 1.1454]])]
Iter 1307/2000 - Loss: 73.085
[tensor(2.9392), tensor([[1.0605, 1.1455]])]
Iter 1308/2000 - Loss: 73.085
[tensor(2.9395), tensor([[1.0605, 1.1455]])]
Iter 1309/2000 - Loss: 73.084
[tensor(2.9398), tensor([[1.0606, 1.1456]])]
Iter 1310/2000 - Loss: 73.085
[tensor(2.9401), tensor([[1.0606, 1.1456]])]
Iter 1311/2000 - Loss: 73.084
[tensor(2.9404), tensor([[1.0606, 1.1457]])]
Iter 1312/2000 - Loss: 73.084
[tensor(2.9407), tensor([[1.0606, 1.1457]])]
Iter 1313/2000 - Loss: 73.084
[tensor(2.9410), tensor([[1.0606, 1.1458]])]
Iter 1314/2000 - Loss: 73.084
[tensor(2.9413), tensor([[1.0607, 1.1458]])]
Iter 1315/2000 - Loss: 73.084
[tensor(2.9416), tensor([[1.0607, 1.1459]])]
Iter 1316/2000 - Loss: 73.084
[tensor(2.9419), tensor([[1.0607, 1.1459]])]
Iter 1317/2000 - Loss: 73

Iter 1424/2000 - Loss: 73.075
[tensor(2.9702), tensor([[1.0628, 1.1504]])]
Iter 1425/2000 - Loss: 73.075
[tensor(2.9705), tensor([[1.0628, 1.1504]])]
Iter 1426/2000 - Loss: 73.075
[tensor(2.9707), tensor([[1.0628, 1.1504]])]
Iter 1427/2000 - Loss: 73.075
[tensor(2.9709), tensor([[1.0628, 1.1505]])]
Iter 1428/2000 - Loss: 73.075
[tensor(2.9711), tensor([[1.0629, 1.1505]])]
Iter 1429/2000 - Loss: 73.075
[tensor(2.9714), tensor([[1.0629, 1.1505]])]
Iter 1430/2000 - Loss: 73.075
[tensor(2.9716), tensor([[1.0629, 1.1506]])]
Iter 1431/2000 - Loss: 73.075
[tensor(2.9718), tensor([[1.0629, 1.1506]])]
Iter 1432/2000 - Loss: 73.074
[tensor(2.9720), tensor([[1.0629, 1.1506]])]
Iter 1433/2000 - Loss: 73.075
[tensor(2.9722), tensor([[1.0629, 1.1507]])]
Iter 1434/2000 - Loss: 73.075
[tensor(2.9725), tensor([[1.0630, 1.1507]])]
Iter 1435/2000 - Loss: 73.075
[tensor(2.9727), tensor([[1.0630, 1.1508]])]
Iter 1436/2000 - Loss: 73.074
[tensor(2.9729), tensor([[1.0630, 1.1508]])]
Iter 1437/2000 - Loss: 73

[tensor(2.9941), tensor([[1.0645, 1.1541]])]
Iter 1549/2000 - Loss: 73.070
[tensor(2.9943), tensor([[1.0645, 1.1541]])]
Iter 1550/2000 - Loss: 73.070
[tensor(2.9945), tensor([[1.0646, 1.1541]])]
Iter 1551/2000 - Loss: 73.070
[tensor(2.9946), tensor([[1.0646, 1.1542]])]
Iter 1552/2000 - Loss: 73.070
[tensor(2.9948), tensor([[1.0646, 1.1542]])]
Iter 1553/2000 - Loss: 73.070
[tensor(2.9949), tensor([[1.0646, 1.1542]])]
Iter 1554/2000 - Loss: 73.070
[tensor(2.9951), tensor([[1.0646, 1.1542]])]
Iter 1555/2000 - Loss: 73.070
[tensor(2.9953), tensor([[1.0646, 1.1543]])]
Iter 1556/2000 - Loss: 73.070
[tensor(2.9954), tensor([[1.0646, 1.1543]])]
Iter 1557/2000 - Loss: 73.070
[tensor(2.9956), tensor([[1.0646, 1.1543]])]
Iter 1558/2000 - Loss: 73.070
[tensor(2.9957), tensor([[1.0647, 1.1543]])]
Iter 1559/2000 - Loss: 73.070
[tensor(2.9959), tensor([[1.0647, 1.1544]])]
Iter 1560/2000 - Loss: 73.070
[tensor(2.9960), tensor([[1.0647, 1.1544]])]
Iter 1561/2000 - Loss: 73.070
[tensor(2.9962), tensor([

Iter 1672/2000 - Loss: 73.068
[tensor(3.0109), tensor([[1.0658, 1.1567]])]
Iter 1673/2000 - Loss: 73.068
[tensor(3.0110), tensor([[1.0658, 1.1567]])]
Iter 1674/2000 - Loss: 73.068
[tensor(3.0111), tensor([[1.0658, 1.1567]])]
Iter 1675/2000 - Loss: 73.068
[tensor(3.0113), tensor([[1.0658, 1.1567]])]
Iter 1676/2000 - Loss: 73.068
[tensor(3.0114), tensor([[1.0658, 1.1568]])]
Iter 1677/2000 - Loss: 73.067
[tensor(3.0115), tensor([[1.0658, 1.1568]])]
Iter 1678/2000 - Loss: 73.067
[tensor(3.0116), tensor([[1.0658, 1.1568]])]
Iter 1679/2000 - Loss: 73.068
[tensor(3.0117), tensor([[1.0658, 1.1568]])]
Iter 1680/2000 - Loss: 73.068
[tensor(3.0118), tensor([[1.0658, 1.1568]])]
Iter 1681/2000 - Loss: 73.068
[tensor(3.0119), tensor([[1.0658, 1.1568]])]
Iter 1682/2000 - Loss: 73.068
[tensor(3.0120), tensor([[1.0658, 1.1569]])]
Iter 1683/2000 - Loss: 73.068
[tensor(3.0121), tensor([[1.0659, 1.1569]])]
Iter 1684/2000 - Loss: 73.068
[tensor(3.0122), tensor([[1.0659, 1.1569]])]
Iter 1685/2000 - Loss: 73

[tensor(3.0223), tensor([[1.0666, 1.1585]])]
Iter 1796/2000 - Loss: 73.067
[tensor(3.0224), tensor([[1.0666, 1.1585]])]
Iter 1797/2000 - Loss: 73.067
[tensor(3.0224), tensor([[1.0666, 1.1585]])]
Iter 1798/2000 - Loss: 73.067
[tensor(3.0225), tensor([[1.0666, 1.1585]])]
Iter 1799/2000 - Loss: 73.067
[tensor(3.0226), tensor([[1.0666, 1.1585]])]
Iter 1800/2000 - Loss: 73.067
[tensor(3.0227), tensor([[1.0666, 1.1585]])]
Iter 1801/2000 - Loss: 73.067
[tensor(3.0227), tensor([[1.0666, 1.1585]])]
Iter 1802/2000 - Loss: 73.067
[tensor(3.0228), tensor([[1.0666, 1.1585]])]
Iter 1803/2000 - Loss: 73.067
[tensor(3.0229), tensor([[1.0666, 1.1585]])]
Iter 1804/2000 - Loss: 73.067
[tensor(3.0230), tensor([[1.0666, 1.1586]])]
Iter 1805/2000 - Loss: 73.067
[tensor(3.0230), tensor([[1.0666, 1.1586]])]
Iter 1806/2000 - Loss: 73.067
[tensor(3.0231), tensor([[1.0666, 1.1586]])]
Iter 1807/2000 - Loss: 73.067
[tensor(3.0232), tensor([[1.0666, 1.1586]])]
Iter 1808/2000 - Loss: 73.067
[tensor(3.0233), tensor([

[tensor(3.0299), tensor([[1.0671, 1.1596]])]
Iter 1921/2000 - Loss: 73.066
[tensor(3.0300), tensor([[1.0672, 1.1596]])]
Iter 1922/2000 - Loss: 73.066
[tensor(3.0300), tensor([[1.0672, 1.1596]])]
Iter 1923/2000 - Loss: 73.066
[tensor(3.0301), tensor([[1.0672, 1.1596]])]
Iter 1924/2000 - Loss: 73.066
[tensor(3.0301), tensor([[1.0672, 1.1596]])]
Iter 1925/2000 - Loss: 73.066
[tensor(3.0302), tensor([[1.0672, 1.1597]])]
Iter 1926/2000 - Loss: 73.066
[tensor(3.0302), tensor([[1.0672, 1.1597]])]
Iter 1927/2000 - Loss: 73.066
[tensor(3.0303), tensor([[1.0672, 1.1597]])]
Iter 1928/2000 - Loss: 73.066
[tensor(3.0303), tensor([[1.0672, 1.1597]])]
Iter 1929/2000 - Loss: 73.066
[tensor(3.0304), tensor([[1.0672, 1.1597]])]
Iter 1930/2000 - Loss: 73.066
[tensor(3.0304), tensor([[1.0672, 1.1597]])]
Iter 1931/2000 - Loss: 73.066
[tensor(3.0305), tensor([[1.0672, 1.1597]])]
Iter 1932/2000 - Loss: 73.066
[tensor(3.0305), tensor([[1.0672, 1.1597]])]
Iter 1933/2000 - Loss: 73.066
[tensor(3.0305), tensor([

In [None]:
class MertonBirthKernel(nn.Module):

    def __init__(self, **kwargs):
        nn.Module.__init__(self)
        self.locations_dimension = kwargs.get("locations_dimension")
        self.define_deep_parameters()
        
    def define_kernel(self):
        kernel = ScaleKernel(RBFKernel(ard_num_dims=self.locations_dimension, requires_grad=True),
                             requires_grad=True) + white_noise_kernel()

        kernel_hypers = {"raw_outputscale": torch.tensor(self.kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(self.kernel_lenght_scales)}

        kernel.kernels[0].initialize(**kernel_hypers)
        kernel_eval = lambda locations: kernel(locations, locations).evaluate().type(torch.float64)

        return kernel, kernel_eval

    def forward(self, locations_history):
        kernel, kernel_eval = self.define_kernel(kernel_sigma,kernel_lenght_scales)
        covariance_diffusion_history = kernel_eval(locations_history)
        return covariance_diffusion_history

    def define_deep_parameters(self):
        self.kernel_sigma = nn.Parameter(torch.Tensor([1.]))
        self.kernel_lenght_scales = nn.Parameter(torch.Tensor([20.,30.]))