In [3]:
import torch
from gpytorch.kernels import RBFKernel
import pandas as pd
import numpy as np
import pickle
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import DiagLinearOperator, LowRankRootLinearOperator, CholLinearOperator
def safe_inverse(psd):
    return CholLinearOperator(psd_safe_cholesky(psd, jitter = 1e-4)).inverse().to_dense()

def nll(Ytest, mu, var):
    Ytest = Ytest.cpu().numpy()
    nll = .5*np.log(2*np.pi*var) + np.square(Ytest - mu)/(2*var)
    return nll.mean()


def gen_data(ls, ntrain = 10000):
    X = torch.rand((ntrain + 2000, 2)).cuda()

    what = RBFKernel().cuda()
    what._set_lengthscale(torch.tensor([[ls]]).cuda())
    covar = what(X, X) + DiagLinearOperator(torch.ones(X.shape[0]).cuda())*1e-3
    dist = MultivariateNormal(torch.zeros(X.shape[0]).cuda(), covar)
    Y = dist.sample()
    Y = Y + torch.randn(Y.shape).cuda()*.0225
    Xtrain, Ytrain = X[:ntrain], Y[:ntrain]
    Xtest, Ytest = X[ntrain:], Y[ntrain:]
    del covar
    del dist
    train_covar = what(Xtrain, Xtrain) + DiagLinearOperator(torch.ones(Xtrain.shape[0]).cuda())*1e-3
    train_covar = CholLinearOperator(train_covar.cholesky())
    pred_covar = what(Xtest, Xtrain).to_dense()
    
    pred_mu = (pred_covar @train_covar.solve(Ytrain[:, None])).squeeze()
    
    Ktt =  what(Xtest, Xtest)._diagonal()
    var = (Ktt - train_covar.inv_quad_logdet(pred_covar.T, reduce_inv_quad = False)[0]).squeeze() + (.0225)**2 + 1e-4

    rmse = (Ytest - pred_mu).square().mean().sqrt()
    tnll =  nll(Ytest, pred_mu.detach().cpu().numpy(), var.detach().cpu().numpy())
    return 10*Xtrain, Ytrain,10*Xtest, Ytest, (rmse, tnll, pred_mu, var)



In [8]:
ntrain = 10000
for i in range(5):
    for l in [.05, .025, .01]:
        Xtrain, Ytrain, Xtest, Ytest, covar  = gen_data(l, ntrain)

        store_list = [Xtrain, Ytrain, Xtest, Ytest, covar]
       
        with open('{}_{}_{}.pickle'.format(l, i, ntrain), 'wb') as handle:
            pickle.dump(store_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
        


In [9]:
import pickle
with open("{}_{}_{}.pickle".format(.01, 0, 10000), "rb") as openfile:
    print(pickle.load(openfile)[-1][1])


-1.0138246
