In [1]:
import math
import numpy as np
import torch
import gpytorch
import tqdm
import random
import time
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
sys.path.append("../")
sys.path.append("../directionalvi/utils")
sys.path.append("../directionalvi")
import traditional_vi
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy
from dfree_directional_vi import train_gp, eval_gp
from metrics import MSE
import testfun
from csv_dataset import csv_dataset

In [2]:
dataset = csv_dataset("../experiments/real_data/WECs_DataSet/Perth_Data.csv", gradients=False, rescale=True)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31]
[48]


In [3]:
dataset.n

72000

In [4]:
# data parameters
n   = dataset.n
print("n is: ", n)
dim = dataset.dim
print("dims is: ", dim)

# training params
num_inducing = 500
num_directions = 1
minibatch_size = 500
num_epochs = 1000

# seed
#torch.random.manual_seed(0)
# use tqdm or just have print statements
tqdm = False
# use data to initialize inducing stuff
inducing_data_initialization = False
# use natural gradients and/or CIQ
use_ngd = False
use_ciq = False
num_contour_quadrature=15
# learning rate
learning_rate_hypers = 0.01
learning_rate_ngd    = 0.1
gamma  = 10.0
#levels = np.array([20,150,300])
#def lr_sched(epoch):
#  a = np.sum(levels > epoch)
#  return (1./gamma)**a
lr_sched = None

n is:  72000
dims is:  32


In [5]:
# train-test split
n_train = int(0.8*dataset.n)
n_test  = n - n_train
train_dataset,test_dataset = torch.utils.data.random_split(dataset,[n_train,n_test])

In [6]:
#loaders
train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=n_test, shuffle=False)

In [7]:
test_y = [item[1] for item in test_loader]
test_x = [item[0] for item in test_loader]

# D-Free Grad SVGP

In [None]:
# train
print("\n\n---DirectionalGradVGP---")
print(f"Start training with {n} trainig data of dim {dim}")
print(f"VI setups: {num_inducing} inducing points, {num_directions} inducing directions")
args={"verbose":True}
t1 = time.time()	
model,likelihood = train_gp(train_dataset,
                      num_inducing=num_inducing,
                      num_directions=num_directions,
                      minibatch_size = minibatch_size,
                      minibatch_dim = num_directions,
                      num_epochs =num_epochs, 
                      learning_rate_hypers=learning_rate_hypers,
                      learning_rate_ngd=learning_rate_ngd,
                      inducing_data_initialization=inducing_data_initialization,
                      use_ngd = use_ngd,
                      use_ciq = use_ciq,
                      lr_sched=lr_sched,
                      num_contour_quadrature=num_contour_quadrature,
                      tqdm=tqdm,**args
                      )
t2 = time.time()	

# save the model
# torch.save(model.state_dict(), "../data/test_dvi_basic.model")

# test
means, variances = eval_gp( test_dataset,model,likelihood,
                            num_directions=num_directions,
                            minibatch_size=n_test,
                            minibatch_dim=num_directions)
t3 = time.time()	




---DirectionalGradVGP---
Start training with 72000 trainig data of dim 32
VI setups: 500 inducing points, 1 inducing directions
All parameters to learn:
      variational_strategy.inducing_points
      torch.Size([500, 32])
      variational_strategy.inducing_directions
      torch.Size([500, 32])
      variational_strategy._variational_distribution.variational_mean
      torch.Size([1000])
      variational_strategy._variational_distribution.chol_variational_covar
      torch.Size([1000, 1000])
      mean_module.constant
      torch.Size([1])
      covar_module.raw_outputscale
      torch.Size([])
      covar_module.base_kernel.raw_lengthscale
      torch.Size([1, 1])
      noise_covar.raw_noise
      torch.Size([1])
Total number of parameters:  1033004.0
Epoch: 0; total_step: 0, loss: 2.4726170184510545, nll: 1.4002978832175332
Epoch: 0; total_step: 50, loss: 1.6684274296414636, nll: 1.1149669167810183
Epoch: 0; total_step: 100, loss: 1.519692141532555, nll: 0.969311184744317
Epoch

Epoch: 40; total_step: 4750, loss: 0.8516478631188985, nll: 0.3097374200872883
Epoch: 41; total_step: 4800, loss: 0.8677313477467139, nll: 0.2781649830288215
Epoch: 41; total_step: 4850, loss: 0.9447872040924618, nll: 0.3596252562954846
Epoch: 42; total_step: 4900, loss: 0.8088763859825773, nll: 0.2766123679133644
Epoch: 42; total_step: 4950, loss: 0.8837187102799768, nll: 0.2803226351912983
Epoch: 43; total_step: 5000, loss: 0.8433901753592635, nll: 0.3085149831603168
Epoch: 43; total_step: 5050, loss: 0.8712945137719028, nll: 0.3118054420102752
Epoch: 43; total_step: 5100, loss: 0.8455594809881213, nll: 0.27342867723215253
Epoch: 44; total_step: 5150, loss: 0.9103972407868193, nll: 0.3217839453970024
Epoch: 44; total_step: 5200, loss: 0.8638011398183396, nll: 0.27716025011179374
Epoch: 45; total_step: 5250, loss: 0.8382332114498079, nll: 0.28932614126557343
Epoch: 45; total_step: 5300, loss: 0.7572340909093768, nll: 0.23583008466728358
Epoch: 46; total_step: 5350, loss: 0.84075324855

Epoch: 85; total_step: 9950, loss: 0.8523967358979375, nll: 0.2536415667085113
Epoch: 86; total_step: 10000, loss: 0.853302373876237, nll: 0.3284592863857071
Epoch: 86; total_step: 10050, loss: 0.811946163628589, nll: 0.21461624231447973
Epoch: 87; total_step: 10100, loss: 0.7773121366464865, nll: 0.2236463775847393
Epoch: 87; total_step: 10150, loss: 0.8130405343034918, nll: 0.21906531091311968
Epoch: 87; total_step: 10200, loss: 0.7759241405364792, nll: 0.20947575661687656
Epoch: 88; total_step: 10250, loss: 0.7981430762255779, nll: 0.2536240875326203
Epoch: 88; total_step: 10300, loss: 0.8181125892419552, nll: 0.275349567955151
Epoch: 89; total_step: 10350, loss: 0.8302426739474054, nll: 0.26080929407524456
Epoch: 89; total_step: 10400, loss: 0.851077614999252, nll: 0.34253174493178296
Epoch: 90; total_step: 10450, loss: 0.8328980301112006, nll: 0.30087960203174374
Epoch: 90; total_step: 10500, loss: 0.7911886238206266, nll: 0.23800831793267233
Epoch: 90; total_step: 10550, loss: 0.

Epoch: 129; total_step: 15050, loss: 0.7641054196117827, nll: 0.21079819986759898
Epoch: 130; total_step: 15100, loss: 0.8292427252101013, nll: 0.24647499899812836
Epoch: 130; total_step: 15150, loss: 0.8100488353878404, nll: 0.19239013689975462
Epoch: 131; total_step: 15200, loss: 0.8630349022439685, nll: 0.3146752341561182
Epoch: 131; total_step: 15250, loss: 0.8424307309789354, nll: 0.2859826246151298
Epoch: 131; total_step: 15300, loss: 0.8438489311760863, nll: 0.23570624441565205
Epoch: 132; total_step: 15350, loss: 0.7797344257255401, nll: 0.19558110590714467
Epoch: 132; total_step: 15400, loss: 0.8653130906979016, nll: 0.29707713902511135
Epoch: 133; total_step: 15450, loss: 0.7652637882541641, nll: 0.20966841142259185
Epoch: 133; total_step: 15500, loss: 0.7764119318911562, nll: 0.2740288333989091
Epoch: 134; total_step: 15550, loss: 0.7944149182920607, nll: 0.1997265760535633
Epoch: 134; total_step: 15600, loss: 0.8293116938299597, nll: 0.22957455670746485
Epoch: 134; total_st

Epoch: 173; total_step: 20100, loss: 0.8335686425701087, nll: 0.2534799330913688
Epoch: 173; total_step: 20150, loss: 0.7912289062517456, nll: 0.1292317170072109
Epoch: 174; total_step: 20200, loss: 0.7985810945147466, nll: 0.16534455697294353
Epoch: 174; total_step: 20250, loss: 0.7505645452458573, nll: 0.18401425091642679
Epoch: 175; total_step: 20300, loss: 0.781302457012019, nll: 0.23019766001410624
Epoch: 175; total_step: 20350, loss: 0.8593973911819416, nll: 0.3322935820329917
Epoch: 175; total_step: 20400, loss: 0.7566075771151418, nll: 0.21302761545636786
Epoch: 176; total_step: 20450, loss: 0.8138316384668699, nll: 0.2668036459538772
Epoch: 176; total_step: 20500, loss: 0.7913485603067707, nll: 0.2369681869930946
Epoch: 177; total_step: 20550, loss: 0.8581391200937596, nll: 0.26485293410638405
Epoch: 177; total_step: 20600, loss: 0.7790153563196245, nll: 0.16830839019526575
Epoch: 178; total_step: 20650, loss: 0.7953798439825953, nll: 0.217942152487335
Epoch: 178; total_step: 

Epoch: 216; total_step: 25150, loss: 0.8351011117203295, nll: 0.27036733134408714
Epoch: 217; total_step: 25200, loss: 0.8204739738534583, nll: 0.1886130713436675
Epoch: 217; total_step: 25250, loss: 0.7842586472542015, nll: 0.21047827994965027
Epoch: 218; total_step: 25300, loss: 0.763090978978628, nll: 0.20062876328196372
Epoch: 218; total_step: 25350, loss: 0.7329407487691144, nll: 0.14374822774137072
Epoch: 218; total_step: 25400, loss: 0.8041741216967565, nll: 0.25402038954012235
Epoch: 219; total_step: 25450, loss: 0.8219229345710664, nll: 0.20210848229624068
Epoch: 219; total_step: 25500, loss: 0.7608196861424339, nll: 0.22553661524150181
Epoch: 220; total_step: 25550, loss: 0.8313568292676745, nll: 0.2661968645621374
Epoch: 220; total_step: 25600, loss: 0.7614450921957946, nll: 0.17373522704673902
Epoch: 221; total_step: 25650, loss: 0.7798109989914356, nll: 0.227714234615864
Epoch: 221; total_step: 25700, loss: 0.7745179106414825, nll: 0.22088701401796237
Epoch: 221; total_ste

Epoch: 260; total_step: 30200, loss: 0.7913666696462941, nll: 0.1990988253048561
Epoch: 260; total_step: 30250, loss: 0.7916528368378579, nll: 0.17987176307084768
Epoch: 261; total_step: 30300, loss: 0.7858231682606068, nll: 0.1909631893354521
Epoch: 261; total_step: 30350, loss: 0.7241586959635841, nll: 0.10405868292850039
Epoch: 262; total_step: 30400, loss: 0.7665880769457619, nll: 0.2136359267076993
Epoch: 262; total_step: 30450, loss: 0.7681760935417666, nll: 0.201904926909372
Epoch: 262; total_step: 30500, loss: 0.7796212307657799, nll: 0.17161043124884667
Epoch: 263; total_step: 30550, loss: 0.769162018742361, nll: 0.18951618293807687
Epoch: 263; total_step: 30600, loss: 0.8018289971104101, nll: 0.27339254093710613
Epoch: 264; total_step: 30650, loss: 0.7789864621614627, nll: 0.11398508952478631
Epoch: 264; total_step: 30700, loss: 0.8171002999076791, nll: 0.23127873864313103
Epoch: 265; total_step: 30750, loss: 0.8381379189666606, nll: 0.22952682099085348
Epoch: 265; total_step

Epoch: 303; total_step: 35250, loss: 0.7919394986854049, nll: 0.17635008304813915
Epoch: 304; total_step: 35300, loss: 0.7974521456918531, nll: 0.20865140135584975
Epoch: 304; total_step: 35350, loss: 0.7286572440269438, nll: 0.1169934238911244
Epoch: 305; total_step: 35400, loss: 0.7805475909642935, nll: 0.17722568438960076
Epoch: 305; total_step: 35450, loss: 0.7680309061203534, nll: 0.15483692965884058
Epoch: 306; total_step: 35500, loss: 0.757560435855816, nll: 0.21286346541415327
Epoch: 306; total_step: 35550, loss: 0.7946943682120498, nll: 0.21720856267123542
Epoch: 306; total_step: 35600, loss: 0.800461534732918, nll: 0.24905064003539693
Epoch: 307; total_step: 35650, loss: 0.6935439979966357, nll: 0.12370594021401154
Epoch: 307; total_step: 35700, loss: 0.6933494793182537, nll: 0.10925796346214116
Epoch: 308; total_step: 35750, loss: 0.785481998206278, nll: 0.19914759010837055
Epoch: 308; total_step: 35800, loss: 0.7303374155779361, nll: 0.12660589851296775
Epoch: 309; total_st

Epoch: 347; total_step: 40300, loss: 0.7917340697176378, nll: 0.16188202633748552
Epoch: 347; total_step: 40350, loss: 0.8034408167829109, nll: 0.1638733658398439
Epoch: 348; total_step: 40400, loss: 0.7612637016059284, nll: 0.2093173924317494
Epoch: 348; total_step: 40450, loss: 0.7848419492460531, nll: 0.22786288737706065
Epoch: 349; total_step: 40500, loss: 0.7429844080992628, nll: 0.17037810611141818
Epoch: 349; total_step: 40550, loss: 0.7693181563917078, nll: 0.23830616609940092
Epoch: 350; total_step: 40600, loss: 0.7480604391336936, nll: 0.147792348293132
Epoch: 350; total_step: 40650, loss: 0.7332046321409051, nll: 0.20918318816012446
Epoch: 350; total_step: 40700, loss: 0.7519358655983419, nll: 0.18938320706997638
Epoch: 351; total_step: 40750, loss: 0.7375911190193701, nll: 0.2096796926041183
Epoch: 351; total_step: 40800, loss: 0.7914132061051512, nll: 0.2127729870851644
Epoch: 352; total_step: 40850, loss: 0.7978984966941545, nll: 0.20912544287127116
Epoch: 352; total_step

Epoch: 390; total_step: 45350, loss: 0.7673265949255169, nll: 0.17000598047220103
Epoch: 391; total_step: 45400, loss: 0.7969833121483479, nll: 0.2305670714411481
Epoch: 391; total_step: 45450, loss: 0.7955405706349385, nll: 0.15990049554610264
Epoch: 392; total_step: 45500, loss: 0.763366239668214, nll: 0.16360747871691125
Epoch: 392; total_step: 45550, loss: 0.7842535578183037, nll: 0.1389094105717999
Epoch: 393; total_step: 45600, loss: 0.7463345271290952, nll: 0.12235987420023758
Epoch: 393; total_step: 45650, loss: 0.8527226055634574, nll: 0.292606707026308
Epoch: 393; total_step: 45700, loss: 0.7863383999382079, nll: 0.1607022735872988
Epoch: 394; total_step: 45750, loss: 0.7054366078069275, nll: 0.130876651934487
Epoch: 394; total_step: 45800, loss: 0.7833874873759408, nll: 0.23310406574412115
Epoch: 395; total_step: 45850, loss: 0.7668365481845866, nll: 0.16271345544402652
Epoch: 395; total_step: 45900, loss: 0.7792260099220121, nll: 0.21697057267710923
Epoch: 396; total_step: 

Epoch: 434; total_step: 50400, loss: 0.7468357987872629, nll: 0.21827068259377286
Epoch: 434; total_step: 50450, loss: 0.8367574174706403, nll: 0.2581881869342235
Epoch: 435; total_step: 50500, loss: 0.7412540944257208, nll: 0.15349133864216993
Epoch: 435; total_step: 50550, loss: 0.7997854016106272, nll: 0.19365688400271455
Epoch: 436; total_step: 50600, loss: 0.7745166149156447, nll: 0.22293288400889544
Epoch: 436; total_step: 50650, loss: 0.7402151229938357, nll: 0.17551124928574507
Epoch: 437; total_step: 50700, loss: 0.8148693623118886, nll: 0.2666599029133592
Epoch: 437; total_step: 50750, loss: 0.7790079726480731, nll: 0.1809829280277417
Epoch: 437; total_step: 50800, loss: 0.8175874642221712, nll: 0.22997365952561402
Epoch: 438; total_step: 50850, loss: 0.7782285211245992, nll: 0.24024074751688887
Epoch: 438; total_step: 50900, loss: 0.7476453744317783, nll: 0.16095543354535674
Epoch: 439; total_step: 50950, loss: 0.7415443487661882, nll: 0.08977498150453812
Epoch: 439; total_s

Epoch: 478; total_step: 55450, loss: 0.8152257326601239, nll: 0.19288078324118155
Epoch: 478; total_step: 55500, loss: 0.7817990477492187, nll: 0.23276420108490856
Epoch: 478; total_step: 55550, loss: 0.7449072011794615, nll: 0.1806894424024735
Epoch: 479; total_step: 55600, loss: 0.7306918229111827, nll: 0.13727529742982206
Epoch: 479; total_step: 55650, loss: 0.7531870200676526, nll: 0.17724399519672962
Epoch: 480; total_step: 55700, loss: 0.7603784918490574, nll: 0.19168836331227662
Epoch: 480; total_step: 55750, loss: 0.7765417018704973, nll: 0.15706750254506816
Epoch: 481; total_step: 55800, loss: 0.7552621668867812, nll: 0.16913471663164728
Epoch: 481; total_step: 55850, loss: 0.7365441616190749, nll: 0.1690150507458989
Epoch: 481; total_step: 55900, loss: 0.7871287835316212, nll: 0.18202077144888976
Epoch: 482; total_step: 55950, loss: 0.8466869936427147, nll: 0.30248800620180366
Epoch: 482; total_step: 56000, loss: 0.8318014854197011, nll: 0.20989364507216066
Epoch: 483; total_

Epoch: 521; total_step: 60500, loss: 0.82011974497255, nll: 0.2595373974693058
Epoch: 521; total_step: 60550, loss: 0.8154635288533065, nll: 0.21326629411190387
Epoch: 522; total_step: 60600, loss: 0.7017809481710031, nll: 0.11546609674638182
Epoch: 522; total_step: 60650, loss: 0.7862736054128352, nll: 0.21063400923607564
Epoch: 523; total_step: 60700, loss: 0.8653646020485791, nll: 0.2684121249477721
Epoch: 523; total_step: 60750, loss: 0.7493541262033432, nll: 0.22453599784945996
Epoch: 524; total_step: 60800, loss: 0.7774875326472229, nll: 0.24110091826645633
Epoch: 524; total_step: 60850, loss: 0.7674273017430512, nll: 0.1997258161203283
Epoch: 525; total_step: 60900, loss: 0.7866231053974229, nll: 0.1781902846220926
Epoch: 525; total_step: 60950, loss: 0.7858453098722845, nll: 0.13905880165713191
Epoch: 525; total_step: 61000, loss: 0.7781026870051808, nll: 0.12798920644175024
Epoch: 526; total_step: 61050, loss: 0.8187890227741224, nll: 0.1990306500081268
Epoch: 526; total_step:

Epoch: 565; total_step: 65550, loss: 0.7986476016070507, nll: 0.2264926992597003
Epoch: 565; total_step: 65600, loss: 0.7756854926255115, nll: 0.17652216203552748
Epoch: 565; total_step: 65650, loss: 0.7830054722614713, nll: 0.1987757428605396
Epoch: 566; total_step: 65700, loss: 0.776198424617823, nll: 0.1793396504578351
Epoch: 566; total_step: 65750, loss: 0.8243773580690361, nll: 0.2814188258236563
Epoch: 567; total_step: 65800, loss: 0.7977709082393126, nll: 0.1952084115352752
Epoch: 567; total_step: 65850, loss: 0.7871796877755334, nll: 0.22389434510420897
Epoch: 568; total_step: 65900, loss: 0.7961422744628409, nll: 0.19437924875749377
Epoch: 568; total_step: 65950, loss: 0.820612341010376, nll: 0.20062943678301626
Epoch: 568; total_step: 66000, loss: 0.7583350022562554, nll: 0.24665579779339558
Epoch: 569; total_step: 66050, loss: 0.8087660216670849, nll: 0.20401041997080255
Epoch: 569; total_step: 66100, loss: 0.8104918433573118, nll: 0.1908243714119755
Epoch: 570; total_step: 

Epoch: 608; total_step: 70600, loss: 0.7940787618171098, nll: 0.196259639948438
Epoch: 609; total_step: 70650, loss: 0.7620438832668959, nll: 0.12815463223573004
Epoch: 609; total_step: 70700, loss: 0.7836643517905368, nll: 0.14570438028813215
Epoch: 609; total_step: 70750, loss: 0.7883055027848678, nll: 0.16185657792377406
Epoch: 610; total_step: 70800, loss: 0.7540612837625505, nll: 0.12561236157027292
Epoch: 610; total_step: 70850, loss: 0.7989138649694153, nll: 0.1709554751914655
Epoch: 611; total_step: 70900, loss: 0.8328507113551255, nll: 0.22490845225334857
Epoch: 611; total_step: 70950, loss: 0.7935897761374243, nll: 0.19781050508030854
Epoch: 612; total_step: 71000, loss: 0.7817408308929644, nll: 0.20548298897704642
Epoch: 612; total_step: 71050, loss: 0.7064971499672932, nll: 0.08945103680225623
Epoch: 612; total_step: 71100, loss: 0.843700600326663, nll: 0.2682036972791908
Epoch: 613; total_step: 71150, loss: 0.7880701295606138, nll: 0.25476825655735713
Epoch: 613; total_ste

Epoch: 652; total_step: 75650, loss: 0.7313274239553522, nll: 0.09007130720328134
Epoch: 652; total_step: 75700, loss: 0.7549152759460018, nll: 0.16388787858734855
Epoch: 653; total_step: 75750, loss: 0.7219279129864083, nll: 0.11180051999487746
Epoch: 653; total_step: 75800, loss: 0.7282364740224864, nll: 0.16213885712042164
Epoch: 653; total_step: 75850, loss: 0.7465024047793397, nll: 0.17601115791355829
Epoch: 654; total_step: 75900, loss: 0.8230812135823399, nll: 0.21153703154271755
Epoch: 654; total_step: 75950, loss: 0.8135901283438882, nll: 0.17544900810367514
Epoch: 655; total_step: 76000, loss: 0.8148161436023381, nll: 0.22992349240037913
Epoch: 655; total_step: 76050, loss: 0.8412586770330909, nll: 0.3133894045099062
Epoch: 656; total_step: 76100, loss: 0.7333354168623646, nll: 0.1491347867770174
Epoch: 656; total_step: 76150, loss: 0.7416331742045188, nll: 0.16168423191538467
Epoch: 656; total_step: 76200, loss: 0.7624383830685272, nll: 0.16172652117071354
Epoch: 657; total_

Epoch: 695; total_step: 80700, loss: 0.7883329567260794, nll: 0.24439953708341478
Epoch: 696; total_step: 80750, loss: 0.7155076542799552, nll: 0.10557568734576483
Epoch: 696; total_step: 80800, loss: 0.742918881906474, nll: 0.15886199614912788
Epoch: 696; total_step: 80850, loss: 0.7821268800267829, nll: 0.21080377604379852
Epoch: 697; total_step: 80900, loss: 0.8045722464434346, nll: 0.2093963905959435
Epoch: 697; total_step: 80950, loss: 0.7390630260682524, nll: 0.21712011017097718
Epoch: 698; total_step: 81000, loss: 0.7314178409047004, nll: 0.20143166139341023
Epoch: 698; total_step: 81050, loss: 0.7375807006260243, nll: 0.1360146118343663
Epoch: 699; total_step: 81100, loss: 0.7302770999391133, nll: 0.12311816927305544
Epoch: 699; total_step: 81150, loss: 0.7284111431172334, nll: 0.1144514151843915
Epoch: 700; total_step: 81200, loss: 0.7601882601035184, nll: 0.14464031719079332
Epoch: 700; total_step: 81250, loss: 0.752580774329305, nll: 0.18569495596388033
Epoch: 700; total_ste

Epoch: 739; total_step: 85750, loss: 0.6833456352756577, nll: 0.14832144859535026
Epoch: 739; total_step: 85800, loss: 0.7627584175243921, nll: 0.17860717297863965
Epoch: 740; total_step: 85850, loss: 0.7909493882294754, nll: 0.14588194352752798
Epoch: 740; total_step: 85900, loss: 0.814197315726615, nll: 0.18624455245162597
Epoch: 740; total_step: 85950, loss: 0.8165268570035364, nll: 0.1623888772115007
Epoch: 741; total_step: 86000, loss: 0.7577874787490576, nll: 0.1697329422914508
Epoch: 741; total_step: 86050, loss: 0.8096565631094277, nll: 0.24033181674401324
Epoch: 742; total_step: 86100, loss: 0.733207362333343, nll: 0.1162722377677498
Epoch: 742; total_step: 86150, loss: 0.7256479430430605, nll: 0.14551187165707616
Epoch: 743; total_step: 86200, loss: 0.7217327188540364, nll: 0.11483685195936956
Epoch: 743; total_step: 86250, loss: 0.7907326227143107, nll: 0.19718798631764173
Epoch: 743; total_step: 86300, loss: 0.7549604347303391, nll: 0.1879366805640543
Epoch: 744; total_step

In [None]:

# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means, variances.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

#plot=1
#if plot == 1:
#    from mpl_toolkits.mplot3d import axes3d
#    import matplotlib.pyplot as plt
#    fig = plt.figure(figsize=(12,6))
#    ax = fig.add_subplot(111, projection='3d')
#    ax.scatter(test_x[0][:,0],test_x[:,1],test_y, color='k')
#    ax.scatter(test_x[0][:,0],test_x[:,1],means, color='b')
#    plt.title("f(x,y) variational fit; actual curve is black, variational is blue")
#    plt.show()

In [None]:
# training params
#num_inducing = 50
#num_directions = 6
#minibatch_size = 200
#num_epochs = 100


# 2 directions
#At 104 testing points, MSE: 2.9133e+00, nll: 3.3945e+00. 
# 3 directions
#At 104 testing points, MSE: 2.9455e+00, nll: 3.3617e+00.
#Training time: 70.29 sec, testing time: 0.10 sec
# 4 directions
#At 104 testing points, MSE: 2.9810e+00, nll: 3.0743e+00.
#Training time: 57.68 sec, testing time: 0.08 sec
# 5 directions
#At 104 testing points, MSE: 2.9440e+00, nll: 3.6124e+00.
#Training time: 104.46 sec, testing time: 0.12 sec
# 6 directions
#At 104 testing points, MSE: 2.9795e+00, nll: 3.1092e+00.
#Training time: 127.73 sec, testing time: 0.10 sec
# 7 directions
#At 104 testing points, MSE: 2.9272e+00, nll: 3.6537e+00.
#Training time: 153.38 sec, testing time: 0.12 sec
# 8 directions
#At 104 testing points, MSE: 2.9503e+00, nll: 3.3300e+00.
#Training time: 173.86 sec, testing time: 0.15 sec
# 9 directions
# 10 directions

# Traditional SVGP

In [None]:
model_t,likelihood_t = traditional_vi.train_gp(train_dataset,dim,
                                                   num_inducing=num_inducing,
                                                   minibatch_size=minibatch_size,
                                                   num_epochs=num_epochs,
                                                   use_ngd=use_ngd, use_ciq=use_ciq,
                                                   learning_rate_hypers=learning_rate_hypers,
                                                   learning_rate_ngd=learning_rate_ngd,
                                                   lr_sched=lr_sched,
                                                   num_contour_quadrature=num_contour_quadrature,gamma=gamma, verbose=True)

In [None]:
means_t, variances_t = traditional_vi.eval_gp(test_dataset, model_t, likelihood_t, minibatch_size=n_test)

In [None]:
# compute MSE
#test_y = test_y.cpu()
test_mse = MSE(test_y[0],means_t)
# compute mean negative predictive density
test_nll = -torch.distributions.Normal(means_t, variances_t.sqrt()).log_prob(test_y[0]).mean()
print(f"At {n_test} testing points, MSE: {test_mse:.4e}, nll: {test_nll:.4e}.")
print(f"Training time: {(t2-t1):.2f} sec, testing time: {(t3-t2):.2f} sec")

In [None]:
# protein
# dfree and svgp
# At 9146 testing points, MSE: 5.9555e-01, nll: 1.1599e+00.
# At 9146 testing points, MSE: 6.2736e-01, nll: 1.1856e+00.

In [None]:
#forest fire

# At 104 testing points, MSE: 3.0218e+00, nll: 4.2752e+00.
#Training time: 837.60 sec, testing time: 0.31 sec

#At 104 testing points, MSE: 2.8974e+00, nll: 3.5117e+00.
#Training time: 837.60 sec, testing time: 0.31 sec

In [None]:
# At 14400 testing points, MSE: 1.0995e-01, nll: 3.0791e-01.
# At 14400 testing points, MSE: 1.2291e-01, nll: 3.9036e-01.