In [31]:
import math
import numpy as np
import torch
import gpytorch
import tqdm
import random
import time
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
sys.path.append("../")
sys.path.append("../directionalvi/utils")
sys.path.append("../directionalvi")
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy
from dfree_directional_vi import train_gp, eval_gp
from metrics import MSE
import testfun
from csv_dataset import csv_dataset

In [17]:
dataset = csv_dataset("../experiments/real_data/CASP.csv", gradients=False, rescale=True)

[0 1 2 3 4 5 6 7 8]
[9]


In [18]:
dataset.dim

9

In [28]:
# data parameters
n   = dataset.n
print("n is: ", n)
dim = dataset.dim
print("dims is: ", dim)


# training params
num_inducing = 20
num_directions = 2
minibatch_size = 200
num_epochs = 400

# seed
torch.random.manual_seed(0)
# use tqdm or just have print statements
tqdm = False
# use data to initialize inducing stuff
inducing_data_initialization = False
# use natural gradients and/or CIQ
use_ngd = False
use_ciq = False
num_contour_quadrature=15
# learning rate
learning_rate_hypers = 0.01
learning_rate_ngd    = 0.1
gamma  = 10.0
#levels = np.array([20,150,300])
#def lr_sched(epoch):
#  a = np.sum(levels > epoch)
#  return (1./gamma)**a
lr_sched = None

n is:  45730
dims is:  9


In [29]:
# train-test split
n_train = int(0.8*dataset.n)
n_test  = n - n_train
train_dataset,test_dataset = torch.utils.data.random_split(dataset,[n_train,n_test])

In [30]:
#loaders
train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=n_test, shuffle=False)

In [32]:
# train
print("\n\n---DirectionalGradVGP---")
print(f"Start training with {n} trainig data of dim {dim}")
print(f"VI setups: {num_inducing} inducing points, {num_directions} inducing directions")
args={"verbose":True}
t1 = time.time()	
model,likelihood = train_gp(train_dataset,
                      num_inducing=num_inducing,
                      num_directions=num_directions,
                      minibatch_size = minibatch_size,
                      minibatch_dim = num_directions,
                      num_epochs =num_epochs, 
                      learning_rate_hypers=learning_rate_hypers,
                      learning_rate_ngd=learning_rate_ngd,
                      inducing_data_initialization=inducing_data_initialization,
                      use_ngd = use_ngd,
                      use_ciq = use_ciq,
                      lr_sched=lr_sched,
                      num_contour_quadrature=num_contour_quadrature,
                      tqdm=tqdm,**args
                      )
t2 = time.time()	

# save the model
# torch.save(model.state_dict(), "../data/test_dvi_basic.model")

# test
means, variances = eval_gp( test_dataset,model,likelihood,
                            num_directions=num_directions,
                            minibatch_size=n_test,
                            minibatch_dim=num_directions)
t3 = time.time()	

# compute MSE
test_y = test_y.cpu()
test_mse = MSE(test_y,means)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means, variances.sqrt()).log_prob(test_y).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

plot=1
if plot == 1:
    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(test_x[:,0],test_x[:,1],test_y, color='k')
    ax.scatter(test_x[:,0],test_x[:,1],means, color='b')
    plt.title("f(x,y) variational fit; actual curve is black, variational is blue")
    plt.show()



---DirectionalGradVGP---
Start training with 45730 trainig data of dim 9
VI setups: 20 inducing points, 2 inducing directions
All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([20, 9])
      variational_strategy.inducing_directions
      torch.Size([40, 9])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([60])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([60, 60])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  4204.0
Epoch: 0; total_step: 0, loss: 2.48255889340812, nll: 1.466237968519302
Epoch: 0; total_step: 50, loss: 1.8813469458247452, nll: 1.3391872372182174
Epoch: 0; total_step: 100, loss: 1.7913530470135421, nll: 1.2427956916297787
Epoch: 0; total_step:

Epoch: 26; total_step: 4800, loss: 1.7448775070806206, nll: 1.228376428952464
Epoch: 26; total_step: 4850, loss: 1.7295344077222117, nll: 1.1685200588841367
Epoch: 26; total_step: 4900, loss: 1.811709760842075, nll: 1.382430054376861
Epoch: 27; total_step: 4950, loss: 1.692392587123844, nll: 1.1861558559440724
Epoch: 27; total_step: 5000, loss: 1.7427730605154255, nll: 1.2575125157075808
Epoch: 27; total_step: 5050, loss: 1.6830897841881893, nll: 1.1871210224906055
Epoch: 27; total_step: 5100, loss: 1.666366889159417, nll: 1.2764668212755736
Epoch: 28; total_step: 5150, loss: 1.6941104279969128, nll: 1.138501688200133
Epoch: 28; total_step: 5200, loss: 1.7348677357795794, nll: 1.0439591575144858
Epoch: 28; total_step: 5250, loss: 1.748910437134861, nll: 1.2215529890786316
Epoch: 28; total_step: 5300, loss: 1.7103513295423136, nll: 1.3135744118448225
Epoch: 29; total_step: 5350, loss: 1.7611931711175912, nll: 1.3148821875204073
Epoch: 29; total_step: 5400, loss: 1.6237685446817047, nll:

Epoch: 54; total_step: 10050, loss: 1.683944287462075, nll: 1.1455930719718024
Epoch: 55; total_step: 10100, loss: 1.7241085559416638, nll: 1.2253638806913039
Epoch: 55; total_step: 10150, loss: 1.6913757986355749, nll: 1.1715421089006646
Epoch: 55; total_step: 10200, loss: 1.7375651658250524, nll: 1.2774907954040842
Epoch: 56; total_step: 10250, loss: 1.755276766986123, nll: 1.2822625603634665
Epoch: 56; total_step: 10300, loss: 1.7447717550050812, nll: 1.2701551705130676
Epoch: 56; total_step: 10350, loss: 1.6392187509639167, nll: 1.1063815864500415
Epoch: 56; total_step: 10400, loss: 1.7343976360619784, nll: 1.239116003956521
Epoch: 57; total_step: 10450, loss: 1.6621883878759558, nll: 1.0972228048560044
Epoch: 57; total_step: 10500, loss: 1.7868958575066693, nll: 1.3436551498300513
Epoch: 57; total_step: 10550, loss: 1.626368565383674, nll: 1.0973712976199022
Epoch: 57; total_step: 10600, loss: 1.6324383978841366, nll: 1.0990986913284568
Epoch: 58; total_step: 10650, loss: 1.683897

Epoch: 83; total_step: 15250, loss: 1.642103653343142, nll: 1.1113121257106724
Epoch: 83; total_step: 15300, loss: 1.7322053234608799, nll: 1.2253664063789735
Epoch: 83; total_step: 15350, loss: 1.6233499505501832, nll: 1.0823552269298509
Epoch: 84; total_step: 15400, loss: 1.6259340153113597, nll: 1.119461632570883
Epoch: 84; total_step: 15450, loss: 1.7138841837811463, nll: 1.1334451983895553
Epoch: 84; total_step: 15500, loss: 1.7489625400313826, nll: 1.0899744846229567
Epoch: 84; total_step: 15550, loss: 1.7258607077697001, nll: 1.1745224667347356
Epoch: 85; total_step: 15600, loss: 1.713451340484097, nll: 1.215503264052656
Epoch: 85; total_step: 15650, loss: 1.6894471738946453, nll: 1.2157102720478359
Epoch: 85; total_step: 15700, loss: 1.6481538128073072, nll: 1.1193411432148705
Epoch: 86; total_step: 15750, loss: 1.7048552161969206, nll: 1.2918868079802652
Epoch: 86; total_step: 15800, loss: 1.686216654788056, nll: 1.1615281163865658
Epoch: 86; total_step: 15850, loss: 1.7380448

Epoch: 111; total_step: 20400, loss: 1.6526005489857227, nll: 1.16984054445232
Epoch: 111; total_step: 20450, loss: 1.6718787571053317, nll: 1.1861078752555883
Epoch: 112; total_step: 20500, loss: 1.6403043225541574, nll: 1.08286728651369
Epoch: 112; total_step: 20550, loss: 1.7473723680390862, nll: 1.3263652871100426
Epoch: 112; total_step: 20600, loss: 1.6658656773212916, nll: 1.1074644433451533
Epoch: 112; total_step: 20650, loss: 1.6390595373601426, nll: 1.0669749518612999
Epoch: 113; total_step: 20700, loss: 1.6756352322305827, nll: 1.217395595878732
Epoch: 113; total_step: 20750, loss: 1.6874334797938417, nll: 1.180280770517902
Epoch: 113; total_step: 20800, loss: 1.6339642435977704, nll: 1.0903632158203156
Epoch: 113; total_step: 20850, loss: 1.6520517608821939, nll: 1.07320693886611
Epoch: 114; total_step: 20900, loss: 1.6760921897707726, nll: 1.1495840587679034
Epoch: 114; total_step: 20950, loss: 1.6690577292133713, nll: 1.146618734834777
Epoch: 114; total_step: 21000, loss: 

Epoch: 139; total_step: 25500, loss: 1.6202665464037507, nll: 1.0951962006239213
Epoch: 139; total_step: 25550, loss: 1.6996850537069441, nll: 1.1633291253266993
Epoch: 139; total_step: 25600, loss: 1.710704337025918, nll: 1.2570150092358443
Epoch: 140; total_step: 25650, loss: 1.7129344365035122, nll: 1.2505761874982422
Epoch: 140; total_step: 25700, loss: 1.6887521455398686, nll: 1.1578702375331484
Epoch: 140; total_step: 25750, loss: 1.6857118872642243, nll: 1.0390673596321942
Epoch: 140; total_step: 25800, loss: 1.6954566823983424, nll: 1.1709477064619231
Epoch: 141; total_step: 25850, loss: 1.6690170242458644, nll: 1.0583071371230581
Epoch: 141; total_step: 25900, loss: 1.7276863037201666, nll: 1.2069689646169652
Epoch: 141; total_step: 25950, loss: 1.7142917408678724, nll: 1.1788582862331456
Epoch: 142; total_step: 26000, loss: 1.643791453429493, nll: 1.0950508691689702
Epoch: 142; total_step: 26050, loss: 1.779791659769582, nll: 1.2315456353885899
Epoch: 142; total_step: 26100, 

Epoch: 167; total_step: 30600, loss: 1.729941457502878, nll: 1.0838789638561768
Epoch: 167; total_step: 30650, loss: 1.7066573615224012, nll: 1.2296167852327387
Epoch: 167; total_step: 30700, loss: 1.6113060193368602, nll: 1.0749211312713616
Epoch: 168; total_step: 30750, loss: 1.7329608093636202, nll: 1.2969797915657884
Epoch: 168; total_step: 30800, loss: 1.7927162943315278, nll: 1.2506158385329014
Epoch: 168; total_step: 30850, loss: 1.669983888855653, nll: 1.1808135599137437
Epoch: 168; total_step: 30900, loss: 1.7005586096230008, nll: 1.2032623681879424
Epoch: 169; total_step: 30950, loss: 1.7003238004868695, nll: 1.2287916297340176
Epoch: 169; total_step: 31000, loss: 1.6529349662314445, nll: 1.1715013922230146
Epoch: 169; total_step: 31050, loss: 1.6622546936074338, nll: 1.1827281371703895
Epoch: 169; total_step: 31100, loss: 1.6739839254338238, nll: 1.2562542183994936
Epoch: 170; total_step: 31150, loss: 1.6317408002262948, nll: 1.162229046515433
Epoch: 170; total_step: 31200, 

Epoch: 195; total_step: 35700, loss: 1.6200633169904044, nll: 1.0632510616441582
Epoch: 195; total_step: 35750, loss: 1.624881794592427, nll: 1.1386802575470114
Epoch: 195; total_step: 35800, loss: 1.6650514987187508, nll: 1.1566098459756988
Epoch: 195; total_step: 35850, loss: 1.6336828830505683, nll: 1.1978369176881825
Epoch: 196; total_step: 35900, loss: 1.6381849153777148, nll: 1.1184879973330144
Epoch: 196; total_step: 35950, loss: 1.699993522790126, nll: 1.1778572495774005
Epoch: 196; total_step: 36000, loss: 1.6268112535203636, nll: 1.1301978223010334
Epoch: 196; total_step: 36050, loss: 1.6591517032013972, nll: 1.07795961152099
Epoch: 197; total_step: 36100, loss: 1.6979165789613133, nll: 1.240214304642135
Epoch: 197; total_step: 36150, loss: 1.678551266696708, nll: 1.2613757621117718
Epoch: 197; total_step: 36200, loss: 1.7332826926074558, nll: 1.1802260193405782
Epoch: 198; total_step: 36250, loss: 1.7277112193675654, nll: 1.1767648801211907
Epoch: 198; total_step: 36300, los

Epoch: 222; total_step: 40800, loss: 1.6641923990791418, nll: 1.115259701166449
Epoch: 223; total_step: 40850, loss: 1.6419105712031214, nll: 0.9875460546177072
Epoch: 223; total_step: 40900, loss: 1.6402272362131414, nll: 1.1721792837273504
Epoch: 223; total_step: 40950, loss: 1.6493758095228546, nll: 1.1805136043134095
Epoch: 224; total_step: 41000, loss: 1.6683375185441616, nll: 1.2065496385333232
Epoch: 224; total_step: 41050, loss: 1.6461127280947345, nll: 1.1461511760983847
Epoch: 224; total_step: 41100, loss: 1.6603967932611345, nll: 1.0825522980179034
Epoch: 224; total_step: 41150, loss: 1.6569560301753592, nll: 1.181336282187857
Epoch: 225; total_step: 41200, loss: 1.615645214677267, nll: 1.0728982844701933
Epoch: 225; total_step: 41250, loss: 1.7048809667224032, nll: 1.1663605041585638
Epoch: 225; total_step: 41300, loss: 1.6800201019391239, nll: 1.1662099436605284
Epoch: 225; total_step: 41350, loss: 1.6954083832735625, nll: 1.2088111636130343
Epoch: 226; total_step: 41400, 

Epoch: 250; total_step: 45900, loss: 1.721640723438419, nll: 1.1753713434955582
Epoch: 251; total_step: 45950, loss: 1.6225360184257038, nll: 1.0856017667742441
Epoch: 251; total_step: 46000, loss: 1.6963952834479046, nll: 1.1713685755219043
Epoch: 251; total_step: 46050, loss: 1.7293751129895316, nll: 1.2852873329585304
Epoch: 251; total_step: 46100, loss: 1.7364429459250414, nll: 1.1434098526731942
Epoch: 252; total_step: 46150, loss: 1.6852466305435798, nll: 1.1757587144283124
Epoch: 252; total_step: 46200, loss: 1.6678807667559754, nll: 1.145337152336791
Epoch: 252; total_step: 46250, loss: 1.7254897438823105, nll: 1.3078519051909727
Epoch: 253; total_step: 46300, loss: 1.685372044451119, nll: 1.296275362241938
Epoch: 253; total_step: 46350, loss: 1.7504393421222104, nll: 1.3264213138975844
Epoch: 253; total_step: 46400, loss: 1.6587629356157974, nll: 1.087351806666651
Epoch: 253; total_step: 46450, loss: 1.5979667683337306, nll: 1.0635800643698872
Epoch: 254; total_step: 46500, lo

Epoch: 278; total_step: 51000, loss: 1.6738545239773808, nll: 1.0912250134732502
Epoch: 278; total_step: 51050, loss: 1.575992178018825, nll: 1.114507883492572
Epoch: 279; total_step: 51100, loss: 1.5638198801344592, nll: 1.191394273945553
Epoch: 279; total_step: 51150, loss: 1.6215687240051961, nll: 1.0540806457436318
Epoch: 279; total_step: 51200, loss: 1.752329491767794, nll: 1.2522910573991186
Epoch: 280; total_step: 51250, loss: 1.7027336971143046, nll: 1.1273297293653233
Epoch: 280; total_step: 51300, loss: 1.6712247977259, nll: 1.2819874378588831
Epoch: 280; total_step: 51350, loss: 1.6404729596409293, nll: 1.2392431507951467
Epoch: 280; total_step: 51400, loss: 1.6933786418184327, nll: 1.4155894870325436
Epoch: 281; total_step: 51450, loss: 1.687870901621742, nll: 1.0717050648497668
Epoch: 281; total_step: 51500, loss: 1.647246605883624, nll: 1.0270774513781564
Epoch: 281; total_step: 51550, loss: 1.6388420799400336, nll: 1.057934184692943
Epoch: 281; total_step: 51600, loss: 1

Epoch: 306; total_step: 56100, loss: 1.720756542237445, nll: 1.289038496791467
Epoch: 306; total_step: 56150, loss: 1.6459479880257089, nll: 1.1746836884474925
Epoch: 307; total_step: 56200, loss: 1.614653100250991, nll: 1.1259246811142736
Epoch: 307; total_step: 56250, loss: 1.6244587231219334, nll: 1.117871642601186
Epoch: 307; total_step: 56300, loss: 1.7255389996844122, nll: 1.1259222649138507
Epoch: 307; total_step: 56350, loss: 1.6439072254954954, nll: 1.0368041043725102
Epoch: 308; total_step: 56400, loss: 1.712695632931019, nll: 1.1956018325778697
Epoch: 308; total_step: 56450, loss: 1.6400820483036682, nll: 1.0597889570034438
Epoch: 308; total_step: 56500, loss: 1.5787198934054487, nll: 1.064285073671678
Epoch: 309; total_step: 56550, loss: 1.62207834218671, nll: 1.0847425019881054
Epoch: 309; total_step: 56600, loss: 1.7367155086738422, nll: 1.0653908658015203
Epoch: 309; total_step: 56650, loss: 1.651451666777605, nll: 1.261682569184644
Epoch: 309; total_step: 56700, loss: 1

Epoch: 334; total_step: 61200, loss: 1.7238270907335453, nll: 1.2384418760716758
Epoch: 334; total_step: 61250, loss: 1.6617716355344032, nll: 1.0506151709261802
Epoch: 334; total_step: 61300, loss: 1.6547344041304075, nll: 1.0829338987615287
Epoch: 335; total_step: 61350, loss: 1.7283068475584507, nll: 1.3079322409558387
Epoch: 335; total_step: 61400, loss: 1.6494100344379772, nll: 1.1399314026439835
Epoch: 335; total_step: 61450, loss: 1.6615315388118332, nll: 1.0855957175794244
Epoch: 336; total_step: 61500, loss: 1.696749364895376, nll: 1.2103296128411998
Epoch: 336; total_step: 61550, loss: 1.6358651741932624, nll: 1.0454304208163037
Epoch: 336; total_step: 61600, loss: 1.7323481805305248, nll: 1.1716733921178137
Epoch: 336; total_step: 61650, loss: 1.7619328964641414, nll: 1.223404837225651
Epoch: 337; total_step: 61700, loss: 1.7426365553564165, nll: 1.1719169903093374
Epoch: 337; total_step: 61750, loss: 1.6354015387123662, nll: 1.0378479855944556
Epoch: 337; total_step: 61800,

Epoch: 362; total_step: 66300, loss: 1.7065070021668478, nll: 1.1104754885867585
Epoch: 362; total_step: 66350, loss: 1.732520479085917, nll: 1.1652223071325176
Epoch: 362; total_step: 66400, loss: 1.672301438381088, nll: 1.0906285721626474
Epoch: 363; total_step: 66450, loss: 1.64466011441246, nll: 1.1900258710668559
Epoch: 363; total_step: 66500, loss: 1.7039197359248843, nll: 1.1709838241746644
Epoch: 363; total_step: 66550, loss: 1.699154272593745, nll: 1.078924841454866
Epoch: 363; total_step: 66600, loss: 1.6420039940469926, nll: 1.089630928287124
Epoch: 364; total_step: 66650, loss: 1.6351752894213711, nll: 1.1903652932239324
Epoch: 364; total_step: 66700, loss: 1.751857472784077, nll: 1.3219817259085311
Epoch: 364; total_step: 66750, loss: 1.620276738641535, nll: 1.1718197649947584
Epoch: 365; total_step: 66800, loss: 1.7657421063289294, nll: 1.1857430934338595
Epoch: 365; total_step: 66850, loss: 1.7305765034556315, nll: 1.2827506986530834
Epoch: 365; total_step: 66900, loss: 

Epoch: 390; total_step: 71400, loss: 1.6711201283801531, nll: 1.158033140237762
Epoch: 390; total_step: 71450, loss: 1.6592879368080662, nll: 1.1169881023327914
Epoch: 390; total_step: 71500, loss: 1.7089538899630303, nll: 1.1141892102799933
Epoch: 390; total_step: 71550, loss: 1.7174888427351467, nll: 1.144198989103378
Epoch: 391; total_step: 71600, loss: 1.7262397548876927, nll: 1.3483994029586568
Epoch: 391; total_step: 71650, loss: 1.684174752231155, nll: 1.0923819408702984
Epoch: 391; total_step: 71700, loss: 1.8059428172763319, nll: 1.3915274806890985
Epoch: 392; total_step: 71750, loss: 1.6683444362387316, nll: 1.1046929811055914
Epoch: 392; total_step: 71800, loss: 1.697708924949026, nll: 1.2358139927166014
Epoch: 392; total_step: 71850, loss: 1.6286581078321927, nll: 1.1535250183274053
Epoch: 392; total_step: 71900, loss: 1.6443795475882832, nll: 1.2092535821010648
Epoch: 393; total_step: 71950, loss: 1.6864186419919793, nll: 1.3268319095414705
Epoch: 393; total_step: 72000, l

RuntimeError: The size of tensor a (1000) must match the size of tensor b (9146) at non-singleton dimension 0