In [1]:
import torch
import numpy as np
from src.datasets.uci_loader import UCIDataset
from src.model_builder import build_model, sghmc_sampling, print_sample_performance, collect_samples, predict_y
from torchviz import make_dot
%load_ext autoreload
%autoreload 2

# setting PyTorch

from src.misc.settings import settings
device = settings.device
if device.type == 'cuda':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

 # Load regression UCI dataset

data_uci = UCIDataset(dataset_path='data/uci/boston.pth', static_split=True, seed=0)
N, D = data_uci.X_train.shape
print(f'X-train: {N, D}')

INFO:root:Loading dataset from data/uci/boston.pth


X-train: (406, 13)


How the model works in BGSP:
1. There is a `run_regression.py` file to setup the model classes such as `RegressionModel` and the cmd line arguments
2. The `RegressionModel` in `models.py` builds the `DGP` model and launches a training loop with the hmc sampling inside
3. The `DGP` class is defined throug different `Layers` and instantiates the TensorFlow session and does the initialization 

In [13]:
#  Instantiate a model
class ARGS():
    num_inducing = 100
    n_layers = 1
    minibatch_size = 100
    window_size = 64
    output_dim= 1
    adam_lr = 0.01
    prior_inducing_type = "uniform"
    full_cov = False
    epsilon = 0.01
    mdecay = 0.05
    iterations = 1024
    num_posterior_samples = 10
    posterior_sample_spacing = 32
args = ARGS()

bsgp_model = build_model(data_uci.X_train, data_uci.Y_train, args)
#make_dot(nll, params={'nll': nll}).render("computation_graph", format="png", cleanup=True)

In [15]:
bsgp_model.reset(data_uci.X_train, data_uci.Y_train)
global_step = 0
for _ in range(args.iterations):
    global_step += 1
    sample = sghmc_sampling(bsgp_model)
    #print(sample[3])
    """
    if args.prior_type == "determinantal":
        bsgp_model.reset_Lm()
        #model.train_hypers() if hasattr(self.model, 'hyper_train_op') else None
    """
    if _ % 50 == 1:
        marginal_ll = print_sample_performance(bsgp_model)
        print('TRAIN | iter = %6d      sample marginal LL = %5.2f' % (_, marginal_ll))
    
collect_samples(bsgp_model, args.num_posterior_samples, args.posterior_sample_spacing)

ms, vs  = predict_y(bsgp_model, data_uci.X_test, args.num_posterior_samples)
m = np.average(ms, 0)
v = np.average(vs + ms**2, 0) - m**2
import seaborn as sns
sns.scatterplot(x=m.reshape(-1), y=data_uci.Y_test.reshape(-1))

TRAIN | iter =      1      sample marginal LL = -12.55
TRAIN | iter =     51      sample marginal LL = -11.34
TRAIN | iter =    101      sample marginal LL = -8.10
TRAIN | iter =    151      sample marginal LL = -8.10
TRAIN | iter =    201      sample marginal LL = -11.64
TRAIN | iter =    251      sample marginal LL = -8.44
TRAIN | iter =    301      sample marginal LL = -8.39
TRAIN | iter =    351      sample marginal LL = -8.78
TRAIN | iter =    401      sample marginal LL = -8.78
TRAIN | iter =    451      sample marginal LL = -8.78
TRAIN | iter =    501      sample marginal LL = -7.46
TRAIN | iter =    551      sample marginal LL = -8.45
TRAIN | iter =    601      sample marginal LL = -8.27
TRAIN | iter =    651      sample marginal LL = -8.06
TRAIN | iter =    701      sample marginal LL = -8.71
TRAIN | iter =    751      sample marginal LL = -8.03
TRAIN | iter =    801      sample marginal LL = -8.61
TRAIN | iter =    851      sample marginal LL = -9.02
TRAIN | iter =    901    

In [12]:
data_uci.Y_test

array([[ 1.2766553 ],
       [-0.40887564],
       [-0.2701926 ],
       [-0.9742752 ],
       [-0.9422715 ],
       [-1.0276147 ],
       [ 0.20986366],
       [ 0.41255406],
       [-0.16351342],
       [ 0.24186733],
       [ 0.06051267],
       [-0.11017384],
       [-0.14217772],
       [ 0.22053142],
       [ 0.01784104],
       [-0.05683426],
       [ 0.63658035],
       [ 0.60457647],
       [-0.366204  ],
       [-0.43021134],
       [-0.45154727],
       [-0.16351342],
       [-0.4195436 ],
       [-0.75024897],
       [-0.50488687],
       [-0.46221521],
       [-0.4942189 ],
       [-0.8889319 ],
       [-0.13150975],
       [-1.0276147 ],
       [ 0.45522568],
       [ 0.24186733],
       [ 0.09251654],
       [ 0.7645952 ],
       [ 1.8207189 ],
       [ 0.67925197],
       [ 2.087417  ],
       [ 0.24186733],
       [ 0.93528193],
       [ 0.95661783],
       [ 0.25253528],
       [ 0.93528193],
       [ 0.1031845 ],
       [-0.07817017],
       [ 0.16719183],
       [ 0