In [1]:
import torch
import gpytorch
import numpy as np
from torch import nn
from torch import matmul as m
from gpytorch.kernels import RBFKernel, ScaleKernel
from torch.distributions import Normal, Poisson, MultivariateNormal


from torch.utils.tensorboard import SummaryWriter
from deep_fields.models.utils.basic_setups import create_dir_and_writer
from deep_fields.models.gaussian_processes.gaussian_processes import multivariate_normal, white_noise_kernel

In [2]:
locations_dimension = 2
prior_locations_mean =  0.
prior_locations_std = 6.69
number_of_realizations = 100
birth_intensity = 2.
kernel_sigma = 0.1
kernel_lenght_scales = [1., 2.]

locations_prior = Normal(torch.full((locations_dimension,), prior_locations_mean),
                         torch.full((locations_dimension,), prior_locations_std))
locations = locations_prior.sample(sample_shape=(100,))

In [3]:
class GPModel(nn.Module):
    
    def __init__(self,kernel_sigma,kernel_lenght_scales):
        nn.Module.__init__(self)        
        kernel = ScaleKernel(RBFKernel(ard_num_dims=locations_dimension, requires_grad=True),
                     requires_grad=True) + white_noise_kernel()

        kernel_hypers = {"raw_outputscale": torch.tensor(kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(kernel_lenght_scales)}

        kernel.kernels[0].initialize(**kernel_hypers)    
        self.mean_module = gpytorch.means.ConstantMean()
        
        self.covar_module = kernel

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x).evaluate()
        return mean_x,covar_x
    
def log_likelihood(covar_x,f):
    covar_dim =  covar_x.shape[0]
    covar_x_inverse = covar_x.inverse()
    covar_det = (covar_x).det()
    
    log_inverse = m(f,m(covar_x_inverse,f))
    log_det = torch.log(covar_det)
    log_probability = -.5*(log_det + log_inverse)
    return log_probability

In [4]:
gp0 = GPModel(kernel_sigma = 3.,kernel_lenght_scales = [1., 1.])
mean_x, covar_x = gp0(locations)
f = MultivariateNormal(mean_x,covar_x).sample()
log_likelihood(covar_x,f)

tensor(-61.9137, grad_fn=<MulBackward0>)

In [5]:
gp1 = GPModel(kernel_sigma=1.,kernel_lenght_scales =[.1, .1])
mean_x, covar_x = gp1(locations)
log_likelihood(covar_x,f)

tensor(-92.0886, grad_fn=<MulBackward0>)

In [6]:
gp1 = GPModel(kernel_sigma=.1,kernel_lenght_scales=[.1, .1])
training_iter = 2000
gp1.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(gp1.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    mean_x, covar_x = gp1(locations)
    # Calc loss and backprop gradients
    loss = -log_likelihood(covar_x,f)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (
        i + 1, training_iter, loss.item()
    ))
    print([a.data for a in gp1.covar_module.parameters()])
    optimizer.step()

Iter 1/2000 - Loss: 133.568
[tensor(0.1000), tensor([[0.1000, 0.1000]])]
Iter 2/2000 - Loss: 132.499
[tensor(0.1100), tensor([[0.1100, 0.1100]])]
Iter 3/2000 - Loss: 131.444
[tensor(0.1200), tensor([[0.1200, 0.1200]])]
Iter 4/2000 - Loss: 130.402
[tensor(0.1300), tensor([[0.1300, 0.1300]])]
Iter 5/2000 - Loss: 129.375
[tensor(0.1400), tensor([[0.1400, 0.1400]])]
Iter 6/2000 - Loss: 128.362
[tensor(0.1499), tensor([[0.1500, 0.1499]])]
Iter 7/2000 - Loss: 127.362
[tensor(0.1599), tensor([[0.1599, 0.1599]])]
Iter 8/2000 - Loss: 126.376
[tensor(0.1698), tensor([[0.1699, 0.1699]])]
Iter 9/2000 - Loss: 125.405
[tensor(0.1798), tensor([[0.1798, 0.1798]])]
Iter 10/2000 - Loss: 124.447
[tensor(0.1897), tensor([[0.1897, 0.1897]])]
Iter 11/2000 - Loss: 123.504
[tensor(0.1995), tensor([[0.1996, 0.1996]])]
Iter 12/2000 - Loss: 122.574
[tensor(0.2094), tensor([[0.2095, 0.2095]])]
Iter 13/2000 - Loss: 121.659
[tensor(0.2192), tensor([[0.2194, 0.2194]])]
Iter 14/2000 - Loss: 120.757
[tensor(0.2290), t

Iter 156/2000 - Loss: 73.179
[tensor(1.2466), tensor([[0.7587, 0.9786]])]
Iter 157/2000 - Loss: 73.076
[tensor(1.2518), tensor([[0.7595, 0.9799]])]
Iter 158/2000 - Loss: 72.973
[tensor(1.2569), tensor([[0.7603, 0.9812]])]
Iter 159/2000 - Loss: 72.872
[tensor(1.2619), tensor([[0.7611, 0.9825]])]
Iter 160/2000 - Loss: 72.772
[tensor(1.2670), tensor([[0.7619, 0.9838]])]
Iter 161/2000 - Loss: 72.672
[tensor(1.2721), tensor([[0.7627, 0.9851]])]
Iter 162/2000 - Loss: 72.574
[tensor(1.2771), tensor([[0.7635, 0.9864]])]
Iter 163/2000 - Loss: 72.477
[tensor(1.2821), tensor([[0.7643, 0.9876]])]
Iter 164/2000 - Loss: 72.380
[tensor(1.2871), tensor([[0.7651, 0.9889]])]
Iter 165/2000 - Loss: 72.285
[tensor(1.2920), tensor([[0.7659, 0.9902]])]
Iter 166/2000 - Loss: 72.191
[tensor(1.2970), tensor([[0.7666, 0.9915]])]
Iter 167/2000 - Loss: 72.098
[tensor(1.3019), tensor([[0.7674, 0.9927]])]
Iter 168/2000 - Loss: 72.005
[tensor(1.3068), tensor([[0.7682, 0.9940]])]
Iter 169/2000 - Loss: 71.914
[tensor(1

Iter 306/2000 - Loss: 64.778
[tensor(1.8443), tensor([[0.8428, 1.1135]])]
Iter 307/2000 - Loss: 64.749
[tensor(1.8474), tensor([[0.8432, 1.1141]])]
Iter 308/2000 - Loss: 64.721
[tensor(1.8506), tensor([[0.8436, 1.1148]])]
Iter 309/2000 - Loss: 64.693
[tensor(1.8537), tensor([[0.8440, 1.1154]])]
Iter 310/2000 - Loss: 64.665
[tensor(1.8568), tensor([[0.8444, 1.1160]])]
Iter 311/2000 - Loss: 64.639
[tensor(1.8598), tensor([[0.8448, 1.1166]])]
Iter 312/2000 - Loss: 64.610
[tensor(1.8629), tensor([[0.8452, 1.1172]])]
Iter 313/2000 - Loss: 64.583
[tensor(1.8660), tensor([[0.8456, 1.1178]])]
Iter 314/2000 - Loss: 64.557
[tensor(1.8690), tensor([[0.8459, 1.1183]])]
Iter 315/2000 - Loss: 64.529
[tensor(1.8721), tensor([[0.8463, 1.1189]])]
Iter 316/2000 - Loss: 64.503
[tensor(1.8751), tensor([[0.8467, 1.1195]])]
Iter 317/2000 - Loss: 64.477
[tensor(1.8782), tensor([[0.8471, 1.1201]])]
Iter 318/2000 - Loss: 64.450
[tensor(1.8812), tensor([[0.8475, 1.1207]])]
Iter 319/2000 - Loss: 64.424
[tensor(1

Iter 468/2000 - Loss: 61.963
[tensor(2.2579), tensor([[0.8920, 1.1880]])]
Iter 469/2000 - Loss: 61.954
[tensor(2.2600), tensor([[0.8923, 1.1884]])]
Iter 470/2000 - Loss: 61.944
[tensor(2.2621), tensor([[0.8925, 1.1887]])]
Iter 471/2000 - Loss: 61.934
[tensor(2.2641), tensor([[0.8927, 1.1891]])]
Iter 472/2000 - Loss: 61.924
[tensor(2.2662), tensor([[0.8930, 1.1894]])]
Iter 473/2000 - Loss: 61.914
[tensor(2.2683), tensor([[0.8932, 1.1898]])]
Iter 474/2000 - Loss: 61.904
[tensor(2.2703), tensor([[0.8934, 1.1901]])]
Iter 475/2000 - Loss: 61.895
[tensor(2.2724), tensor([[0.8936, 1.1904]])]
Iter 476/2000 - Loss: 61.885
[tensor(2.2744), tensor([[0.8939, 1.1908]])]
Iter 477/2000 - Loss: 61.876
[tensor(2.2765), tensor([[0.8941, 1.1911]])]
Iter 478/2000 - Loss: 61.866
[tensor(2.2785), tensor([[0.8943, 1.1915]])]
Iter 479/2000 - Loss: 61.856
[tensor(2.2806), tensor([[0.8946, 1.1918]])]
Iter 480/2000 - Loss: 61.847
[tensor(2.2826), tensor([[0.8948, 1.1921]])]
Iter 481/2000 - Loss: 61.837
[tensor(2

Iter 579/2000 - Loss: 61.129
[tensor(2.4636), tensor([[0.9142, 1.2209]])]
Iter 580/2000 - Loss: 61.125
[tensor(2.4652), tensor([[0.9144, 1.2212]])]
Iter 581/2000 - Loss: 61.119
[tensor(2.4668), tensor([[0.9145, 1.2214]])]
Iter 582/2000 - Loss: 61.114
[tensor(2.4685), tensor([[0.9147, 1.2217]])]
Iter 583/2000 - Loss: 61.109
[tensor(2.4701), tensor([[0.9149, 1.2219]])]
Iter 584/2000 - Loss: 61.103
[tensor(2.4717), tensor([[0.9150, 1.2222]])]
Iter 585/2000 - Loss: 61.098
[tensor(2.4733), tensor([[0.9152, 1.2224]])]
Iter 586/2000 - Loss: 61.093
[tensor(2.4750), tensor([[0.9154, 1.2227]])]
Iter 587/2000 - Loss: 61.087
[tensor(2.4766), tensor([[0.9155, 1.2229]])]
Iter 588/2000 - Loss: 61.082
[tensor(2.4782), tensor([[0.9157, 1.2232]])]
Iter 589/2000 - Loss: 61.077
[tensor(2.4798), tensor([[0.9159, 1.2234]])]
Iter 590/2000 - Loss: 61.072
[tensor(2.4814), tensor([[0.9160, 1.2237]])]
Iter 591/2000 - Loss: 61.066
[tensor(2.4830), tensor([[0.9162, 1.2239]])]
Iter 592/2000 - Loss: 61.062
[tensor(2

Iter 745/2000 - Loss: 60.517
[tensor(2.6947), tensor([[0.9374, 1.2552]])]
Iter 746/2000 - Loss: 60.516
[tensor(2.6959), tensor([[0.9375, 1.2553]])]
Iter 747/2000 - Loss: 60.513
[tensor(2.6971), tensor([[0.9376, 1.2555]])]
Iter 748/2000 - Loss: 60.510
[tensor(2.6982), tensor([[0.9378, 1.2557]])]
Iter 749/2000 - Loss: 60.508
[tensor(2.6994), tensor([[0.9379, 1.2558]])]
Iter 750/2000 - Loss: 60.505
[tensor(2.7006), tensor([[0.9380, 1.2560]])]
Iter 751/2000 - Loss: 60.504
[tensor(2.7017), tensor([[0.9381, 1.2562]])]
Iter 752/2000 - Loss: 60.501
[tensor(2.7029), tensor([[0.9382, 1.2563]])]
Iter 753/2000 - Loss: 60.500
[tensor(2.7040), tensor([[0.9383, 1.2565]])]
Iter 754/2000 - Loss: 60.497
[tensor(2.7052), tensor([[0.9384, 1.2567]])]
Iter 755/2000 - Loss: 60.495
[tensor(2.7064), tensor([[0.9385, 1.2568]])]
Iter 756/2000 - Loss: 60.493
[tensor(2.7075), tensor([[0.9387, 1.2570]])]
Iter 757/2000 - Loss: 60.490
[tensor(2.7087), tensor([[0.9388, 1.2572]])]
Iter 758/2000 - Loss: 60.487
[tensor(2

Iter 856/2000 - Loss: 60.314
[tensor(2.8118), tensor([[0.9486, 1.2715]])]
Iter 857/2000 - Loss: 60.313
[tensor(2.8127), tensor([[0.9487, 1.2716]])]
Iter 858/2000 - Loss: 60.311
[tensor(2.8137), tensor([[0.9487, 1.2717]])]
Iter 859/2000 - Loss: 60.311
[tensor(2.8146), tensor([[0.9488, 1.2719]])]
Iter 860/2000 - Loss: 60.309
[tensor(2.8156), tensor([[0.9489, 1.2720]])]
Iter 861/2000 - Loss: 60.308
[tensor(2.8165), tensor([[0.9490, 1.2721]])]
Iter 862/2000 - Loss: 60.306
[tensor(2.8174), tensor([[0.9491, 1.2723]])]
Iter 863/2000 - Loss: 60.304
[tensor(2.8184), tensor([[0.9492, 1.2724]])]
Iter 864/2000 - Loss: 60.304
[tensor(2.8193), tensor([[0.9493, 1.2725]])]
Iter 865/2000 - Loss: 60.302
[tensor(2.8202), tensor([[0.9494, 1.2726]])]
Iter 866/2000 - Loss: 60.301
[tensor(2.8211), tensor([[0.9494, 1.2728]])]
Iter 867/2000 - Loss: 60.299
[tensor(2.8221), tensor([[0.9495, 1.2729]])]
Iter 868/2000 - Loss: 60.299
[tensor(2.8230), tensor([[0.9496, 1.2730]])]
Iter 869/2000 - Loss: 60.296
[tensor(2

Iter 1010/2000 - Loss: 60.163
[tensor(2.9371), tensor([[0.9601, 1.2883]])]
Iter 1011/2000 - Loss: 60.162
[tensor(2.9378), tensor([[0.9601, 1.2884]])]
Iter 1012/2000 - Loss: 60.161
[tensor(2.9385), tensor([[0.9602, 1.2885]])]
Iter 1013/2000 - Loss: 60.161
[tensor(2.9392), tensor([[0.9602, 1.2886]])]
Iter 1014/2000 - Loss: 60.160
[tensor(2.9399), tensor([[0.9603, 1.2887]])]
Iter 1015/2000 - Loss: 60.160
[tensor(2.9406), tensor([[0.9604, 1.2887]])]
Iter 1016/2000 - Loss: 60.159
[tensor(2.9413), tensor([[0.9604, 1.2888]])]
Iter 1017/2000 - Loss: 60.158
[tensor(2.9420), tensor([[0.9605, 1.2889]])]
Iter 1018/2000 - Loss: 60.158
[tensor(2.9427), tensor([[0.9605, 1.2890]])]
Iter 1019/2000 - Loss: 60.157
[tensor(2.9433), tensor([[0.9606, 1.2891]])]
Iter 1020/2000 - Loss: 60.156
[tensor(2.9440), tensor([[0.9607, 1.2892]])]
Iter 1021/2000 - Loss: 60.157
[tensor(2.9447), tensor([[0.9607, 1.2893]])]
Iter 1022/2000 - Loss: 60.155
[tensor(2.9454), tensor([[0.9608, 1.2894]])]
Iter 1023/2000 - Loss: 60

Iter 1172/2000 - Loss: 60.089
[tensor(3.0331), tensor([[0.9686, 1.3007]])]
Iter 1173/2000 - Loss: 60.089
[tensor(3.0336), tensor([[0.9686, 1.3008]])]
Iter 1174/2000 - Loss: 60.089
[tensor(3.0341), tensor([[0.9687, 1.3008]])]
Iter 1175/2000 - Loss: 60.087
[tensor(3.0346), tensor([[0.9687, 1.3009]])]
Iter 1176/2000 - Loss: 60.088
[tensor(3.0351), tensor([[0.9687, 1.3010]])]
Iter 1177/2000 - Loss: 60.089
[tensor(3.0356), tensor([[0.9688, 1.3010]])]
Iter 1178/2000 - Loss: 60.086
[tensor(3.0361), tensor([[0.9688, 1.3011]])]
Iter 1179/2000 - Loss: 60.087
[tensor(3.0366), tensor([[0.9689, 1.3011]])]
Iter 1180/2000 - Loss: 60.085
[tensor(3.0371), tensor([[0.9689, 1.3012]])]
Iter 1181/2000 - Loss: 60.086
[tensor(3.0376), tensor([[0.9690, 1.3013]])]
Iter 1182/2000 - Loss: 60.086
[tensor(3.0380), tensor([[0.9690, 1.3013]])]
Iter 1183/2000 - Loss: 60.085
[tensor(3.0385), tensor([[0.9690, 1.3014]])]
Iter 1184/2000 - Loss: 60.085
[tensor(3.0390), tensor([[0.9691, 1.3015]])]
Iter 1185/2000 - Loss: 60

[tensor(3.1016), tensor([[0.9745, 1.3094]])]
Iter 1336/2000 - Loss: 60.054
[tensor(3.1020), tensor([[0.9745, 1.3094]])]
Iter 1337/2000 - Loss: 60.054
[tensor(3.1023), tensor([[0.9746, 1.3094]])]
Iter 1338/2000 - Loss: 60.054
[tensor(3.1027), tensor([[0.9746, 1.3095]])]
Iter 1339/2000 - Loss: 60.054
[tensor(3.1030), tensor([[0.9746, 1.3095]])]
Iter 1340/2000 - Loss: 60.055
[tensor(3.1034), tensor([[0.9746, 1.3096]])]
Iter 1341/2000 - Loss: 60.053
[tensor(3.1037), tensor([[0.9747, 1.3096]])]
Iter 1342/2000 - Loss: 60.055
[tensor(3.1040), tensor([[0.9747, 1.3097]])]
Iter 1343/2000 - Loss: 60.054
[tensor(3.1044), tensor([[0.9747, 1.3097]])]
Iter 1344/2000 - Loss: 60.053
[tensor(3.1047), tensor([[0.9748, 1.3097]])]
Iter 1345/2000 - Loss: 60.054
[tensor(3.1051), tensor([[0.9748, 1.3098]])]
Iter 1346/2000 - Loss: 60.054
[tensor(3.1054), tensor([[0.9748, 1.3098]])]
Iter 1347/2000 - Loss: 60.054
[tensor(3.1058), tensor([[0.9748, 1.3099]])]
Iter 1348/2000 - Loss: 60.054
[tensor(3.1061), tensor([

Iter 1487/2000 - Loss: 60.042
[tensor(3.1463), tensor([[0.9783, 1.3149]])]
Iter 1488/2000 - Loss: 60.041
[tensor(3.1465), tensor([[0.9783, 1.3149]])]
Iter 1489/2000 - Loss: 60.042
[tensor(3.1467), tensor([[0.9783, 1.3149]])]
Iter 1490/2000 - Loss: 60.041
[tensor(3.1470), tensor([[0.9784, 1.3150]])]
Iter 1491/2000 - Loss: 60.041
[tensor(3.1472), tensor([[0.9784, 1.3150]])]
Iter 1492/2000 - Loss: 60.041
[tensor(3.1475), tensor([[0.9784, 1.3150]])]
Iter 1493/2000 - Loss: 60.040
[tensor(3.1477), tensor([[0.9784, 1.3150]])]
Iter 1494/2000 - Loss: 60.041
[tensor(3.1479), tensor([[0.9784, 1.3151]])]
Iter 1495/2000 - Loss: 60.041
[tensor(3.1482), tensor([[0.9784, 1.3151]])]
Iter 1496/2000 - Loss: 60.041
[tensor(3.1484), tensor([[0.9785, 1.3151]])]
Iter 1497/2000 - Loss: 60.041
[tensor(3.1487), tensor([[0.9785, 1.3151]])]
Iter 1498/2000 - Loss: 60.040
[tensor(3.1489), tensor([[0.9785, 1.3152]])]
Iter 1499/2000 - Loss: 60.040
[tensor(3.1491), tensor([[0.9785, 1.3152]])]
Iter 1500/2000 - Loss: 60

Iter 1650/2000 - Loss: 60.035
[tensor(3.1787), tensor([[0.9810, 1.3189]])]
Iter 1651/2000 - Loss: 60.036
[tensor(3.1789), tensor([[0.9810, 1.3189]])]
Iter 1652/2000 - Loss: 60.036
[tensor(3.1790), tensor([[0.9811, 1.3189]])]
Iter 1653/2000 - Loss: 60.035
[tensor(3.1792), tensor([[0.9811, 1.3189]])]
Iter 1654/2000 - Loss: 60.036
[tensor(3.1794), tensor([[0.9811, 1.3189]])]
Iter 1655/2000 - Loss: 60.035
[tensor(3.1795), tensor([[0.9811, 1.3189]])]
Iter 1656/2000 - Loss: 60.035
[tensor(3.1797), tensor([[0.9811, 1.3190]])]
Iter 1657/2000 - Loss: 60.035
[tensor(3.1798), tensor([[0.9811, 1.3190]])]
Iter 1658/2000 - Loss: 60.035
[tensor(3.1800), tensor([[0.9811, 1.3190]])]
Iter 1659/2000 - Loss: 60.034
[tensor(3.1801), tensor([[0.9811, 1.3190]])]
Iter 1660/2000 - Loss: 60.035
[tensor(3.1803), tensor([[0.9812, 1.3190]])]
Iter 1661/2000 - Loss: 60.035
[tensor(3.1805), tensor([[0.9812, 1.3191]])]
Iter 1662/2000 - Loss: 60.035
[tensor(3.1806), tensor([[0.9812, 1.3191]])]
Iter 1663/2000 - Loss: 60

[tensor(3.1940), tensor([[0.9823, 1.3207]])]
Iter 1762/2000 - Loss: 60.033
[tensor(3.1941), tensor([[0.9823, 1.3207]])]
Iter 1763/2000 - Loss: 60.033
[tensor(3.1942), tensor([[0.9823, 1.3208]])]
Iter 1764/2000 - Loss: 60.033
[tensor(3.1943), tensor([[0.9823, 1.3208]])]
Iter 1765/2000 - Loss: 60.032
[tensor(3.1944), tensor([[0.9823, 1.3208]])]
Iter 1766/2000 - Loss: 60.033
[tensor(3.1945), tensor([[0.9823, 1.3208]])]
Iter 1767/2000 - Loss: 60.034
[tensor(3.1947), tensor([[0.9824, 1.3208]])]
Iter 1768/2000 - Loss: 60.033
[tensor(3.1948), tensor([[0.9824, 1.3208]])]
Iter 1769/2000 - Loss: 60.033
[tensor(3.1949), tensor([[0.9824, 1.3208]])]
Iter 1770/2000 - Loss: 60.034
[tensor(3.1950), tensor([[0.9824, 1.3208]])]
Iter 1771/2000 - Loss: 60.032
[tensor(3.1951), tensor([[0.9824, 1.3208]])]
Iter 1772/2000 - Loss: 60.033
[tensor(3.1952), tensor([[0.9824, 1.3208]])]
Iter 1773/2000 - Loss: 60.033
[tensor(3.1953), tensor([[0.9824, 1.3209]])]
Iter 1774/2000 - Loss: 60.034
[tensor(3.1955), tensor([

Iter 1875/2000 - Loss: 60.032
[tensor(3.2052), tensor([[0.9832, 1.3220]])]
Iter 1876/2000 - Loss: 60.032
[tensor(3.2053), tensor([[0.9832, 1.3221]])]
Iter 1877/2000 - Loss: 60.033
[tensor(3.2054), tensor([[0.9832, 1.3221]])]
Iter 1878/2000 - Loss: 60.033
[tensor(3.2055), tensor([[0.9833, 1.3221]])]
Iter 1879/2000 - Loss: 60.032
[tensor(3.2056), tensor([[0.9833, 1.3221]])]
Iter 1880/2000 - Loss: 60.032
[tensor(3.2057), tensor([[0.9833, 1.3221]])]
Iter 1881/2000 - Loss: 60.033
[tensor(3.2057), tensor([[0.9833, 1.3221]])]
Iter 1882/2000 - Loss: 60.033
[tensor(3.2058), tensor([[0.9833, 1.3221]])]
Iter 1883/2000 - Loss: 60.032
[tensor(3.2059), tensor([[0.9833, 1.3221]])]
Iter 1884/2000 - Loss: 60.033
[tensor(3.2060), tensor([[0.9833, 1.3221]])]
Iter 1885/2000 - Loss: 60.033
[tensor(3.2061), tensor([[0.9833, 1.3222]])]
Iter 1886/2000 - Loss: 60.032
[tensor(3.2061), tensor([[0.9833, 1.3222]])]
Iter 1887/2000 - Loss: 60.033
[tensor(3.2062), tensor([[0.9833, 1.3222]])]
Iter 1888/2000 - Loss: 60

In [None]:
class MertonBirthKernel(nn.Module):

    def __init__(self, **kwargs):
        nn.Module.__init__(self)
        self.locations_dimension = kwargs.get("locations_dimension")
        self.define_deep_parameters()
        
    def define_kernel(self):
        kernel = ScaleKernel(RBFKernel(ard_num_dims=self.locations_dimension, requires_grad=True),
                             requires_grad=True) + white_noise_kernel()

        kernel_hypers = {"raw_outputscale": torch.tensor(self.kernel_sigma),
                         "base_kernel.raw_lengthscale": torch.tensor(self.kernel_lenght_scales)}

        kernel.kernels[0].initialize(**kernel_hypers)
        kernel_eval = lambda locations: kernel(locations, locations).evaluate().type(torch.float64)

        return kernel, kernel_eval

    def forward(self, locations_history):
        kernel, kernel_eval = self.define_kernel(kernel_sigma,kernel_lenght_scales)
        covariance_diffusion_history = kernel_eval(locations_history)
        return covariance_diffusion_history

    def define_deep_parameters(self):
        self.kernel_sigma = nn.Parameter(torch.Tensor([1.]))
        self.kernel_lenght_scales = nn.Parameter(torch.Tensor([20.,30.]))