In [36]:
import torch
import gpytorch
import numpy as np
from torch import nn
from torch import matmul as m
from gpytorch.kernels import RBFKernel, ScaleKernel
from torch.distributions import Normal, Poisson, MultivariateNormal


from torch.utils.tensorboard import SummaryWriter
from deep_fields.models.utils.basic_setups import create_dir_and_writer
from deep_fields.models.gaussian_processes.gaussian_processes import multivariate_normal, white_noise_kernel

In [80]:
locations_dimension = 2
prior_locations_mean =  0.
prior_locations_std = 6.69
number_of_realizations = 100
birth_intensity = 2.
kernel_sigma = 0.1
kernel_lenght_scales = [1., 2.]

locations_prior = Normal(torch.full((locations_dimension,), prior_locations_mean),
                         torch.full((locations_dimension,), prior_locations_std))
locations = locations_prior.sample(sample_shape=(100,))

In [242]:
class GPModel(nn.Module):
    
    def __init__(self,kernel_sigma,kernel_lenght_scales):
        nn.Module.__init__(self)        
        kernel = ScaleKernel(RBFKernel(ard_num_dims=locations_dimension, requires_grad=True),
                     requires_grad=True) + white_noise_kernel()

        kernel_hypers = {"raw_outputscale": torch.tensor(kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(kernel_lenght_scales)}

        kernel.kernels[0].initialize(**kernel_hypers)    
        self.mean_module = gpytorch.means.ConstantMean()
        
        self.covar_module = kernel

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x).evaluate()
        return mean_x,covar_x
    
def log_likelihood(covar_x,f):
    covar_dim =  covar_x.shape[0]
    covar_x_inverse = covar_x.inverse()
    covar_det = (covar_x).det()
    
    log_inverse = m(f,m(covar_x_inverse,f))
    log_det = torch.log(covar_det)
    log_probability = -.5*(log_det + log_inverse)
    return log_probability

In [287]:
gp0 = GPModel(kernel_sigma = 3.,kernel_lenght_scales = [1., 1.])
mean_x, covar_x = gp0(locations)
f = MultivariateNormal(mean_x,covar_x).sample()
log_likelihood(covar_x,f)

tensor(-62.5735, grad_fn=<MulBackward0>)

In [288]:
gp1 = GPModel(kernel_sigma=1.,kernel_lenght_scales =[.1, .1])
mean_x, covar_x = gp1(locations)
log_likelihood(covar_x,f)

tensor(-100.9959, grad_fn=<MulBackward0>)

In [289]:
gp1 = GPModel(kernel_sigma=.1,kernel_lenght_scales=[.1, .1])
training_iter = 2000
gp1.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(gp1.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    mean_x, covar_x = gp1(locations)
    # Calc loss and backprop gradients
    loss = -log_likelihood(covar_x,f)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (
        i + 1, training_iter, loss.item()
    ))
    print([a.data for a in gp1.covar_module.parameters()])
    optimizer.step()

Iter 1/2000 - Loss: 149.982
[tensor(0.1000), tensor([[0.1000, 0.1000]])]
Iter 2/2000 - Loss: 148.564
[tensor(0.1100), tensor([[0.1100, 0.1100]])]
Iter 3/2000 - Loss: 147.164
[tensor(0.1200), tensor([[0.1200, 0.1200]])]
Iter 4/2000 - Loss: 145.783
[tensor(0.1300), tensor([[0.1300, 0.1300]])]
Iter 5/2000 - Loss: 144.420
[tensor(0.1400), tensor([[0.1400, 0.1400]])]
Iter 6/2000 - Loss: 143.076
[tensor(0.1499), tensor([[0.1500, 0.1500]])]
Iter 7/2000 - Loss: 141.751
[tensor(0.1599), tensor([[0.1599, 0.1599]])]
Iter 8/2000 - Loss: 140.445
[tensor(0.1698), tensor([[0.1699, 0.1699]])]
Iter 9/2000 - Loss: 139.158
[tensor(0.1797), tensor([[0.1798, 0.1798]])]
Iter 10/2000 - Loss: 137.890
[tensor(0.1896), tensor([[0.1897, 0.1898]])]
Iter 11/2000 - Loss: 136.640
[tensor(0.1995), tensor([[0.1996, 0.1997]])]
Iter 12/2000 - Loss: 135.410
[tensor(0.2093), tensor([[0.2095, 0.2096]])]
Iter 13/2000 - Loss: 134.200
[tensor(0.2192), tensor([[0.2194, 0.2195]])]
Iter 14/2000 - Loss: 133.008
[tensor(0.2289), t

[tensor(1.0933), tensor([[0.7638, 0.8070]])]
Iter 132/2000 - Loss: 79.375
[tensor(1.0988), tensor([[0.7645, 0.8076]])]
Iter 133/2000 - Loss: 79.233
[tensor(1.1043), tensor([[0.7653, 0.8082]])]
Iter 134/2000 - Loss: 79.092
[tensor(1.1098), tensor([[0.7660, 0.8088]])]
Iter 135/2000 - Loss: 78.954
[tensor(1.1153), tensor([[0.7668, 0.8094]])]
Iter 136/2000 - Loss: 78.817
[tensor(1.1208), tensor([[0.7676, 0.8099]])]
Iter 137/2000 - Loss: 78.681
[tensor(1.1262), tensor([[0.7684, 0.8105]])]
Iter 138/2000 - Loss: 78.547
[tensor(1.1316), tensor([[0.7692, 0.8111]])]
Iter 139/2000 - Loss: 78.414
[tensor(1.1370), tensor([[0.7700, 0.8116]])]
Iter 140/2000 - Loss: 78.284
[tensor(1.1423), tensor([[0.7708, 0.8122]])]
Iter 141/2000 - Loss: 78.154
[tensor(1.1477), tensor([[0.7717, 0.8128]])]
Iter 142/2000 - Loss: 78.026
[tensor(1.1530), tensor([[0.7725, 0.8133]])]
Iter 143/2000 - Loss: 77.899
[tensor(1.1583), tensor([[0.7733, 0.8139]])]
Iter 144/2000 - Loss: 77.774
[tensor(1.1635), tensor([[0.7742, 0.81

[tensor(1.7353), tensor([[0.8667, 0.8778]])]
Iter 284/2000 - Loss: 68.220
[tensor(1.7386), tensor([[0.8672, 0.8781]])]
Iter 285/2000 - Loss: 68.185
[tensor(1.7418), tensor([[0.8677, 0.8785]])]
Iter 286/2000 - Loss: 68.150
[tensor(1.7450), tensor([[0.8682, 0.8788]])]
Iter 287/2000 - Loss: 68.115
[tensor(1.7483), tensor([[0.8687, 0.8791]])]
Iter 288/2000 - Loss: 68.080
[tensor(1.7515), tensor([[0.8691, 0.8795]])]
Iter 289/2000 - Loss: 68.046
[tensor(1.7547), tensor([[0.8696, 0.8798]])]
Iter 290/2000 - Loss: 68.013
[tensor(1.7579), tensor([[0.8701, 0.8801]])]
Iter 291/2000 - Loss: 67.979
[tensor(1.7611), tensor([[0.8706, 0.8805]])]
Iter 292/2000 - Loss: 67.946
[tensor(1.7642), tensor([[0.8710, 0.8808]])]
Iter 293/2000 - Loss: 67.913
[tensor(1.7674), tensor([[0.8715, 0.8811]])]
Iter 294/2000 - Loss: 67.879
[tensor(1.7705), tensor([[0.8720, 0.8815]])]
Iter 295/2000 - Loss: 67.847
[tensor(1.7737), tensor([[0.8724, 0.8818]])]
Iter 296/2000 - Loss: 67.814
[tensor(1.7768), tensor([[0.8729, 0.88

Iter 441/2000 - Loss: 64.844
[tensor(2.1542), tensor([[0.9265, 0.9197]])]
Iter 442/2000 - Loss: 64.831
[tensor(2.1563), tensor([[0.9268, 0.9199]])]
Iter 443/2000 - Loss: 64.819
[tensor(2.1585), tensor([[0.9271, 0.9201]])]
Iter 444/2000 - Loss: 64.807
[tensor(2.1606), tensor([[0.9274, 0.9203]])]
Iter 445/2000 - Loss: 64.795
[tensor(2.1628), tensor([[0.9277, 0.9205]])]
Iter 446/2000 - Loss: 64.783
[tensor(2.1649), tensor([[0.9280, 0.9207]])]
Iter 447/2000 - Loss: 64.770
[tensor(2.1671), tensor([[0.9283, 0.9209]])]
Iter 448/2000 - Loss: 64.759
[tensor(2.1692), tensor([[0.9286, 0.9211]])]
Iter 449/2000 - Loss: 64.746
[tensor(2.1713), tensor([[0.9288, 0.9213]])]
Iter 450/2000 - Loss: 64.735
[tensor(2.1734), tensor([[0.9291, 0.9215]])]
Iter 451/2000 - Loss: 64.723
[tensor(2.1756), tensor([[0.9294, 0.9217]])]
Iter 452/2000 - Loss: 64.711
[tensor(2.1777), tensor([[0.9297, 0.9219]])]
Iter 453/2000 - Loss: 64.700
[tensor(2.1798), tensor([[0.9300, 0.9221]])]
Iter 454/2000 - Loss: 64.688
[tensor(2

Iter 594/2000 - Loss: 63.575
[tensor(2.4344), tensor([[0.9632, 0.9454]])]
Iter 595/2000 - Loss: 63.569
[tensor(2.4360), tensor([[0.9634, 0.9455]])]
Iter 596/2000 - Loss: 63.563
[tensor(2.4375), tensor([[0.9636, 0.9456]])]
Iter 597/2000 - Loss: 63.559
[tensor(2.4391), tensor([[0.9638, 0.9458]])]
Iter 598/2000 - Loss: 63.554
[tensor(2.4406), tensor([[0.9640, 0.9459]])]
Iter 599/2000 - Loss: 63.548
[tensor(2.4421), tensor([[0.9642, 0.9461]])]
Iter 600/2000 - Loss: 63.543
[tensor(2.4437), tensor([[0.9644, 0.9462]])]
Iter 601/2000 - Loss: 63.538
[tensor(2.4452), tensor([[0.9646, 0.9463]])]
Iter 602/2000 - Loss: 63.532
[tensor(2.4467), tensor([[0.9648, 0.9465]])]
Iter 603/2000 - Loss: 63.528
[tensor(2.4482), tensor([[0.9650, 0.9466]])]
Iter 604/2000 - Loss: 63.523
[tensor(2.4498), tensor([[0.9652, 0.9467]])]
Iter 605/2000 - Loss: 63.518
[tensor(2.4513), tensor([[0.9653, 0.9469]])]
Iter 606/2000 - Loss: 63.513
[tensor(2.4528), tensor([[0.9655, 0.9470]])]
Iter 607/2000 - Loss: 63.507
[tensor(2

[tensor(2.6488), tensor([[0.9896, 0.9638]])]
Iter 758/2000 - Loss: 62.985
[tensor(2.6499), tensor([[0.9897, 0.9639]])]
Iter 759/2000 - Loss: 62.983
[tensor(2.6510), tensor([[0.9899, 0.9640]])]
Iter 760/2000 - Loss: 62.981
[tensor(2.6521), tensor([[0.9900, 0.9641]])]
Iter 761/2000 - Loss: 62.978
[tensor(2.6532), tensor([[0.9901, 0.9642]])]
Iter 762/2000 - Loss: 62.976
[tensor(2.6543), tensor([[0.9903, 0.9643]])]
Iter 763/2000 - Loss: 62.974
[tensor(2.6554), tensor([[0.9904, 0.9644]])]
Iter 764/2000 - Loss: 62.972
[tensor(2.6565), tensor([[0.9905, 0.9644]])]
Iter 765/2000 - Loss: 62.969
[tensor(2.6576), tensor([[0.9906, 0.9645]])]
Iter 766/2000 - Loss: 62.967
[tensor(2.6587), tensor([[0.9908, 0.9646]])]
Iter 767/2000 - Loss: 62.965
[tensor(2.6598), tensor([[0.9909, 0.9647]])]
Iter 768/2000 - Loss: 62.963
[tensor(2.6609), tensor([[0.9910, 0.9648]])]
Iter 769/2000 - Loss: 62.960
[tensor(2.6620), tensor([[0.9912, 0.9649]])]
Iter 770/2000 - Loss: 62.959
[tensor(2.6630), tensor([[0.9913, 0.96

Iter 927/2000 - Loss: 62.718
[tensor(2.8091), tensor([[1.0084, 0.9769]])]
Iter 928/2000 - Loss: 62.716
[tensor(2.8099), tensor([[1.0085, 0.9770]])]
Iter 929/2000 - Loss: 62.716
[tensor(2.8107), tensor([[1.0086, 0.9771]])]
Iter 930/2000 - Loss: 62.715
[tensor(2.8115), tensor([[1.0086, 0.9771]])]
Iter 931/2000 - Loss: 62.714
[tensor(2.8123), tensor([[1.0087, 0.9772]])]
Iter 932/2000 - Loss: 62.713
[tensor(2.8130), tensor([[1.0088, 0.9773]])]
Iter 933/2000 - Loss: 62.711
[tensor(2.8138), tensor([[1.0089, 0.9773]])]
Iter 934/2000 - Loss: 62.711
[tensor(2.8146), tensor([[1.0090, 0.9774]])]
Iter 935/2000 - Loss: 62.711
[tensor(2.8154), tensor([[1.0091, 0.9774]])]
Iter 936/2000 - Loss: 62.709
[tensor(2.8162), tensor([[1.0092, 0.9775]])]
Iter 937/2000 - Loss: 62.707
[tensor(2.8169), tensor([[1.0093, 0.9776]])]
Iter 938/2000 - Loss: 62.707
[tensor(2.8177), tensor([[1.0093, 0.9776]])]
Iter 939/2000 - Loss: 62.706
[tensor(2.8185), tensor([[1.0094, 0.9777]])]
Iter 940/2000 - Loss: 62.705
[tensor(2

[tensor(2.8886), tensor([[1.0174, 0.9833]])]
Iter 1040/2000 - Loss: 62.628
[tensor(2.8892), tensor([[1.0175, 0.9833]])]
Iter 1041/2000 - Loss: 62.629
[tensor(2.8898), tensor([[1.0175, 0.9834]])]
Iter 1042/2000 - Loss: 62.628
[tensor(2.8905), tensor([[1.0176, 0.9834]])]
Iter 1043/2000 - Loss: 62.628
[tensor(2.8911), tensor([[1.0177, 0.9835]])]
Iter 1044/2000 - Loss: 62.626
[tensor(2.8917), tensor([[1.0178, 0.9835]])]
Iter 1045/2000 - Loss: 62.626
[tensor(2.8923), tensor([[1.0178, 0.9836]])]
Iter 1046/2000 - Loss: 62.625
[tensor(2.8930), tensor([[1.0179, 0.9836]])]
Iter 1047/2000 - Loss: 62.625
[tensor(2.8936), tensor([[1.0180, 0.9837]])]
Iter 1048/2000 - Loss: 62.624
[tensor(2.8942), tensor([[1.0180, 0.9837]])]
Iter 1049/2000 - Loss: 62.623
[tensor(2.8948), tensor([[1.0181, 0.9838]])]
Iter 1050/2000 - Loss: 62.622
[tensor(2.8955), tensor([[1.0182, 0.9838]])]
Iter 1051/2000 - Loss: 62.622
[tensor(2.8961), tensor([[1.0183, 0.9839]])]
Iter 1052/2000 - Loss: 62.622
[tensor(2.8967), tensor([

Iter 1203/2000 - Loss: 62.562
[tensor(2.9764), tensor([[1.0272, 0.9901]])]
Iter 1204/2000 - Loss: 62.562
[tensor(2.9769), tensor([[1.0272, 0.9901]])]
Iter 1205/2000 - Loss: 62.561
[tensor(2.9773), tensor([[1.0273, 0.9902]])]
Iter 1206/2000 - Loss: 62.561
[tensor(2.9777), tensor([[1.0273, 0.9902]])]
Iter 1207/2000 - Loss: 62.562
[tensor(2.9782), tensor([[1.0274, 0.9902]])]
Iter 1208/2000 - Loss: 62.561
[tensor(2.9786), tensor([[1.0274, 0.9903]])]
Iter 1209/2000 - Loss: 62.560
[tensor(2.9791), tensor([[1.0275, 0.9903]])]
Iter 1210/2000 - Loss: 62.561
[tensor(2.9795), tensor([[1.0275, 0.9903]])]
Iter 1211/2000 - Loss: 62.560
[tensor(2.9800), tensor([[1.0276, 0.9904]])]
Iter 1212/2000 - Loss: 62.560
[tensor(2.9804), tensor([[1.0276, 0.9904]])]
Iter 1213/2000 - Loss: 62.559
[tensor(2.9808), tensor([[1.0277, 0.9905]])]
Iter 1214/2000 - Loss: 62.560
[tensor(2.9813), tensor([[1.0277, 0.9905]])]
Iter 1215/2000 - Loss: 62.559
[tensor(2.9817), tensor([[1.0277, 0.9905]])]
Iter 1216/2000 - Loss: 62

Iter 1315/2000 - Loss: 62.540
[tensor(3.0209), tensor([[1.0321, 0.9935]])]
Iter 1316/2000 - Loss: 62.540
[tensor(3.0212), tensor([[1.0321, 0.9935]])]
Iter 1317/2000 - Loss: 62.539
[tensor(3.0216), tensor([[1.0321, 0.9936]])]
Iter 1318/2000 - Loss: 62.539
[tensor(3.0219), tensor([[1.0322, 0.9936]])]
Iter 1319/2000 - Loss: 62.539
[tensor(3.0223), tensor([[1.0322, 0.9936]])]
Iter 1320/2000 - Loss: 62.538
[tensor(3.0226), tensor([[1.0322, 0.9936]])]
Iter 1321/2000 - Loss: 62.540
[tensor(3.0230), tensor([[1.0323, 0.9937]])]
Iter 1322/2000 - Loss: 62.539
[tensor(3.0233), tensor([[1.0323, 0.9937]])]
Iter 1323/2000 - Loss: 62.539
[tensor(3.0236), tensor([[1.0324, 0.9937]])]
Iter 1324/2000 - Loss: 62.539
[tensor(3.0240), tensor([[1.0324, 0.9937]])]
Iter 1325/2000 - Loss: 62.539
[tensor(3.0243), tensor([[1.0324, 0.9938]])]
Iter 1326/2000 - Loss: 62.538
[tensor(3.0247), tensor([[1.0325, 0.9938]])]
Iter 1327/2000 - Loss: 62.538
[tensor(3.0250), tensor([[1.0325, 0.9938]])]
Iter 1328/2000 - Loss: 62

Iter 1477/2000 - Loss: 62.524
[tensor(3.0679), tensor([[1.0371, 0.9971]])]
Iter 1478/2000 - Loss: 62.524
[tensor(3.0682), tensor([[1.0372, 0.9971]])]
Iter 1479/2000 - Loss: 62.524
[tensor(3.0684), tensor([[1.0372, 0.9971]])]
Iter 1480/2000 - Loss: 62.524
[tensor(3.0686), tensor([[1.0372, 0.9971]])]
Iter 1481/2000 - Loss: 62.524
[tensor(3.0689), tensor([[1.0372, 0.9972]])]
Iter 1482/2000 - Loss: 62.524
[tensor(3.0691), tensor([[1.0372, 0.9972]])]
Iter 1483/2000 - Loss: 62.524
[tensor(3.0693), tensor([[1.0373, 0.9972]])]
Iter 1484/2000 - Loss: 62.524
[tensor(3.0696), tensor([[1.0373, 0.9972]])]
Iter 1485/2000 - Loss: 62.523
[tensor(3.0698), tensor([[1.0373, 0.9972]])]
Iter 1486/2000 - Loss: 62.523
[tensor(3.0700), tensor([[1.0373, 0.9973]])]
Iter 1487/2000 - Loss: 62.523
[tensor(3.0703), tensor([[1.0374, 0.9973]])]
Iter 1488/2000 - Loss: 62.524
[tensor(3.0705), tensor([[1.0374, 0.9973]])]
Iter 1489/2000 - Loss: 62.523
[tensor(3.0707), tensor([[1.0374, 0.9973]])]
Iter 1490/2000 - Loss: 62

Iter 1592/2000 - Loss: 62.518
[tensor(3.0916), tensor([[1.0397, 0.9988]])]
Iter 1593/2000 - Loss: 62.519
[tensor(3.0918), tensor([[1.0397, 0.9989]])]
Iter 1594/2000 - Loss: 62.519
[tensor(3.0920), tensor([[1.0397, 0.9989]])]
Iter 1595/2000 - Loss: 62.518
[tensor(3.0921), tensor([[1.0397, 0.9989]])]
Iter 1596/2000 - Loss: 62.519
[tensor(3.0923), tensor([[1.0397, 0.9989]])]
Iter 1597/2000 - Loss: 62.519
[tensor(3.0925), tensor([[1.0398, 0.9989]])]
Iter 1598/2000 - Loss: 62.518
[tensor(3.0927), tensor([[1.0398, 0.9989]])]
Iter 1599/2000 - Loss: 62.519
[tensor(3.0928), tensor([[1.0398, 0.9989]])]
Iter 1600/2000 - Loss: 62.518
[tensor(3.0930), tensor([[1.0398, 0.9989]])]
Iter 1601/2000 - Loss: 62.519
[tensor(3.0932), tensor([[1.0398, 0.9990]])]
Iter 1602/2000 - Loss: 62.519
[tensor(3.0934), tensor([[1.0399, 0.9990]])]
Iter 1603/2000 - Loss: 62.518
[tensor(3.0935), tensor([[1.0399, 0.9990]])]
Iter 1604/2000 - Loss: 62.518
[tensor(3.0937), tensor([[1.0399, 0.9990]])]
Iter 1605/2000 - Loss: 62

[tensor(3.1093), tensor([[1.0416, 1.0002]])]
Iter 1710/2000 - Loss: 62.516
[tensor(3.1094), tensor([[1.0416, 1.0002]])]
Iter 1711/2000 - Loss: 62.516
[tensor(3.1096), tensor([[1.0416, 1.0002]])]
Iter 1712/2000 - Loss: 62.516
[tensor(3.1097), tensor([[1.0416, 1.0002]])]
Iter 1713/2000 - Loss: 62.516
[tensor(3.1098), tensor([[1.0416, 1.0002]])]
Iter 1714/2000 - Loss: 62.515
[tensor(3.1100), tensor([[1.0416, 1.0002]])]
Iter 1715/2000 - Loss: 62.516
[tensor(3.1101), tensor([[1.0416, 1.0002]])]
Iter 1716/2000 - Loss: 62.516
[tensor(3.1102), tensor([[1.0416, 1.0002]])]
Iter 1717/2000 - Loss: 62.517
[tensor(3.1103), tensor([[1.0417, 1.0002]])]
Iter 1718/2000 - Loss: 62.517
[tensor(3.1105), tensor([[1.0417, 1.0002]])]
Iter 1719/2000 - Loss: 62.516
[tensor(3.1106), tensor([[1.0417, 1.0003]])]
Iter 1720/2000 - Loss: 62.515
[tensor(3.1107), tensor([[1.0417, 1.0003]])]
Iter 1721/2000 - Loss: 62.516
[tensor(3.1108), tensor([[1.0417, 1.0003]])]
Iter 1722/2000 - Loss: 62.516
[tensor(3.1110), tensor([

Iter 1828/2000 - Loss: 62.515
[tensor(3.1222), tensor([[1.0429, 1.0011]])]
Iter 1829/2000 - Loss: 62.515
[tensor(3.1222), tensor([[1.0429, 1.0011]])]
Iter 1830/2000 - Loss: 62.515
[tensor(3.1223), tensor([[1.0429, 1.0012]])]
Iter 1831/2000 - Loss: 62.515
[tensor(3.1224), tensor([[1.0430, 1.0012]])]
Iter 1832/2000 - Loss: 62.515
[tensor(3.1225), tensor([[1.0430, 1.0012]])]
Iter 1833/2000 - Loss: 62.515
[tensor(3.1226), tensor([[1.0430, 1.0012]])]
Iter 1834/2000 - Loss: 62.515
[tensor(3.1227), tensor([[1.0430, 1.0012]])]
Iter 1835/2000 - Loss: 62.515
[tensor(3.1228), tensor([[1.0430, 1.0012]])]
Iter 1836/2000 - Loss: 62.515
[tensor(3.1229), tensor([[1.0430, 1.0012]])]
Iter 1837/2000 - Loss: 62.514
[tensor(3.1230), tensor([[1.0430, 1.0012]])]
Iter 1838/2000 - Loss: 62.515
[tensor(3.1230), tensor([[1.0430, 1.0012]])]
Iter 1839/2000 - Loss: 62.515
[tensor(3.1231), tensor([[1.0430, 1.0012]])]
Iter 1840/2000 - Loss: 62.515
[tensor(3.1232), tensor([[1.0430, 1.0012]])]
Iter 1841/2000 - Loss: 62

Iter 1987/2000 - Loss: 62.513
[tensor(3.1333), tensor([[1.0441, 1.0020]])]
Iter 1988/2000 - Loss: 62.514
[tensor(3.1334), tensor([[1.0441, 1.0020]])]
Iter 1989/2000 - Loss: 62.514
[tensor(3.1334), tensor([[1.0441, 1.0020]])]
Iter 1990/2000 - Loss: 62.514
[tensor(3.1335), tensor([[1.0441, 1.0020]])]
Iter 1991/2000 - Loss: 62.514
[tensor(3.1336), tensor([[1.0441, 1.0020]])]
Iter 1992/2000 - Loss: 62.514
[tensor(3.1336), tensor([[1.0441, 1.0020]])]
Iter 1993/2000 - Loss: 62.514
[tensor(3.1337), tensor([[1.0441, 1.0020]])]
Iter 1994/2000 - Loss: 62.514
[tensor(3.1337), tensor([[1.0441, 1.0020]])]
Iter 1995/2000 - Loss: 62.514
[tensor(3.1338), tensor([[1.0441, 1.0020]])]
Iter 1996/2000 - Loss: 62.515
[tensor(3.1338), tensor([[1.0441, 1.0020]])]
Iter 1997/2000 - Loss: 62.514
[tensor(3.1339), tensor([[1.0442, 1.0020]])]
Iter 1998/2000 - Loss: 62.514
[tensor(3.1339), tensor([[1.0442, 1.0020]])]
Iter 1999/2000 - Loss: 62.514
[tensor(3.1340), tensor([[1.0442, 1.0020]])]
Iter 2000/2000 - Loss: 62

In [None]:
class MertonBirthKernel(nn.Module):

    def __init__(self, **kwargs):
        nn.Module.__init__(self)
        self.locations_dimension = kwargs.get("locations_dimension")
        self.define_deep_parameters()
        
    def define_kernel(self):
        kernel = ScaleKernel(RBFKernel(ard_num_dims=self.locations_dimension, requires_grad=True),
                             requires_grad=True) + white_noise_kernel()

        kernel_hypers = {"raw_outputscale": torch.tensor(self.kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(self.kernel_lenght_scales)}

        kernel.kernels[0].initialize(**kernel_hypers)
        kernel_eval = lambda locations: kernel(locations, locations).evaluate().type(torch.float64)

        return kernel, kernel_eval

    def forward(self, locations_history):
        kernel, kernel_eval = define_kernel(kernel_sigma,kernel_lenght_scales)
        covariance_diffusion_history = kernel_eval(locations_history)
        return covariance_diffusion_history

    def define_deep_parameters(self):
        self.kernel_sigma = nn.Parameter(torch.Tensor([1.]))
        self.kernel_lenght_scales = nn.Parameter(torch.Tensor([20.,30.]))