In [1]:
import math
import numpy as np
import torch
import gpytorch
import tqdm
import random
import time
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
sys.path.append("../")
sys.path.append("../directionalvi/utils")
sys.path.append("../directionalvi")
import traditional_vi
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy
from dfree_directional_vi import train_gp, eval_gp
from metrics import MSE
import testfun
from csv_dataset import csv_dataset

In [2]:
dataset = csv_dataset("../experiments/real_data/kin40k.csv", gradients=False, rescale=True)

[0 1 2 3 4 5 6 7]
[8]


In [3]:
dataset.dim

8

In [4]:
# data parameters
n   = dataset.n
print("n is: ", n)
dim = dataset.dim
print("dims is: ", dim)

# training params
num_inducing = 500
num_directions = 1
minibatch_size = 200
num_epochs = 200

# seed
#torch.random.manual_seed(0)
# use tqdm or just have print statements
tqdm = False
# use data to initialize inducing stuff
inducing_data_initialization = False
# use natural gradients and/or CIQ
use_ngd = False
use_ciq = False
num_contour_quadrature=15
# learning rate
learning_rate_hypers = 0.01
learning_rate_ngd    = 0.1
gamma  = 10.0
#levels = np.array([20,150,300])
#def lr_sched(epoch):
#  a = np.sum(levels > epoch)
#  return (1./gamma)**a
lr_sched = None

n is:  40000
dims is:  8


In [5]:
# train-test split
n_train = int(0.8*dataset.n)
n_test  = n - n_train
train_dataset,test_dataset = torch.utils.data.random_split(dataset,[n_train,n_test])

In [6]:
#loaders
train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=n_test, shuffle=False)

In [7]:
test_y = [item[1] for item in test_loader]
test_x = [item[0] for item in test_loader]

# D-Free Grad SVGP

In [8]:
# train
print("\n\n---DirectionalGradVGP---")
print(f"Start training with {n} trainig data of dim {dim}")
print(f"VI setups: {num_inducing} inducing points, {num_directions} inducing directions")
args={"verbose":True}
t1 = time.time()	
model,likelihood = train_gp(train_dataset,
                      num_inducing=num_inducing,
                      num_directions=num_directions,
                      minibatch_size = minibatch_size,
                      minibatch_dim = num_directions,
                      num_epochs =num_epochs, 
                      learning_rate_hypers=learning_rate_hypers,
                      learning_rate_ngd=learning_rate_ngd,
                      inducing_data_initialization=inducing_data_initialization,
                      use_ngd = use_ngd,
                      use_ciq = use_ciq,
                      mll_type="PLL",
                      lr_sched=lr_sched,
                      num_contour_quadrature=num_contour_quadrature,
                      tqdm=tqdm,**args
                      )
t2 = time.time()	

# save the model
# torch.save(model.state_dict(), "../data/test_dvi_basic.model")

# test
means, variances = eval_gp( test_dataset,model,likelihood,
                            num_directions=num_directions,
                            minibatch_size=n_test,
                            minibatch_dim=num_directions)
t3 = time.time()	




---DirectionalGradVGP---
Start training with 40000 trainig data of dim 8
VI setups: 500 inducing points, 1 inducing directions
All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([500, 8])
      variational_strategy.inducing_directions
      torch.Size([500, 8])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([1000])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([1000, 1000])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  1009004.0
Epoch: 0; total_step: 0, loss: 1.4937641155393457, nll: 1.3879609849590495
Epoch: 0; total_step: 50, loss: 1.146034675818015, nll: 1.0122134678864874
Epoch: 0; total_step: 100, loss: 0.8841277832947706, nll: 0.7472111924479976
Epoch: 

Epoch: 28; total_step: 4600, loss: -0.4084235052841079, nll: -0.45692333944777275
Epoch: 29; total_step: 4650, loss: -0.4868173287058397, nll: -0.5105758356137842
Epoch: 29; total_step: 4700, loss: -0.36521306824051664, nll: -0.4350045354511733
Epoch: 29; total_step: 4750, loss: -0.3676301780937219, nll: -0.4708047826794263
Epoch: 30; total_step: 4800, loss: -0.44542184498265086, nll: -0.4885635549731722
Epoch: 30; total_step: 4850, loss: -0.380440308423909, nll: -0.47828205447565353
Epoch: 30; total_step: 4900, loss: -0.4944713524784389, nll: -0.5931832450262882
Epoch: 30; total_step: 4950, loss: -0.4411327370007445, nll: -0.4865207210783665
Epoch: 31; total_step: 5000, loss: -0.4344551964810575, nll: -0.4752987097855476
Epoch: 31; total_step: 5050, loss: -0.3763806388879524, nll: -0.4412685799474025
Epoch: 31; total_step: 5100, loss: -0.5250862575697238, nll: -0.55174239703253
Epoch: 32; total_step: 5150, loss: -0.4275291102913739, nll: -0.4636718112276258
Epoch: 32; total_step: 5200

Epoch: 60; total_step: 9650, loss: -0.5654755696253148, nll: -0.5446416916422145
Epoch: 60; total_step: 9700, loss: -0.5640641903931349, nll: -0.6260706267720311
Epoch: 60; total_step: 9750, loss: -0.44538078826948835, nll: -0.5470982342158239
Epoch: 61; total_step: 9800, loss: -0.4134069331058521, nll: -0.39899907363999904
Epoch: 61; total_step: 9850, loss: -0.5788842962876725, nll: -0.6188076329846502
Epoch: 61; total_step: 9900, loss: -0.5027422089544268, nll: -0.5991511810406834
Epoch: 62; total_step: 9950, loss: -0.566616749438346, nll: -0.596072199980356
Epoch: 62; total_step: 10000, loss: -0.44261893631617555, nll: -0.6095299132278771
Epoch: 62; total_step: 10050, loss: -0.6151193760403204, nll: -0.689834718109649
Epoch: 63; total_step: 10100, loss: -0.5140510851691569, nll: -0.597732698512701
Epoch: 63; total_step: 10150, loss: -0.524332052929249, nll: -0.5813292268421238
Epoch: 63; total_step: 10200, loss: -0.3840585705500631, nll: -0.432563421197799
Epoch: 64; total_step: 102

Epoch: 91; total_step: 14700, loss: -0.5751105745438118, nll: -0.627194747480551
Epoch: 92; total_step: 14750, loss: -0.5886620708586077, nll: -0.618731468721844
Epoch: 92; total_step: 14800, loss: -0.5696022992318486, nll: -0.643434154246446
Epoch: 92; total_step: 14850, loss: -0.4236751362731256, nll: -0.4498032643109754
Epoch: 93; total_step: 14900, loss: -0.5847459060403492, nll: -0.6523676095738783
Epoch: 93; total_step: 14950, loss: -0.6342318195171321, nll: -0.6888296948343512
Epoch: 93; total_step: 15000, loss: -0.48688023043854867, nll: -0.5052936672867493
Epoch: 94; total_step: 15050, loss: -0.539677014843163, nll: -0.5299082385913658
Epoch: 94; total_step: 15100, loss: -0.5360640636604674, nll: -0.649718271611817
Epoch: 94; total_step: 15150, loss: -0.6748946216484524, nll: -0.7813500983400581
Epoch: 95; total_step: 15200, loss: -0.5476766849434886, nll: -0.6748731393614801
Epoch: 95; total_step: 15250, loss: -0.3990413945884265, nll: -0.4587870340710352
Epoch: 95; total_ste

Epoch: 123; total_step: 19700, loss: -0.6745524233383022, nll: -0.6778381237001448
Epoch: 123; total_step: 19750, loss: -0.5872114106395179, nll: -0.6258329335076905
Epoch: 123; total_step: 19800, loss: -0.6591337217102294, nll: -0.7505307305033526
Epoch: 124; total_step: 19850, loss: -0.671149161599478, nll: -0.7696285417973233
Epoch: 124; total_step: 19900, loss: -0.5374264757396842, nll: -0.5103418736276152
Epoch: 124; total_step: 19950, loss: -0.5939926456415354, nll: -0.6468479097776852
Epoch: 125; total_step: 20000, loss: -0.5984078006776054, nll: -0.6774681675525548
Epoch: 125; total_step: 20050, loss: -0.6320286731960394, nll: -0.6907380835599162
Epoch: 125; total_step: 20100, loss: -0.5004854142834769, nll: -0.6271776480014335
Epoch: 125; total_step: 20150, loss: -0.438923684400071, nll: -0.603852547092814
Epoch: 126; total_step: 20200, loss: -0.6278761307212539, nll: -0.7110532934778864
Epoch: 126; total_step: 20250, loss: -0.5746170694044741, nll: -0.661725882511394
Epoch: 1

Epoch: 154; total_step: 24650, loss: -0.601682794512223, nll: -0.7780470207761035
Epoch: 154; total_step: 24700, loss: -0.5895118230092108, nll: -0.6716892254286533
Epoch: 154; total_step: 24750, loss: -0.38125723257475996, nll: -0.40089707317255474
Epoch: 155; total_step: 24800, loss: -0.6058570564714045, nll: -0.7188390327105703
Epoch: 155; total_step: 24850, loss: -0.568592915646357, nll: -0.7387262509710404
Epoch: 155; total_step: 24900, loss: -0.6191742693460535, nll: -0.7547086801073972
Epoch: 155; total_step: 24950, loss: -0.7201803538430172, nll: -0.8504727925582622
Epoch: 156; total_step: 25000, loss: -0.6038008383521454, nll: -0.6438741336452165
Epoch: 156; total_step: 25050, loss: -0.6380658918252601, nll: -0.7074048034911299
Epoch: 156; total_step: 25100, loss: -0.5560044867218142, nll: -0.6919696640938905
Epoch: 157; total_step: 25150, loss: -0.570206845085784, nll: -0.7647542310905615
Epoch: 157; total_step: 25200, loss: -0.6220611074950869, nll: -0.7976033929958163
Epoch

Epoch: 185; total_step: 29600, loss: -0.6017807160501042, nll: -0.7529807693102533
Epoch: 185; total_step: 29650, loss: -0.5098247841873167, nll: -0.5597056423734185
Epoch: 185; total_step: 29700, loss: -0.5283883980253953, nll: -0.5249657149813255
Epoch: 185; total_step: 29750, loss: -0.5775649305029864, nll: -0.587044603681556
Epoch: 186; total_step: 29800, loss: -0.49590190449943305, nll: -0.6321459581812544
Epoch: 186; total_step: 29850, loss: -0.6104715280091942, nll: -0.7417684139927403
Epoch: 186; total_step: 29900, loss: -0.5298285234381578, nll: -0.6367453449503284
Epoch: 187; total_step: 29950, loss: -0.557247354885088, nll: -0.7062687645444652
Epoch: 187; total_step: 30000, loss: -0.5688101543573884, nll: -0.7575597191133642
Epoch: 187; total_step: 30050, loss: -0.5872779158815091, nll: -0.6250809292947564
Epoch: 188; total_step: 30100, loss: -0.603461143841553, nll: -0.6861148195272321
Epoch: 188; total_step: 30150, loss: -0.6298638893919118, nll: -0.7834587443389444
Epoch:

In [9]:

# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means, variances.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

#plot=1
#if plot == 1:
#    from mpl_toolkits.mplot3d import axes3d
#    import matplotlib.pyplot as plt
#    fig = plt.figure(figsize=(12,6))
#    ax = fig.add_subplot(111, projection='3d')
#    ax.scatter(test_x[0][:,0],test_x[:,1],test_y, color='k')
#    ax.scatter(test_x[0][:,0],test_x[:,1],means, color='b')
#    plt.title("f(x,y) variational fit; actual curve is black, variational is blue")
#    plt.show()

At 8000 testing points, MSE: 6.0187e-02, nll: -7.0207e-01.
Training time: 15444.24 sec, testing time: 7.04 sec


In [10]:
# training params
#num_inducing = 50
#num_directions = 6
#minibatch_size = 200
#num_epochs = 100


# 2 directions
#At 104 testing points, MSE: 2.9133e+00, nll: 3.3945e+00. 
# 3 directions
#At 104 testing points, MSE: 2.9455e+00, nll: 3.3617e+00.
#Training time: 70.29 sec, testing time: 0.10 sec
# 4 directions
#At 104 testing points, MSE: 2.9810e+00, nll: 3.0743e+00.
#Training time: 57.68 sec, testing time: 0.08 sec
# 5 directions
#At 104 testing points, MSE: 2.9440e+00, nll: 3.6124e+00.
#Training time: 104.46 sec, testing time: 0.12 sec
# 6 directions
#At 104 testing points, MSE: 2.9795e+00, nll: 3.1092e+00.
#Training time: 127.73 sec, testing time: 0.10 sec
# 7 directions
#At 104 testing points, MSE: 2.9272e+00, nll: 3.6537e+00.
#Training time: 153.38 sec, testing time: 0.12 sec
# 8 directions
#At 104 testing points, MSE: 2.9503e+00, nll: 3.3300e+00.
#Training time: 173.86 sec, testing time: 0.15 sec
# 9 directions
# 10 directions

# Traditional SVGP

In [11]:
model_t,likelihood_t = traditional_vi.train_gp(train_dataset,dim,
                                                   num_inducing=2*num_inducing,
                                                   minibatch_size=minibatch_size,
                                                   num_epochs=num_epochs,
                                                   mll_type="PLL",
                                                   use_ngd=use_ngd, use_ciq=use_ciq,
                                                   learning_rate_hypers=learning_rate_hypers,
                                                   learning_rate_ngd=learning_rate_ngd,
                                                   lr_sched=lr_sched,
                                                   num_contour_quadrature=num_contour_quadrature,gamma=gamma, verbose=True)

All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([1000, 8])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([1000])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([1000, 1000])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  1009004.0
Using PLL
Epoch: 0; total_step: 0, loss: 1.5019231833305076, nll: 1.40762204300697
Epoch: 0; total_step: 50, loss: 1.1449864120182502, nll: 0.9683286665793658
Epoch: 0; total_step: 100, loss: 0.9352204341762527, nll: 0.7588101250144059
Epoch: 0; total_step: 150, loss: 0.7240399202380806, nll: 0.5472806152254704
Epoch: 1; total_step: 200, loss: 0.5095031497091773, nll: 0.33015579104268356
Epoch: 1; total_step: 250, loss: 0.4543547592

Epoch: 29; total_step: 4700, loss: -0.2862052268465981, nll: -0.4390834565948192
Epoch: 29; total_step: 4750, loss: -0.22905732452644828, nll: -0.3809189584483803
Epoch: 30; total_step: 4800, loss: -0.30730179713830097, nll: -0.4605158369528219
Epoch: 30; total_step: 4850, loss: -0.05073220191716271, nll: -0.20360984324162967
Epoch: 30; total_step: 4900, loss: -0.188563724777414, nll: -0.3419599694490268
Epoch: 30; total_step: 4950, loss: -0.3461752115576148, nll: -0.5001883630729745
Epoch: 31; total_step: 5000, loss: -0.13305015904013084, nll: -0.2872292348280071
Epoch: 31; total_step: 5050, loss: -0.1663833005129018, nll: -0.3200550863239626
Epoch: 31; total_step: 5100, loss: -0.23705321637692706, nll: -0.39191368311682956
Epoch: 32; total_step: 5150, loss: -0.15753604404694388, nll: -0.3121562821575694
Epoch: 32; total_step: 5200, loss: -0.1982974671876867, nll: -0.35384078086305293
Epoch: 32; total_step: 5250, loss: -0.29983212574614254, nll: -0.4550510156078414
Epoch: 33; total_st

Epoch: 60; total_step: 9700, loss: -0.2979448898599959, nll: -0.464230583431302
Epoch: 60; total_step: 9750, loss: -0.27572236154908836, nll: -0.44264607898183633
Epoch: 61; total_step: 9800, loss: -0.24605156053702687, nll: -0.41323720342456943
Epoch: 61; total_step: 9850, loss: -0.31510348175083025, nll: -0.4819528277544633
Epoch: 61; total_step: 9900, loss: -0.24306938926964022, nll: -0.4101973625020115
Epoch: 62; total_step: 9950, loss: -0.2951945795010646, nll: -0.4627842506352764
Epoch: 62; total_step: 10000, loss: -0.1689184203480299, nll: -0.3362806204535187
Epoch: 62; total_step: 10050, loss: -0.20587363520056073, nll: -0.3729093269515934
Epoch: 63; total_step: 10100, loss: -0.2645864529687218, nll: -0.43162057009495086
Epoch: 63; total_step: 10150, loss: -0.3109831754404695, nll: -0.47850068543809443
Epoch: 63; total_step: 10200, loss: -0.2735407582013616, nll: -0.44098957238192926
Epoch: 64; total_step: 10250, loss: -0.2116790090075631, nll: -0.3792665776108916
Epoch: 64; to

Epoch: 91; total_step: 14650, loss: -0.20859510175887186, nll: -0.3790873388973588
Epoch: 91; total_step: 14700, loss: -0.1250944253478885, nll: -0.2945468166983193
Epoch: 92; total_step: 14750, loss: -0.2500346828521039, nll: -0.41956295718257325
Epoch: 92; total_step: 14800, loss: -0.18752447381984266, nll: -0.35726963969230313
Epoch: 92; total_step: 14850, loss: -0.25096808987745844, nll: -0.420797705691897
Epoch: 93; total_step: 14900, loss: -0.2509137115363535, nll: -0.4209925327730673
Epoch: 93; total_step: 14950, loss: -0.2414312918276657, nll: -0.41282764526994953
Epoch: 93; total_step: 15000, loss: -0.25234373320142367, nll: -0.42309557520495406
Epoch: 94; total_step: 15050, loss: -0.2948926612207818, nll: -0.4649751684336467
Epoch: 94; total_step: 15100, loss: -0.10209624389798727, nll: -0.2718182187567604
Epoch: 94; total_step: 15150, loss: -0.346484689434545, nll: -0.5165826472996737
Epoch: 95; total_step: 15200, loss: -0.3101962592815029, nll: -0.48031659794801307
Epoch: 9

Epoch: 122; total_step: 19600, loss: -0.3479023972976737, nll: -0.520464135242398
Epoch: 122; total_step: 19650, loss: -0.18099382234883138, nll: -0.3536115141092539
Epoch: 123; total_step: 19700, loss: -0.1828601255345994, nll: -0.3549065800025965
Epoch: 123; total_step: 19750, loss: -0.29186881204257065, nll: -0.4641164510539142
Epoch: 123; total_step: 19800, loss: -0.2315898244880335, nll: -0.40339693013896477
Epoch: 124; total_step: 19850, loss: -0.2772117171710331, nll: -0.4488012145484108
Epoch: 124; total_step: 19900, loss: -0.2489886388053268, nll: -0.4205119698451765
Epoch: 124; total_step: 19950, loss: -0.31800302378830736, nll: -0.490063421463316
Epoch: 125; total_step: 20000, loss: -0.25282933362455423, nll: -0.42411307348016064
Epoch: 125; total_step: 20050, loss: -0.17813552073397476, nll: -0.34948874537614627
Epoch: 125; total_step: 20100, loss: -0.26340998098013785, nll: -0.43441665113577527
Epoch: 125; total_step: 20150, loss: -0.25788753106312023, nll: -0.429775482266

Epoch: 153; total_step: 24500, loss: -0.2124617799594069, nll: -0.3854356725581313
Epoch: 153; total_step: 24550, loss: -0.3452807470400656, nll: -0.5186390572253707
Epoch: 153; total_step: 24600, loss: -0.25982314762017805, nll: -0.4335523099627646
Epoch: 154; total_step: 24650, loss: -0.1195830213578879, nll: -0.2927634129402502
Epoch: 154; total_step: 24700, loss: -0.2268255048996175, nll: -0.40048798254959694
Epoch: 154; total_step: 24750, loss: -0.29180517847270265, nll: -0.4648133932380161
Epoch: 155; total_step: 24800, loss: -0.33230440991669996, nll: -0.5053943991315621
Epoch: 155; total_step: 24850, loss: -0.22242982846826864, nll: -0.3953079197691908
Epoch: 155; total_step: 24900, loss: -0.29624060913239264, nll: -0.4694391850033972
Epoch: 155; total_step: 24950, loss: -0.2511396078146639, nll: -0.4240161280555418
Epoch: 156; total_step: 25000, loss: -0.21511108299539922, nll: -0.3881168520402729
Epoch: 156; total_step: 25050, loss: -0.13338443623462726, nll: -0.3060882706360

Epoch: 183; total_step: 29400, loss: -0.24561835347848715, nll: -0.4189015014952933
Epoch: 184; total_step: 29450, loss: -0.16406774019055453, nll: -0.3373960372305364
Epoch: 184; total_step: 29500, loss: -0.311621269083484, nll: -0.4854999930515717
Epoch: 184; total_step: 29550, loss: -0.33011952415203316, nll: -0.5035287351813637
Epoch: 185; total_step: 29600, loss: -0.3209167396653334, nll: -0.49443289216854786
Epoch: 185; total_step: 29650, loss: -0.24623374183752406, nll: -0.41970706944232594
Epoch: 185; total_step: 29700, loss: -0.2670853039093626, nll: -0.43981467711299077
Epoch: 185; total_step: 29750, loss: -0.3190315168420275, nll: -0.49220379502631617
Epoch: 186; total_step: 29800, loss: -0.29507291240269157, nll: -0.4688979164420487
Epoch: 186; total_step: 29850, loss: -0.2332880137106325, nll: -0.40771868483771406
Epoch: 186; total_step: 29900, loss: -0.2667445461626785, nll: -0.44017384535069554
Epoch: 187; total_step: 29950, loss: -0.2865079577377195, nll: -0.46034214019

In [12]:
means_t, variances_t = traditional_vi.eval_gp(test_dataset, model_t, likelihood_t, minibatch_size=n_test)

In [13]:
# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means_t)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means_t, variances_t.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

At 8000 testing points, MSE: 8.2751e-02, nll: -4.3636e-01.
Training time: 15444.24 sec, testing time: 7.04 sec


In [14]:
#
# At 8000 testing points, MSE: 3.8732e-02, nll: -1.8686e-01.
# At 8000 testing points, MSE: 6.8093e-02, nll: 1.2315e-01.