In [1]:
import math
import numpy as np
import torch
import gpytorch
import tqdm
import random
import time
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
sys.path.append("../")
sys.path.append("../directionalvi/utils")
sys.path.append("../directionalvi")
import traditional_vi
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy
from dfree_directional_vi import train_gp, eval_gp
from metrics import MSE
import testfun
from csv_dataset import csv_dataset

In [2]:
dataset = csv_dataset("../experiments/real_data/CASP.csv", gradients=False, rescale=True)

[0 1 2 3 4 5 6 7 8]
[9]


In [3]:
dataset.dim

9

In [4]:
# data parameters
n   = dataset.n
print("n is: ", n)
dim = dataset.dim
print("dims is: ", dim)

# training params
num_inducing = 500
num_directions = 1
minibatch_size = 200
num_epochs = 100

# seed
torch.random.manual_seed(0)
# use tqdm or just have print statements
tqdm = False
# use data to initialize inducing stuff
inducing_data_initialization = False
# use natural gradients and/or CIQ
use_ngd = False
use_ciq = False
num_contour_quadrature=15
# learning rate
learning_rate_hypers = 0.01
learning_rate_ngd    = 0.1
gamma  = 10.0
#levels = np.array([20,150,300])
#def lr_sched(epoch):
#  a = np.sum(levels > epoch)
#  return (1./gamma)**a
lr_sched = None

n is:  45730
dims is:  9


In [5]:
# train-test split
n_train = int(0.8*dataset.n)
n_test  = n - n_train
train_dataset,test_dataset = torch.utils.data.random_split(dataset,[n_train,n_test])

In [6]:
#loaders
train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=n_test, shuffle=False)

In [7]:
test_y = [item[1] for item in test_loader]
test_x = [item[0] for item in test_loader]

# D-Free Grad SVGP

In [8]:
# train
print("\n\n---DirectionalGradVGP---")
print(f"Start training with {n} trainig data of dim {dim}")
print(f"VI setups: {num_inducing} inducing points, {num_directions} inducing directions")
args={"verbose":True}
t1 = time.time()	
model,likelihood = train_gp(train_dataset,
                      num_inducing=num_inducing,
                      num_directions=num_directions,
                      minibatch_size = minibatch_size,
                      minibatch_dim = num_directions,
                      num_epochs =num_epochs, 
                      learning_rate_hypers=learning_rate_hypers,
                      learning_rate_ngd=learning_rate_ngd,
                      inducing_data_initialization=inducing_data_initialization,
                      use_ngd = use_ngd,
                      use_ciq = use_ciq,
                      lr_sched=lr_sched,
                      num_contour_quadrature=num_contour_quadrature,
                      tqdm=tqdm,**args
                      )
t2 = time.time()	

# save the model
# torch.save(model.state_dict(), "../data/test_dvi_basic.model")

# test
means, variances = eval_gp( test_dataset,model,likelihood,
                            num_directions=num_directions,
                            minibatch_size=n_test,
                            minibatch_dim=num_directions)
t3 = time.time()	




---DirectionalGradVGP---
Start training with 45730 trainig data of dim 9
VI setups: 500 inducing points, 1 inducing directions
All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([500, 9])
      variational_strategy.inducing_directions
      torch.Size([500, 9])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([1000])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([1000, 1000])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  1010004.0
Epoch: 0; total_step: 0, loss: 2.5546049320523667, nll: 1.5186942597074
Epoch: 0; total_step: 50, loss: 1.8084012538245064, nll: 1.2835976646088023
Epoch: 0; total_step: 100, loss: 1.738893468901481, nll: 1.2476965661885733
Epoch: 0; 

Epoch: 26; total_step: 4800, loss: 1.6923812909118132, nll: 1.1830591796899028
Epoch: 26; total_step: 4850, loss: 1.7279505722410418, nll: 1.2938393524526213
Epoch: 26; total_step: 4900, loss: 1.6560486002179808, nll: 1.108975642179913
Epoch: 27; total_step: 4950, loss: 1.7134121522834267, nll: 1.1572707121263088
Epoch: 27; total_step: 5000, loss: 1.724345091635982, nll: 1.2428133554393332
Epoch: 27; total_step: 5050, loss: 1.5965660225373908, nll: 1.1202069650004405
Epoch: 27; total_step: 5100, loss: 1.6377579739756003, nll: 1.1053572042307842
Epoch: 28; total_step: 5150, loss: 1.733804238646575, nll: 1.2491318485294198
Epoch: 28; total_step: 5200, loss: 1.5890153847511699, nll: 1.0759993921752327
Epoch: 28; total_step: 5250, loss: 1.6236911549256186, nll: 1.1072626290627319
Epoch: 28; total_step: 5300, loss: 1.647873352489607, nll: 1.1629221651485795
Epoch: 29; total_step: 5350, loss: 1.6677832037970548, nll: 1.179505137547505
Epoch: 29; total_step: 5400, loss: 1.6269227297572149, nl

Epoch: 54; total_step: 10050, loss: 1.6706827444833907, nll: 1.1496090401682362
Epoch: 55; total_step: 10100, loss: 1.7015126660833542, nll: 1.1503766073143096
Epoch: 55; total_step: 10150, loss: 1.7066115151270884, nll: 1.2539013959640177
Epoch: 55; total_step: 10200, loss: 1.6827154900518873, nll: 1.187986188303058
Epoch: 56; total_step: 10250, loss: 1.6768109421786275, nll: 1.2028768930744276
Epoch: 56; total_step: 10300, loss: 1.7123309400587114, nll: 1.1878174568799214
Epoch: 56; total_step: 10350, loss: 1.6833170567212454, nll: 1.0950377601143624
Epoch: 56; total_step: 10400, loss: 1.6240748514684578, nll: 1.1179898183492873
Epoch: 57; total_step: 10450, loss: 1.658434227334724, nll: 1.1676971981896855
Epoch: 57; total_step: 10500, loss: 1.6558806747035442, nll: 1.1274239779728714
Epoch: 57; total_step: 10550, loss: 1.709263563917849, nll: 1.1446303812432084
Epoch: 57; total_step: 10600, loss: 1.6681375438930304, nll: 1.2503702576294424
Epoch: 58; total_step: 10650, loss: 1.67722

Epoch: 83; total_step: 15200, loss: 1.6839627063067715, nll: 1.2024427466861944
Epoch: 83; total_step: 15250, loss: 1.5951114637284272, nll: 1.0958603452108417
Epoch: 83; total_step: 15300, loss: 1.594004563265631, nll: 1.0991652003384216
Epoch: 83; total_step: 15350, loss: 1.7034791582116935, nll: 1.308489066748954
Epoch: 84; total_step: 15400, loss: 1.6462517459709358, nll: 1.1236176296845128
Epoch: 84; total_step: 15450, loss: 1.6515469276626118, nll: 1.082604549391296
Epoch: 84; total_step: 15500, loss: 1.7322496070668216, nll: 1.2874041766892328
Epoch: 84; total_step: 15550, loss: 1.6460587487235818, nll: 1.0920256993713675
Epoch: 85; total_step: 15600, loss: 1.7014415017840545, nll: 1.2074432349266124
Epoch: 85; total_step: 15650, loss: 1.749313536146043, nll: 1.2615194371492788
Epoch: 85; total_step: 15700, loss: 1.6894229305707764, nll: 1.199971279088559
Epoch: 86; total_step: 15750, loss: 1.6434336571989494, nll: 1.1923088678283116
Epoch: 86; total_step: 15800, loss: 1.5647919

In [9]:

# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means, variances.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

#plot=1
#if plot == 1:
#    from mpl_toolkits.mplot3d import axes3d
#    import matplotlib.pyplot as plt
#    fig = plt.figure(figsize=(12,6))
#    ax = fig.add_subplot(111, projection='3d')
#    ax.scatter(test_x[0][:,0],test_x[:,1],test_y, color='k')
#    ax.scatter(test_x[0][:,0],test_x[:,1],means, color='b')
#    plt.title("f(x,y) variational fit; actual curve is black, variational is blue")
#    plt.show()

At 9146 testing points, MSE: 5.9555e-01, nll: 1.1599e+00.
Training time: 8148.59 sec, testing time: 6.37 sec


In [10]:
# training params
#num_inducing = 50
#num_directions = 6
#minibatch_size = 200
#num_epochs = 100


# 2 directions
#At 104 testing points, MSE: 2.9133e+00, nll: 3.3945e+00. 
# 3 directions
#At 104 testing points, MSE: 2.9455e+00, nll: 3.3617e+00.
#Training time: 70.29 sec, testing time: 0.10 sec
# 4 directions
#At 104 testing points, MSE: 2.9810e+00, nll: 3.0743e+00.
#Training time: 57.68 sec, testing time: 0.08 sec
# 5 directions
#At 104 testing points, MSE: 2.9440e+00, nll: 3.6124e+00.
#Training time: 104.46 sec, testing time: 0.12 sec
# 6 directions
#At 104 testing points, MSE: 2.9795e+00, nll: 3.1092e+00.
#Training time: 127.73 sec, testing time: 0.10 sec
# 7 directions
#At 104 testing points, MSE: 2.9272e+00, nll: 3.6537e+00.
#Training time: 153.38 sec, testing time: 0.12 sec
# 8 directions
#At 104 testing points, MSE: 2.9503e+00, nll: 3.3300e+00.
#Training time: 173.86 sec, testing time: 0.15 sec
# 9 directions
# 10 directions

# Traditional SVGP

In [None]:
model_t,likelihood_t = traditional_vi.train_gp(train_dataset,dim,
                                                   num_inducing=num_inducing,
                                                   minibatch_size=minibatch_size,
                                                   num_epochs=num_epochs,
                                                   use_ngd=use_ngd, use_ciq=use_ciq,
                                                   learning_rate_hypers=learning_rate_hypers,
                                                   learning_rate_ngd=learning_rate_ngd,
                                                   lr_sched=lr_sched,
                                                   num_contour_quadrature=num_contour_quadrature,gamma=gamma, verbose=True)

All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([500, 9])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([500])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([500, 500])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  255004.0
Using ELBO
Epoch: 0; total_step: 0, loss: 2.4442716960932573, nll: 1.4365865539417353
Epoch: 0; total_step: 50, loss: 1.7168978216076511, nll: 1.2087576877397188
Epoch: 0; total_step: 100, loss: 1.7202699924596345, nll: 1.2138562715051886
Epoch: 0; total_step: 150, loss: 1.8007331416387604, nll: 1.292580674154088
Epoch: 1; total_step: 200, loss: 1.7723763100538963, nll: 1.2660390302183935
Epoch: 1; total_step: 250, loss: 1.70648991986953

Epoch: 26; total_step: 4900, loss: 1.7084710510936332, nll: 1.1939964991634509
Epoch: 27; total_step: 4950, loss: 1.6579208110696018, nll: 1.1421964844336747
Epoch: 27; total_step: 5000, loss: 1.642699350379093, nll: 1.1263757967483705
Epoch: 27; total_step: 5050, loss: 1.7020007229416558, nll: 1.1867606771095718
Epoch: 27; total_step: 5100, loss: 1.698396390565223, nll: 1.176132911077666
Epoch: 28; total_step: 5150, loss: 1.6424828448703577, nll: 1.127448315729192
Epoch: 28; total_step: 5200, loss: 1.7217072611572355, nll: 1.2061330396904466
Epoch: 28; total_step: 5250, loss: 1.7334169749150419, nll: 1.218828464829425
Epoch: 28; total_step: 5300, loss: 1.5877762131213442, nll: 1.0748802074089843
Epoch: 29; total_step: 5350, loss: 1.789768189761249, nll: 1.273678032564406
Epoch: 29; total_step: 5400, loss: 1.6807936307250038, nll: 1.1652577370640365
Epoch: 29; total_step: 5450, loss: 1.659871177619687, nll: 1.1449087106979219
Epoch: 30; total_step: 5500, loss: 1.6562798060202502, nll: 

Epoch: 55; total_step: 10150, loss: 1.6620246935580198, nll: 1.1452536309040147
Epoch: 55; total_step: 10200, loss: 1.6637382857324485, nll: 1.1460438906403974
Epoch: 56; total_step: 10250, loss: 1.6438613689799308, nll: 1.125650629461636
Epoch: 56; total_step: 10300, loss: 1.684403183768155, nll: 1.1669278098806792
Epoch: 56; total_step: 10350, loss: 1.63494902458216, nll: 1.1188602060576018
Epoch: 56; total_step: 10400, loss: 1.7109726446789342, nll: 1.1924193313613336
Epoch: 57; total_step: 10450, loss: 1.6606755793972514, nll: 1.1426327701473746
Epoch: 57; total_step: 10500, loss: 1.6793099510861738, nll: 1.1584034498697735
Epoch: 57; total_step: 10550, loss: 1.64300425846966, nll: 1.1249023694782325
Epoch: 57; total_step: 10600, loss: 1.6989660245047702, nll: 1.1811847467103447
Epoch: 58; total_step: 10650, loss: 1.7946042352718155, nll: 1.2681708874363893
Epoch: 58; total_step: 10700, loss: 1.7341551243724902, nll: 1.2162171247054951
Epoch: 58; total_step: 10750, loss: 1.65202566

In [None]:
means_t, variances_t = traditional_vi.eval_gp(test_dataset, model_t, likelihood_t, minibatch_size=n_test)

In [None]:
# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means_t)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means_t, variances_t.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")