In [1]:
import torch
from gpytorch.means import ConstantMean
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.variational import WhitenedVariationalStrategy, CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from gpytorch.models import AbstractVariationalGP
from gpytorch.mlls import VariationalELBO, AddedLossTerm
from gpytorch.likelihoods import GaussianLikelihood

In [2]:
from deep_gp import DeepGP, DeepGaussianLikelihood

In [3]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor
import numpy as np

data = torch.Tensor(loadmat('../../uci/protein/protein.mat')['data'])
X = data[:, :-1]
y = data[:, -1]

N = data.shape[0]
np.random.seed(0)
data = data[np.random.permutation(np.arange(N)),:]

train_n = int(floor(0.8*len(X)))

train_x = X[:train_n, :].contiguous().cuda()
train_y = y[:train_n].contiguous().cuda()

test_x = X[train_n:, :].contiguous().cuda()
test_y = y[train_n:].contiguous().cuda()

mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

mean,std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

In [4]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

In [5]:
num_samples = 5

model = DeepGP(input_dims=train_x.size(-1), hidden_dims=15, output_dims=1, num_inducing=200, num_samples=num_samples).cuda()
likelihood = DeepGaussianLikelihood(num_samples=num_samples).cuda()
mll = VariationalELBO(likelihood, model, train_x.size(-2))

In [6]:
num_epochs = 60

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

import time

for i in range(num_epochs):
    # Within each iteration, we will go over each minibatch of data
    for minibatch_i, (x_batch, y_batch) in enumerate(train_loader):
        start_time = time.time()
        optimizer.zero_grad()
        output = model(x_batch)
        print(output.mean.shape)
        y_batch = y_batch.unsqueeze(0).unsqueeze(0).expand(model.output_dims, model.num_samples, y_batch.size(-1))
        loss = -mll(output, y_batch, num_samples=num_samples).sum()
        
        print('Epoch %d [%d/%d] - Loss: %.3f - - Time: %.3f' % (i + 1, minibatch_i, len(train_loader), loss.item(), time.time() - start_time))

        loss.backward()
        optimizer.step()

torch.Size([1, 5120])
Epoch 1 [0/36] - Loss: 1.907 - - Time: 0.468
torch.Size([1, 5120])
Epoch 1 [1/36] - Loss: 1.950 - - Time: 0.144
torch.Size([1, 5120])
Epoch 1 [2/36] - Loss: 1.948 - - Time: 0.149
torch.Size([1, 5120])
Epoch 1 [3/36] - Loss: 1.952 - - Time: 0.142
torch.Size([1, 5120])
Epoch 1 [4/36] - Loss: 1.908 - - Time: 0.150
torch.Size([1, 5120])
Epoch 1 [5/36] - Loss: 1.914 - - Time: 0.143
torch.Size([1, 5120])
Epoch 1 [6/36] - Loss: 1.889 - - Time: 0.149
torch.Size([1, 5120])
Epoch 1 [7/36] - Loss: 1.897 - - Time: 0.141
torch.Size([1, 5120])
Epoch 1 [8/36] - Loss: 1.895 - - Time: 0.146
torch.Size([1, 5120])
Epoch 1 [9/36] - Loss: 1.901 - - Time: 0.187
torch.Size([1, 5120])
Epoch 1 [10/36] - Loss: 1.870 - - Time: 0.195
torch.Size([1, 5120])
Epoch 1 [11/36] - Loss: 1.879 - - Time: 0.164
torch.Size([1, 5120])
Epoch 1 [12/36] - Loss: 1.886 - - Time: 0.203
torch.Size([1, 5120])
Epoch 1 [13/36] - Loss: 1.842 - - Time: 0.140
torch.Size([1, 5120])
Epoch 1 [14/36] - Loss: 1.853 - - Ti

torch.Size([1, 5120])
Epoch 4 [14/36] - Loss: 1.577 - - Time: 0.143
torch.Size([1, 5120])
Epoch 4 [15/36] - Loss: 1.556 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [16/36] - Loss: 1.582 - - Time: 0.142
torch.Size([1, 5120])
Epoch 4 [17/36] - Loss: 1.549 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [18/36] - Loss: 1.555 - - Time: 0.142
torch.Size([1, 5120])
Epoch 4 [19/36] - Loss: 1.552 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [20/36] - Loss: 1.558 - - Time: 0.142
torch.Size([1, 5120])
Epoch 4 [21/36] - Loss: 1.562 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [22/36] - Loss: 1.551 - - Time: 0.142
torch.Size([1, 5120])
Epoch 4 [23/36] - Loss: 1.558 - - Time: 0.154
torch.Size([1, 5120])
Epoch 4 [24/36] - Loss: 1.558 - - Time: 0.143
torch.Size([1, 5120])
Epoch 4 [25/36] - Loss: 1.566 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [26/36] - Loss: 1.563 - - Time: 0.141
torch.Size([1, 5120])
Epoch 4 [27/36] - Loss: 1.568 - - Time: 0.140
torch.Size([1, 5120])
Epoch 4 [28/36] - Loss: 1.

torch.Size([1, 5120])
Epoch 7 [28/36] - Loss: 1.455 - - Time: 0.144
torch.Size([1, 5120])
Epoch 7 [29/36] - Loss: 1.444 - - Time: 0.140
torch.Size([1, 5120])
Epoch 7 [30/36] - Loss: 1.464 - - Time: 0.144
torch.Size([1, 5120])
Epoch 7 [31/36] - Loss: 1.456 - - Time: 0.140
torch.Size([1, 5120])
Epoch 7 [32/36] - Loss: 1.442 - - Time: 0.142
torch.Size([1, 5120])
Epoch 7 [33/36] - Loss: 1.452 - - Time: 0.140
torch.Size([1, 5120])
Epoch 7 [34/36] - Loss: 1.457 - - Time: 0.144
torch.Size([1, 3720])
Epoch 7 [35/36] - Loss: 1.438 - - Time: 0.109
torch.Size([1, 5120])
Epoch 8 [0/36] - Loss: 1.440 - - Time: 0.144
torch.Size([1, 5120])
Epoch 8 [1/36] - Loss: 1.434 - - Time: 0.140
torch.Size([1, 5120])
Epoch 8 [2/36] - Loss: 1.430 - - Time: 0.144
torch.Size([1, 5120])
Epoch 8 [3/36] - Loss: 1.462 - - Time: 0.140
torch.Size([1, 5120])
Epoch 8 [4/36] - Loss: 1.439 - - Time: 0.144
torch.Size([1, 5120])
Epoch 8 [5/36] - Loss: 1.442 - - Time: 0.141
torch.Size([1, 5120])
Epoch 8 [6/36] - Loss: 1.436 - -

torch.Size([1, 5120])
Epoch 11 [6/36] - Loss: 1.114 - - Time: 0.145
torch.Size([1, 5120])
Epoch 11 [7/36] - Loss: 1.092 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [8/36] - Loss: 1.100 - - Time: 0.146
torch.Size([1, 5120])
Epoch 11 [9/36] - Loss: 1.100 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [10/36] - Loss: 1.137 - - Time: 0.145
torch.Size([1, 5120])
Epoch 11 [11/36] - Loss: 1.106 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [12/36] - Loss: 1.130 - - Time: 0.145
torch.Size([1, 5120])
Epoch 11 [13/36] - Loss: 1.103 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [14/36] - Loss: 1.113 - - Time: 0.148
torch.Size([1, 5120])
Epoch 11 [15/36] - Loss: 1.102 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [16/36] - Loss: 1.134 - - Time: 0.146
torch.Size([1, 5120])
Epoch 11 [17/36] - Loss: 1.093 - - Time: 0.142
torch.Size([1, 5120])
Epoch 11 [18/36] - Loss: 1.078 - - Time: 0.151
torch.Size([1, 5120])
Epoch 11 [19/36] - Loss: 1.128 - - Time: 0.141
torch.Size([1, 5120])
Epoch 11 [20/36]

torch.Size([1, 5120])
Epoch 14 [18/36] - Loss: 1.022 - - Time: 0.143
torch.Size([1, 5120])
Epoch 14 [19/36] - Loss: 0.990 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [20/36] - Loss: 1.013 - - Time: 0.142
torch.Size([1, 5120])
Epoch 14 [21/36] - Loss: 1.045 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [22/36] - Loss: 1.009 - - Time: 0.143
torch.Size([1, 5120])
Epoch 14 [23/36] - Loss: 0.994 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [24/36] - Loss: 1.036 - - Time: 0.144
torch.Size([1, 5120])
Epoch 14 [25/36] - Loss: 0.989 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [26/36] - Loss: 0.960 - - Time: 0.143
torch.Size([1, 5120])
Epoch 14 [27/36] - Loss: 0.998 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [28/36] - Loss: 0.957 - - Time: 0.143
torch.Size([1, 5120])
Epoch 14 [29/36] - Loss: 0.946 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [30/36] - Loss: 0.977 - - Time: 0.143
torch.Size([1, 5120])
Epoch 14 [31/36] - Loss: 0.945 - - Time: 0.141
torch.Size([1, 5120])
Epoch 14 [32

torch.Size([1, 5120])
Epoch 17 [30/36] - Loss: 0.924 - - Time: 0.147
torch.Size([1, 5120])
Epoch 17 [31/36] - Loss: 0.939 - - Time: 0.141
torch.Size([1, 5120])
Epoch 17 [32/36] - Loss: 0.916 - - Time: 0.148
torch.Size([1, 5120])
Epoch 17 [33/36] - Loss: 0.920 - - Time: 0.141
torch.Size([1, 5120])
Epoch 17 [34/36] - Loss: 0.938 - - Time: 0.147
torch.Size([1, 3720])
Epoch 17 [35/36] - Loss: 0.956 - - Time: 0.110
torch.Size([1, 5120])
Epoch 18 [0/36] - Loss: 0.899 - - Time: 0.147
torch.Size([1, 5120])
Epoch 18 [1/36] - Loss: 0.960 - - Time: 0.142
torch.Size([1, 5120])
Epoch 18 [2/36] - Loss: 0.953 - - Time: 0.148
torch.Size([1, 5120])
Epoch 18 [3/36] - Loss: 0.889 - - Time: 0.141
torch.Size([1, 5120])
Epoch 18 [4/36] - Loss: 0.927 - - Time: 0.147
torch.Size([1, 5120])
Epoch 18 [5/36] - Loss: 0.859 - - Time: 0.142
torch.Size([1, 5120])
Epoch 18 [6/36] - Loss: 0.895 - - Time: 0.148
torch.Size([1, 5120])
Epoch 18 [7/36] - Loss: 0.921 - - Time: 0.142
torch.Size([1, 5120])
Epoch 18 [8/36] - Lo

torch.Size([1, 5120])
Epoch 21 [6/36] - Loss: 0.896 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [7/36] - Loss: 0.841 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [8/36] - Loss: 0.868 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [9/36] - Loss: 0.875 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [10/36] - Loss: 0.850 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [11/36] - Loss: 0.878 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [12/36] - Loss: 0.896 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [13/36] - Loss: 0.803 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [14/36] - Loss: 0.828 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [15/36] - Loss: 0.913 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [16/36] - Loss: 0.886 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [17/36] - Loss: 0.838 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [18/36] - Loss: 0.839 - - Time: 0.143
torch.Size([1, 5120])
Epoch 21 [19/36] - Loss: 0.835 - - Time: 0.141
torch.Size([1, 5120])
Epoch 21 [20/36]

torch.Size([1, 5120])
Epoch 24 [18/36] - Loss: 0.832 - - Time: 0.143
torch.Size([1, 5120])
Epoch 24 [19/36] - Loss: 0.769 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [20/36] - Loss: 0.799 - - Time: 0.143
torch.Size([1, 5120])
Epoch 24 [21/36] - Loss: 0.791 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [22/36] - Loss: 0.827 - - Time: 0.143
torch.Size([1, 5120])
Epoch 24 [23/36] - Loss: 0.758 - - Time: 0.142
torch.Size([1, 5120])
Epoch 24 [24/36] - Loss: 0.725 - - Time: 0.157
torch.Size([1, 5120])
Epoch 24 [25/36] - Loss: 0.878 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [26/36] - Loss: 0.895 - - Time: 0.144
torch.Size([1, 5120])
Epoch 24 [27/36] - Loss: 0.870 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [28/36] - Loss: 0.848 - - Time: 0.143
torch.Size([1, 5120])
Epoch 24 [29/36] - Loss: 0.857 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [30/36] - Loss: 0.776 - - Time: 0.143
torch.Size([1, 5120])
Epoch 24 [31/36] - Loss: 0.791 - - Time: 0.141
torch.Size([1, 5120])
Epoch 24 [32

torch.Size([1, 5120])
Epoch 27 [30/36] - Loss: 0.758 - - Time: 0.143
torch.Size([1, 5120])
Epoch 27 [31/36] - Loss: 0.730 - - Time: 0.141
torch.Size([1, 5120])
Epoch 27 [32/36] - Loss: 0.812 - - Time: 0.142
torch.Size([1, 5120])
Epoch 27 [33/36] - Loss: 0.783 - - Time: 0.141
torch.Size([1, 5120])
Epoch 27 [34/36] - Loss: 0.742 - - Time: 0.143
torch.Size([1, 3720])
Epoch 27 [35/36] - Loss: 0.788 - - Time: 0.110
torch.Size([1, 5120])
Epoch 28 [0/36] - Loss: 0.751 - - Time: 0.145
torch.Size([1, 5120])
Epoch 28 [1/36] - Loss: 0.758 - - Time: 0.143
torch.Size([1, 5120])
Epoch 28 [2/36] - Loss: 0.683 - - Time: 0.144
torch.Size([1, 5120])
Epoch 28 [3/36] - Loss: 0.757 - - Time: 0.148
torch.Size([1, 5120])
Epoch 28 [4/36] - Loss: 0.742 - - Time: 0.146
torch.Size([1, 5120])
Epoch 28 [5/36] - Loss: 0.768 - - Time: 0.143
torch.Size([1, 5120])
Epoch 28 [6/36] - Loss: 0.741 - - Time: 0.147
torch.Size([1, 5120])
Epoch 28 [7/36] - Loss: 0.721 - - Time: 0.144
torch.Size([1, 5120])
Epoch 28 [8/36] - Lo

torch.Size([1, 5120])
Epoch 31 [6/36] - Loss: 0.690 - - Time: 0.143
torch.Size([1, 5120])
Epoch 31 [7/36] - Loss: 0.751 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [8/36] - Loss: 0.794 - - Time: 0.143
torch.Size([1, 5120])
Epoch 31 [9/36] - Loss: 0.699 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [10/36] - Loss: 0.698 - - Time: 0.147
torch.Size([1, 5120])
Epoch 31 [11/36] - Loss: 0.714 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [12/36] - Loss: 0.736 - - Time: 0.143
torch.Size([1, 5120])
Epoch 31 [13/36] - Loss: 0.727 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [14/36] - Loss: 0.727 - - Time: 0.143
torch.Size([1, 5120])
Epoch 31 [15/36] - Loss: 0.712 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [16/36] - Loss: 0.773 - - Time: 0.143
torch.Size([1, 5120])
Epoch 31 [17/36] - Loss: 0.688 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [18/36] - Loss: 0.674 - - Time: 0.142
torch.Size([1, 5120])
Epoch 31 [19/36] - Loss: 0.717 - - Time: 0.141
torch.Size([1, 5120])
Epoch 31 [20/36]

torch.Size([1, 5120])
Epoch 34 [18/36] - Loss: 0.713 - - Time: 0.143
torch.Size([1, 5120])
Epoch 34 [19/36] - Loss: 0.673 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [20/36] - Loss: 0.636 - - Time: 0.143
torch.Size([1, 5120])
Epoch 34 [21/36] - Loss: 0.651 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [22/36] - Loss: 0.702 - - Time: 0.142
torch.Size([1, 5120])
Epoch 34 [23/36] - Loss: 0.685 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [24/36] - Loss: 0.663 - - Time: 0.144
torch.Size([1, 5120])
Epoch 34 [25/36] - Loss: 0.682 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [26/36] - Loss: 0.620 - - Time: 0.143
torch.Size([1, 5120])
Epoch 34 [27/36] - Loss: 0.649 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [28/36] - Loss: 0.651 - - Time: 0.143
torch.Size([1, 5120])
Epoch 34 [29/36] - Loss: 0.695 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [30/36] - Loss: 0.659 - - Time: 0.143
torch.Size([1, 5120])
Epoch 34 [31/36] - Loss: 0.729 - - Time: 0.141
torch.Size([1, 5120])
Epoch 34 [32

torch.Size([1, 5120])
Epoch 37 [30/36] - Loss: 0.705 - - Time: 0.143
torch.Size([1, 5120])
Epoch 37 [31/36] - Loss: 0.645 - - Time: 0.141
torch.Size([1, 5120])
Epoch 37 [32/36] - Loss: 0.712 - - Time: 0.144
torch.Size([1, 5120])
Epoch 37 [33/36] - Loss: 0.720 - - Time: 0.141
torch.Size([1, 5120])
Epoch 37 [34/36] - Loss: 0.672 - - Time: 0.144
torch.Size([1, 3720])
Epoch 37 [35/36] - Loss: 0.753 - - Time: 0.110
torch.Size([1, 5120])
Epoch 38 [0/36] - Loss: 0.638 - - Time: 0.144
torch.Size([1, 5120])
Epoch 38 [1/36] - Loss: 0.669 - - Time: 0.142
torch.Size([1, 5120])
Epoch 38 [2/36] - Loss: 0.667 - - Time: 0.143
torch.Size([1, 5120])
Epoch 38 [3/36] - Loss: 0.670 - - Time: 0.141
torch.Size([1, 5120])
Epoch 38 [4/36] - Loss: 0.656 - - Time: 0.144
torch.Size([1, 5120])
Epoch 38 [5/36] - Loss: 0.643 - - Time: 0.141
torch.Size([1, 5120])
Epoch 38 [6/36] - Loss: 0.595 - - Time: 0.143
torch.Size([1, 5120])
Epoch 38 [7/36] - Loss: 0.667 - - Time: 0.141
torch.Size([1, 5120])
Epoch 38 [8/36] - Lo

torch.Size([1, 5120])
Epoch 41 [6/36] - Loss: 0.620 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [7/36] - Loss: 0.580 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [8/36] - Loss: 0.707 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [9/36] - Loss: 0.664 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [10/36] - Loss: 0.642 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [11/36] - Loss: 0.629 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [12/36] - Loss: 0.733 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [13/36] - Loss: 0.592 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [14/36] - Loss: 0.693 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [15/36] - Loss: 0.618 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [16/36] - Loss: 0.560 - - Time: 0.143
torch.Size([1, 5120])
Epoch 41 [17/36] - Loss: 0.638 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [18/36] - Loss: 0.639 - - Time: 0.144
torch.Size([1, 5120])
Epoch 41 [19/36] - Loss: 0.647 - - Time: 0.141
torch.Size([1, 5120])
Epoch 41 [20/36]

torch.Size([1, 5120])
Epoch 44 [18/36] - Loss: 0.574 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [19/36] - Loss: 0.637 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [20/36] - Loss: 0.660 - - Time: 0.142
torch.Size([1, 5120])
Epoch 44 [21/36] - Loss: 0.562 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [22/36] - Loss: 0.575 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [23/36] - Loss: 0.673 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [24/36] - Loss: 0.613 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [25/36] - Loss: 0.689 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [26/36] - Loss: 0.677 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [27/36] - Loss: 0.675 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [28/36] - Loss: 0.636 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [29/36] - Loss: 0.706 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [30/36] - Loss: 0.678 - - Time: 0.143
torch.Size([1, 5120])
Epoch 44 [31/36] - Loss: 0.691 - - Time: 0.141
torch.Size([1, 5120])
Epoch 44 [32

torch.Size([1, 5120])
Epoch 47 [30/36] - Loss: 0.586 - - Time: 0.143
torch.Size([1, 5120])
Epoch 47 [31/36] - Loss: 0.615 - - Time: 0.141
torch.Size([1, 5120])
Epoch 47 [32/36] - Loss: 0.595 - - Time: 0.143
torch.Size([1, 5120])
Epoch 47 [33/36] - Loss: 0.579 - - Time: 0.141
torch.Size([1, 5120])
Epoch 47 [34/36] - Loss: 0.615 - - Time: 0.142
torch.Size([1, 3720])
Epoch 47 [35/36] - Loss: 0.686 - - Time: 0.110
torch.Size([1, 5120])
Epoch 48 [0/36] - Loss: 0.560 - - Time: 0.144
torch.Size([1, 5120])
Epoch 48 [1/36] - Loss: 0.538 - - Time: 0.141
torch.Size([1, 5120])
Epoch 48 [2/36] - Loss: 0.560 - - Time: 0.144
torch.Size([1, 5120])
Epoch 48 [3/36] - Loss: 0.583 - - Time: 0.141
torch.Size([1, 5120])
Epoch 48 [4/36] - Loss: 0.627 - - Time: 0.143
torch.Size([1, 5120])
Epoch 48 [5/36] - Loss: 0.571 - - Time: 0.141
torch.Size([1, 5120])
Epoch 48 [6/36] - Loss: 0.580 - - Time: 0.143
torch.Size([1, 5120])
Epoch 48 [7/36] - Loss: 0.548 - - Time: 0.141
torch.Size([1, 5120])
Epoch 48 [8/36] - Lo

torch.Size([1, 5120])
Epoch 51 [6/36] - Loss: 0.558 - - Time: 0.147
torch.Size([1, 5120])
Epoch 51 [7/36] - Loss: 0.489 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [8/36] - Loss: 0.576 - - Time: 0.144
torch.Size([1, 5120])
Epoch 51 [9/36] - Loss: 0.585 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [10/36] - Loss: 0.525 - - Time: 0.145
torch.Size([1, 5120])
Epoch 51 [11/36] - Loss: 0.550 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [12/36] - Loss: 0.515 - - Time: 0.145
torch.Size([1, 5120])
Epoch 51 [13/36] - Loss: 0.510 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [14/36] - Loss: 0.575 - - Time: 0.144
torch.Size([1, 5120])
Epoch 51 [15/36] - Loss: 0.573 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [16/36] - Loss: 0.554 - - Time: 0.144
torch.Size([1, 5120])
Epoch 51 [17/36] - Loss: 0.586 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [18/36] - Loss: 0.614 - - Time: 0.143
torch.Size([1, 5120])
Epoch 51 [19/36] - Loss: 0.534 - - Time: 0.141
torch.Size([1, 5120])
Epoch 51 [20/36]

torch.Size([1, 5120])
Epoch 54 [18/36] - Loss: 0.537 - - Time: 0.152
torch.Size([1, 5120])
Epoch 54 [19/36] - Loss: 0.542 - - Time: 0.166
torch.Size([1, 5120])
Epoch 54 [20/36] - Loss: 0.506 - - Time: 0.153
torch.Size([1, 5120])
Epoch 54 [21/36] - Loss: 0.498 - - Time: 0.161
torch.Size([1, 5120])
Epoch 54 [22/36] - Loss: 0.521 - - Time: 0.152
torch.Size([1, 5120])
Epoch 54 [23/36] - Loss: 0.525 - - Time: 0.157
torch.Size([1, 5120])
Epoch 54 [24/36] - Loss: 0.576 - - Time: 0.161
torch.Size([1, 5120])
Epoch 54 [25/36] - Loss: 0.469 - - Time: 0.142
torch.Size([1, 5120])
Epoch 54 [26/36] - Loss: 0.575 - - Time: 0.148
torch.Size([1, 5120])
Epoch 54 [27/36] - Loss: 0.612 - - Time: 0.143
torch.Size([1, 5120])
Epoch 54 [28/36] - Loss: 0.580 - - Time: 0.149
torch.Size([1, 5120])
Epoch 54 [29/36] - Loss: 0.545 - - Time: 0.151
torch.Size([1, 5120])
Epoch 54 [30/36] - Loss: 0.569 - - Time: 0.146
torch.Size([1, 5120])
Epoch 54 [31/36] - Loss: 0.543 - - Time: 0.142
torch.Size([1, 5120])
Epoch 54 [32

torch.Size([1, 5120])
Epoch 57 [31/36] - Loss: 0.565 - - Time: 0.147
torch.Size([1, 5120])
Epoch 57 [32/36] - Loss: 0.564 - - Time: 0.141
torch.Size([1, 5120])
Epoch 57 [33/36] - Loss: 0.511 - - Time: 0.147
torch.Size([1, 5120])
Epoch 57 [34/36] - Loss: 0.493 - - Time: 0.142
torch.Size([1, 3720])
Epoch 57 [35/36] - Loss: 0.636 - - Time: 0.117
torch.Size([1, 5120])
Epoch 58 [0/36] - Loss: 0.484 - - Time: 0.141
torch.Size([1, 5120])
Epoch 58 [1/36] - Loss: 0.519 - - Time: 0.147
torch.Size([1, 5120])
Epoch 58 [2/36] - Loss: 0.439 - - Time: 0.141
torch.Size([1, 5120])
Epoch 58 [3/36] - Loss: 0.475 - - Time: 0.148
torch.Size([1, 5120])
Epoch 58 [4/36] - Loss: 0.469 - - Time: 0.141
torch.Size([1, 5120])
Epoch 58 [5/36] - Loss: 0.523 - - Time: 0.147
torch.Size([1, 5120])
Epoch 58 [6/36] - Loss: 0.505 - - Time: 0.141
torch.Size([1, 5120])
Epoch 58 [7/36] - Loss: 0.547 - - Time: 0.148
torch.Size([1, 5120])
Epoch 58 [8/36] - Loss: 0.556 - - Time: 0.142
torch.Size([1, 5120])
Epoch 58 [9/36] - Los

In [7]:
preds = likelihood(model(test_x))

In [8]:
torch.mean(torch.pow(preds.mean.reshape(model.num_samples, -1).mean(0) - test_y, 2))

tensor(0.2788, device='cuda:0', grad_fn=<MeanBackward1>)

In [27]:
test_y

tensor([-0.3685, -0.1655,  1.0216,  ...,  2.6740, -0.8084, -0.1655],
       device='cuda:0')

In [29]:
preds.mean.reshape(model.num_samples, -1)

torch.Size([3, 3320])

In [30]:
test_y.shape

torch.Size([3320])