In [31]:
import math
import numpy as np
import torch
import gpytorch
import tqdm
import random
import time
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
sys.path.append("../")
sys.path.append("../directionalvi/utils")
sys.path.append("../directionalvi")
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy
from dfree_directional_vi import train_gp, eval_gp
from metrics import MSE
import testfun
from csv_dataset import csv_dataset

In [17]:
dataset = csv_dataset("../experiments/real_data/CASP.csv", gradients=False, rescale=True)

[0 1 2 3 4 5 6 7 8]
[9]


In [18]:
dataset.dim

9

In [28]:
# data parameters
n   = dataset.n
print("n is: ", n)
dim = dataset.dim
print("dims is: ", dim)


# training params
num_inducing = 20
num_directions = 2
minibatch_size = 200
num_epochs = 400

# seed
torch.random.manual_seed(0)
# use tqdm or just have print statements
tqdm = False
# use data to initialize inducing stuff
inducing_data_initialization = False
# use natural gradients and/or CIQ
use_ngd = False
use_ciq = False
num_contour_quadrature=15
# learning rate
learning_rate_hypers = 0.01
learning_rate_ngd    = 0.1
gamma  = 10.0
#levels = np.array([20,150,300])
#def lr_sched(epoch):
#  a = np.sum(levels > epoch)
#  return (1./gamma)**a
lr_sched = None

n is:  45730
dims is:  9


In [29]:
# train-test split
n_train = int(0.8*dataset.n)
n_test  = n - n_train
train_dataset,test_dataset = torch.utils.data.random_split(dataset,[n_train,n_test])

In [30]:
#loaders
train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=n_test, shuffle=False)

In [None]:
# train
print("\n\n---DirectionalGradVGP---")
print(f"Start training with {n} trainig data of dim {dim}")
print(f"VI setups: {num_inducing} inducing points, {num_directions} inducing directions")
args={"verbose":True}
t1 = time.time()	
model,likelihood = train_gp(train_dataset,
                      num_inducing=num_inducing,
                      num_directions=num_directions,
                      minibatch_size = minibatch_size,
                      minibatch_dim = num_directions,
                      num_epochs =num_epochs, 
                      learning_rate_hypers=learning_rate_hypers,
                      learning_rate_ngd=learning_rate_ngd,
                      inducing_data_initialization=inducing_data_initialization,
                      use_ngd = use_ngd,
                      use_ciq = use_ciq,
                      lr_sched=lr_sched,
                      num_contour_quadrature=num_contour_quadrature,
                      tqdm=tqdm,**args
                      )
t2 = time.time()	

# save the model
# torch.save(model.state_dict(), "../data/test_dvi_basic.model")

# test
means, variances = eval_gp( test_dataset,model,likelihood,
                            num_directions=num_directions,
                            minibatch_size=n_test,
                            minibatch_dim=num_directions)
t3 = time.time()	

# compute MSE
test_y = test_y.cpu()
test_mse = MSE(test_y,means)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means, variances.sqrt()).log_prob(test_y).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

plot=1
if plot == 1:
    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(test_x[:,0],test_x[:,1],test_y, color='k')
    ax.scatter(test_x[:,0],test_x[:,1],means, color='b')
    plt.title("f(x,y) variational fit; actual curve is black, variational is blue")
    plt.show()



---DirectionalGradVGP---
Start training with 45730 trainig data of dim 9
VI setups: 20 inducing points, 2 inducing directions
All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([20, 9])
      variational_strategy.inducing_directions
      torch.Size([40, 9])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([60])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([60, 60])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  4204.0
Epoch: 0; total_step: 0, loss: 2.48255889340812, nll: 1.466237968519302
Epoch: 0; total_step: 50, loss: 1.8813469458247452, nll: 1.3391872372182174
Epoch: 0; total_step: 100, loss: 1.7913530470135421, nll: 1.2427956916297787
Epoch: 0; total_step:

Epoch: 26; total_step: 4800, loss: 1.7448775070806206, nll: 1.228376428952464
Epoch: 26; total_step: 4850, loss: 1.7295344077222117, nll: 1.1685200588841367
Epoch: 26; total_step: 4900, loss: 1.811709760842075, nll: 1.382430054376861
Epoch: 27; total_step: 4950, loss: 1.692392587123844, nll: 1.1861558559440724
Epoch: 27; total_step: 5000, loss: 1.7427730605154255, nll: 1.2575125157075808
Epoch: 27; total_step: 5050, loss: 1.6830897841881893, nll: 1.1871210224906055
Epoch: 27; total_step: 5100, loss: 1.666366889159417, nll: 1.2764668212755736
Epoch: 28; total_step: 5150, loss: 1.6941104279969128, nll: 1.138501688200133
Epoch: 28; total_step: 5200, loss: 1.7348677357795794, nll: 1.0439591575144858
Epoch: 28; total_step: 5250, loss: 1.748910437134861, nll: 1.2215529890786316
Epoch: 28; total_step: 5300, loss: 1.7103513295423136, nll: 1.3135744118448225
Epoch: 29; total_step: 5350, loss: 1.7611931711175912, nll: 1.3148821875204073
Epoch: 29; total_step: 5400, loss: 1.6237685446817047, nll: