# Concrete

In [1]:
# set python to pick up local directory
from pathlib import Path
import sys
import os
path_root = str(Path(os.getcwd()).parents[0])
package_root = str(Path(os.getcwd()).parents[1])
sys.path.append(path_root)
sys.path.append(package_root)

# auto reload
%reload_ext autoreload
%autoreload 2

# hush
import warnings
warnings.simplefilter('always', category=FutureWarning)

In [2]:
from jax import config
config.update("jax_enable_x64", True)

In [3]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from copy import deepcopy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import jax
import jax.numpy as jnp
from jax import vmap, jit
from tensorflow_probability.substrates.jax import distributions as tfd
from jax.lib import xla_bridge

In [4]:
from steinRF import GP, LowRankGP, MixGP
from steinRF.gp.kernels import RFF
from steinRF.gp.transforms import Transform, ARD
from steinRF.stein.targets import TFTarget

from steinRF.stein.srfr import srfr
# from steinRF.mar_srfr import mar_srfr
from steinRF.utils import gp_cross_val, metric_model, run_hyperopt, mse, mae
from steinRF.baselines import build_svgp, build_train_svgp, svgp_predict, svgp_cross_val
from steinRF.gp.models import *

In [5]:
jax.clear_caches()
print(f"device: {xla_bridge.get_backend().platform}")

device: gpu


## Data Preparation

In [6]:
key = jax.random.PRNGKey(0)

In [7]:
concrete = pd.read_csv("concrete.csv")
X_scaler = MinMaxScaler()
y_scaler = StandardScaler()

X_cols = [col for col in concrete.columns if col != "compressive_strength"]
concrete[X_cols] = X_scaler.fit_transform(concrete[X_cols])
concrete["compressive_strength"] = y_scaler.fit_transform(concrete[["compressive_strength"]])

X = jnp.array(concrete.drop(columns=['compressive_strength']))
y = jnp.array(concrete['compressive_strength'])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.20, random_state=int(key[0]))
X_train, X_val, y_train, y_val = jnp.array(X_train), jnp.array(X_val), jnp.array(y_train), jnp.array(y_val)

d = X_train.shape[1]

## Hyperparameter Optimization

In [8]:
# total settings
diag_min, diag_max = 1e-4, 1e-1
epoch_min, epoch_max = 500, 4000
lr_min, lr_max = 1e-3, 4e-1
alpha_min, alpha_max = 0., 3.
q_min, q_max = 2, 10
s_vals = [0, 0.5, 1, 2]


# --------------------------------- Sparse Variational GP -------------------------------- #
def _svgp_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.get('R')
    diag = params.get('diag', None)
    epochs = params.pop('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.pop('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", 5e-4, 0.1, log=True)
    cv_params = {'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, **params}

    # run cross val
    cross_val_acc = svgp_cross_val(key, X, y, cv_params)
    return cross_val_acc


# ---------------------------------- RFF GP - RBF Kernel --------------------------------- #
def rff_rbf_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.pop('R')
    diag = params.pop('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.pop('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.pop('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)
    cv_params = {'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, **params}

    # run cross val
    cross_val_acc = gp_cross_val(build_train_rff_rbf, key, X, y, cv_params)
    return cross_val_acc


# --------------------------------------- RFF BASIC -------------------------------------- #
def rff_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.pop('R')
    diag = params.pop('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.pop('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.pop('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)
    cv_params = {'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, **params}

    # run cross val
    cross_val_acc = gp_cross_val(build_train_rff, key, X, y, cv_params)
    return cross_val_acc

# --------------------------------------- STEIN RFF -------------------------------------- #
def srf_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.pop('R')

    diag = params.pop('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.pop('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.pop('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)
    lr_gd = params.pop('lr_gd', None)
    if lr_gd is None:
        lr_gd = trial.suggest_float("lr_gd", lr_min, lr_max, log=True)
    alpha = params.pop('alpha', None)
    if alpha is None:
        alpha = trial.suggest_float("alpha", alpha_min, alpha_max, step=0.2)
    s = params.pop('s', None)
    if s is None:
        s = trial.suggest_categorical("s", s_vals)
    cv_params = {
        'R': R, 'epochs': epochs, 'lr': lr, 'lr_gd': lr_gd, 'diag': diag, 'alpha': alpha, 
        's': s, **params
    }
    # cv_params = {
    #     'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, 'alpha': alpha, 
    #     's': s, **params
    # }

    # run cross val
    cross_val_acc = gp_cross_val(build_train_srf, key, X, y, cv_params)
    return cross_val_acc


# ----------------------------------- NONSTATIONARY RFF ---------------------------------- #
def nrff_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.get('R')
    diag = params.get('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.get('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.get('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)

    cv_params = {'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, **params}

    # run cross val
    cross_val_acc = gp_cross_val(build_train_nrff, key, X, y, cv_params)
    return cross_val_acc


# ------------------------------------ MARGINAL KERNEL ----------------------------------- #
def mix_rff_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.pop('R')

    diag = params.pop('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.pop('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.pop('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)
    lr_gd = params.pop('lr_gd', None)
    if lr_gd is None:
        lr_gd = trial.suggest_float("lr_gd", lr_min, lr_max, log=True)
    alpha = params.pop('alpha', None)
    if alpha is None:
        alpha = trial.suggest_float("alpha", alpha_min, alpha_max, step=0.2)
    q = params.pop('q', None)
    if q is None:
        q = trial.suggest_int("q", q_min, q_max, step=2)
    s = params.pop('s', None)
    if s is None:
        s = trial.suggest_categorical("s", s_vals)

    cv_params = {
        'R': R, 'epochs': epochs, 'lr': lr, 'diag': diag, 'lr_gd': lr_gd,
        'alpha': alpha, 's': s, 'q': q, **params
    }

    # run cross val
    cross_val_acc = gp_cross_val(build_train_mix_rff, key, X, y, cv_params)
    return cross_val_acc


# ------------------------------ NONSTATIONARY STEIN MIXTURE ----------------------------- #
def nmix_cross_val(trial, key, X, y, **params):
    # hparams
    R = params.get('R')
    diag = params.get('diag', None)
    if diag is None:
        diag = trial.suggest_float("diag", diag_min, diag_max, log=True)
    epochs = params.get('epochs', None)
    if epochs is None:
        epochs = trial.suggest_int("epochs", epoch_min, epoch_max, step=500)
    lr = params.get('lr', None)
    if lr is None:
        lr = trial.suggest_float("lr", lr_min, lr_max, log=True)
    lr_gd = params.pop('lr_gd', None)
    if lr_gd is None:
        lr_gd = trial.suggest_float("lr_gd", lr_min, lr_max, log=True)
    alpha = params.get('alpha', None)
    if alpha is None:
        alpha = trial.suggest_float("alpha", alpha_min, alpha_max, step=0.2)
    s = params.pop('s', None)
    if s is None:
        s = trial.suggest_categorical("s", s_vals)
    q = params.pop('q', None)
    if q is None:
        q = trial.suggest_int("q", q_min, q_max, step=2)
    cv_params = {
        'R': R, 'epochs': epochs, 'lr': lr, 'lr_gd': lr_gd, 'diag': diag, 'alpha': alpha,
        's': s, 'q': q, **params
    }

    # run cross val
    cross_val_acc = gp_cross_val(build_train_nmix_rff, key, X, y, cv_params)
    return cross_val_acc

### Run Optimization

In [9]:
n_trials = 30
R = 100
q = 10
hparams = {}

#### Sparse-Variational GP

In [10]:
# svgp = build_train_svgp(key, X, y, R=R, diag=1e-2, epochs=1000, lr=0.01, from_data=True)[0]
svgp_study = run_hyperopt(
    _svgp_cross_val, key, X_train, y_train, n_trials=n_trials, R=R, diag=1e-2, from_data=False, epochs=1000
)
hparams["svgp"] = svgp_study.best_params

[I 2024-02-10 01:34:48,562] A new study created in memory with name: no-name-60b00659-9b5a-4843-8911-372d520e2288


  0%|          | 0/30 [00:00<?, ?it/s]

[I 2024-02-10 01:35:25,090] Trial 0 finished with value: 0.3684043744747548 and parameters: {'lr': 0.0006728018524728593}. Best is trial 0 with value: 0.3684043744747548.
[I 2024-02-10 01:35:56,565] Trial 1 finished with value: 0.19670388557321522 and parameters: {'lr': 0.07920197843656476}. Best is trial 1 with value: 0.19670388557321522.
[I 2024-02-10 01:36:27,302] Trial 2 finished with value: 0.23411662457014418 and parameters: {'lr': 0.001567476042144662}. Best is trial 1 with value: 0.19670388557321522.
[I 2024-02-10 01:36:58,235] Trial 3 finished with value: 0.20444573559830231 and parameters: {'lr': 0.004180130757767301}. Best is trial 1 with value: 0.19670388557321522.
[I 2024-02-10 01:37:29,357] Trial 4 finished with value: 0.4024862814342782 and parameters: {'lr': 0.0005805878779183577}. Best is trial 1 with value: 0.19670388557321522.
[I 2024-02-10 01:38:00,581] Trial 5 finished with value: 0.16709677968953224 and parameters: {'lr': 0.018994658441233388}. Best is trial 5 wit

#### RFF with RBF

In [None]:
# rff_rbf = build_train_rff_rbf(key, X_train, R=100, diag=1e-2, epochs=1000, lr=0.01, from_data=False)[0]
rff_rbf_study = run_hyperopt(
    rff_rbf_cross_val, key, X_train, y_train, n_trials=n_trials, R=R, epochs=1000, init_ls=False
)
hparams["rff_rbf"] = rff_rbf_study.best_params

#### Basic RFF GP

In [11]:
# rff = build_train_rff(key, X_train, y_train, R=R, diag=0.09, epochs=2000, lr=.013, from_data=False, init_ls=False)[0]
rff_study = run_hyperopt(
    rff_cross_val, key, X_train, y_train, n_trials=n_trials, R=R, epochs=1000, from_data=False, init_ls=False
)
hparams["rff"] = rff_study.best_params

[I 2024-02-10 02:01:14,493] A new study created in memory with name: no-name-3f7c95e5-686b-460f-999e-8539b79b48e6


  0%|          | 0/30 [00:00<?, ?it/s]

[I 2024-02-10 02:01:31,059] Trial 0 finished with value: 0.3130512038124511 and parameters: {'diag': 0.00011899057245427316, 'lr': 0.17639032540647825}. Best is trial 0 with value: 0.3130512038124511.
[I 2024-02-10 02:01:45,733] Trial 1 finished with value: 0.1276076855572909 and parameters: {'diag': 0.030822676516862647, 'lr': 0.00987201891888893}. Best is trial 1 with value: 0.1276076855572909.
[I 2024-02-10 02:02:00,503] Trial 2 finished with value: 1.009601344360198 and parameters: {'diag': 0.00015793581531171045, 'lr': 0.0117777584502081}. Best is trial 1 with value: 0.1276076855572909.
[I 2024-02-10 02:02:15,286] Trial 3 finished with value: 0.11582558453854805 and parameters: {'diag': 0.05854497803653283, 'lr': 0.009455215905141085}. Best is trial 3 with value: 0.11582558453854805.
[I 2024-02-10 02:02:30,044] Trial 4 finished with value: 0.4859430112313782 and parameters: {'diag': 0.0006317727820432532, 'lr': 0.026329004009463197}. Best is trial 3 with value: 0.11582558453854805

#### Stein Random Features

In [12]:
# srf = build_train_srf(key, X_train, y_train, R=R, diag=1e-3, epochs=1000, lr=0.1, alpha=0.5, s=0.5, from_data=False, init_ls=False)[0]
srf_study = run_hyperopt(
    srf_cross_val, key, X_train, y_train, n_trials=50, R=R, epochs=1000, from_data=False, init_ls=False,
    gd_params=lambda t: [t.kernel.transform.scale]
)
hparams["srf"] = srf_study.best_params

[I 2024-02-10 02:08:53,523] A new study created in memory with name: no-name-495170fe-e384-444c-bf0b-c0e534ce3801


  0%|          | 0/50 [00:00<?, ?it/s]

[I 2024-02-10 02:09:12,337] Trial 0 finished with value: 0.5865822960073148 and parameters: {'diag': 0.00024282177812108373, 'lr': 0.00644087337919752, 'lr_gd': 0.052595922878610715, 'alpha': 0.0, 's': 2}. Best is trial 0 with value: 0.5865822960073148.
[I 2024-02-10 02:09:29,950] Trial 1 finished with value: 0.1701304637574811 and parameters: {'diag': 0.014154012553906711, 'lr': 0.0022194712558587758, 'lr_gd': 0.0028390487894859625, 'alpha': 2.6, 's': 0}. Best is trial 1 with value: 0.1701304637574811.
[I 2024-02-10 02:09:48,220] Trial 2 finished with value: 0.1952352102886128 and parameters: {'diag': 0.00755509061274433, 'lr': 0.08952762005592099, 'lr_gd': 0.0038011188576371198, 'alpha': 0.0, 's': 0}. Best is trial 1 with value: 0.1701304637574811.
[I 2024-02-10 02:10:05,888] Trial 3 finished with value: 0.6386056788632002 and parameters: {'diag': 0.00013332915515950507, 'lr': 0.0015787554554618452, 'lr_gd': 0.02002046909385376, 'alpha': 3.0, 's': 0.5}. Best is trial 1 with value: 0.

#### Mixture SRFR

In [26]:
# mix_rff = build_train_mix_rff(
#     key, X_train, y_train, diag=1e-2, q=5, R=100, alpha=1.4, lr=0.1, from_data=False, epochs=1000, init_ls=False
# )[0]
prior = TFTarget(tfd.Normal(jnp.zeros(d), jnp.ones(d) * 3))
mix_rff_study = run_hyperopt(
    mix_rff_cross_val, key, X_train, y_train, n_trials=75, R=R, epochs=1000, from_data=False, prior=prior,
    gd_params=lambda t: [t.kernel.transform.scale], init_ls=False
)

[I 2024-02-08 22:11:43,016] A new study created in memory with name: no-name-7bc311cd-4955-44db-80ff-c99da633a8bc


  0%|          | 0/75 [00:00<?, ?it/s]

[I 2024-02-08 22:13:10,194] Trial 0 finished with value: 0.29708814141020995 and parameters: {'diag': 0.00015019678272494203, 'lr': 0.03756491583306538, 'lr_gd': 0.015119729641671504, 'alpha': 0.2, 'q': 6, 's': 0.5}. Best is trial 0 with value: 0.29708814141020995.
[I 2024-02-08 22:14:10,353] Trial 1 finished with value: 0.105772653093206 and parameters: {'diag': 0.04731404978278705, 'lr': 0.05480265116019067, 'lr_gd': 0.022885965528919867, 'alpha': 0.0, 'q': 4, 's': 2}. Best is trial 1 with value: 0.105772653093206.
[I 2024-02-08 22:15:08,263] Trial 2 finished with value: 0.10921101058790646 and parameters: {'diag': 0.0431121824389674, 'lr': 0.0022317891736973704, 'lr_gd': 0.02858769325631335, 'alpha': 0.6000000000000001, 'q': 4, 's': 1}. Best is trial 1 with value: 0.105772653093206.
[I 2024-02-08 22:16:06,057] Trial 3 finished with value: 0.17987110966248598 and parameters: {'diag': 0.0001244545364431794, 'lr': 0.3604268753918517, 'lr_gd': 0.07422693852055715, 'alpha': 0.8, 'q': 4, 

#### Nonstationary GP

In [None]:
# build_train_nrff(key, X_train, y_train, R=100, diag=1e-4, epochs=1000, lr=0.01, from_data=False, init_ls=False)
nrff_study = run_hyperopt(
    nrff_cross_val, key, X_train, y_train, n_trials=n_trials, R=R, from_data=False, epochs=1000, init_ls=False
)
hparams["nrff"] = nrff_study.best_params

#### Nonstationary Mixture SRFR

In [None]:
# build_train_nmix_rff(
#     key, X_train, y_train, q=5, R=100, diag=1e-4, alpha=0.5, epochs=10, lr=0.01, from_data=False, init_ls=False
# )
nmix_rff = run_hyperopt(
    nmix_cross_val, key, X_train, y_train, n_trials=75, R=R, from_data=False, epochs=1000, init_ls=False
)
hparams["nmix_rff"] = nmix_rff.best_params

## Define and Run Experiment

### Definition

In [9]:
def experiment_run(exp_key, X, y, scaler, params, R, restarts=1):    
    # split data
    X_tr, X_test, y_tr, y_test = train_test_split(X, y, test_size=0.2, random_state=int(exp_key[0]))
    X_tr, X_test, y_tr, y_test = jnp.array(X_tr), jnp.array(X_test), jnp.array(y_tr), jnp.array(y_test)

    ############ ORF ############
    # orf_gp, _, _ = build_train_orf(exp_key, X_tr, y_tr, R=R, **params["orf"])
    # orf_preds, orf_sd = orf_gp.condition(y_tr, X_test)
    # orf_metrics = metric_model(y_test, orf_preds, orf_sd, scaler=scaler)

    ############ SVGP ############
    svgp, _ = build_train_svgp(exp_key, X_tr, y_tr, R=R, **params["svgp"])
    svgp_preds, svgp_sd = svgp_predict(svgp, X_test)
    svgp_metrics = metric_model(y_test, svgp_preds, svgp_sd, scaler=scaler)

    ############ RFF RBF ############
    rff_rbf_gp, _ = build_train_rff_rbf(exp_key, X_tr, y_tr, R=R, restarts=restarts, **params["rff_rbf"])
    rff_rbf_preds, rff_rbf_sd = rff_rbf_gp.condition(y_tr, X_test)
    rff_rbf_metrics = metric_model(y_test, rff_rbf_preds, rff_rbf_sd, scaler=scaler)

    ############ RFF ############
    rff_gp, _ = build_train_rff(exp_key, X_tr, y_tr, R=R, restarts=restarts, **params["rff"])
    rff_preds, rff_sd = rff_gp.condition(y_tr, X_test)
    rff_metrics = metric_model(y_test, rff_preds, rff_sd, scaler=scaler)

    ############ RFF-Q ############
    # RFF with the equilvalent number of R as the computational complexity of the mixture.
    q = params["mix_rff"]["q"]
    R_equiv_q = int((q * R**3)**(1/3))
    qrff_gp, _ = build_train_rff(exp_key, X_tr, y_tr, R=R_equiv_q, restarts=restarts, **params["rff"])
    qrff_preds, qrff_sd = qrff_gp.condition(y_tr, X_test)
    qrff_metrics = metric_model(y_test, qrff_preds, qrff_sd, scaler=scaler)

    ############ SRF ############
    srf_gp, _ = build_train_srf(exp_key, X_tr, y_tr, R=R, restarts=restarts, **params["srf"])
    srf_preds, srf_sd = srf_gp.condition(y_tr, X_test)
    srf_metrics = metric_model(y_test, srf_preds, srf_sd, scaler=scaler)

    ############ NRFF ############
    # nrff_gp, _ = build_train_nrff(exp_key, X_tr, y_tr, R=R, restarts=restarts, **params["nrff"])
    # nrff_preds, nrff_sd = nrff_gp.condition(y_tr, X_test)
    # nrff_metrics = metric_model(y_test, nrff_preds, nrff_sd, scaler=scaler)

    ############ MIX ############
    mix_gp, _ = build_train_mix_rff(exp_key, X_tr, y_tr, R=R, restarts=restarts, **params["mix_rff"])
    mix_preds, mix_sd = mix_gp.condition(exp_key, y_tr, X_test)
    mix_metrics = metric_model(y_test, mix_preds, mix_sd, scaler=scaler)

    # # ############ NMIX ############
    # nmix_rff_gp, _ = build_train_nsrf(exp_key, X_tr, y_tr, R=R, **params["nmix_rff"])
    # nmix_rff_preds, nmix_rff_sd = nmix_rff_gp.condition(exp_key, y_tr, X_test)
    # nmix_rff_metrics = metric_model(y_test, nmix_rff_preds, nmix_rff_sd, scaler=scaler)

    metrics = {
        "seed": exp_key[1],
        "svgp": svgp_metrics,
        "rff_rbf": rff_rbf_metrics,
        "rff": rff_metrics,
        "qrff": qrff_metrics,
        "srf": srf_metrics,
        # "nrff": nrff_metrics,
        "mix": mix_metrics,
        # "nmix": nmix_rff_metrics
    }

    return metrics

In [10]:
def multi_run(multi_key, X, y, scaler, params, R, n_runs=10):
    metrics = []
    run_keys = jax.random.split(multi_key, n_runs)

    for i in range(n_runs):
        print(f"Running experiment {i+1} of {n_runs}")
        run_res = experiment_run(run_keys[i], X, y, scaler, params, R)
        metrics.append(run_res)
        print(run_res)
    
    return metrics

### Run

In [11]:
# parameters
KEY, subkey = jax.random.split(jax.random.PRNGKey(2024))
params = {
    "svgp": {'epochs': 1000, 'lr':  0.03943006189843416, 'diag': 0.01, 'from_data': False},  # not done
    "rff_rbf": {'diag': 0.09557277712120715, 'lr': 0.3673756353254991, 'epochs': 1000, 'init_ls': False}, # done
    "rff": {'diag': 0.08941600525463213, 'lr': 0.013028342062540154, 'epochs': 1000, 'from_data': False, 'init_ls': False},  # done
    "srf": {
        'diag': 0.09605792010056687, 'lr': 0.01821058468763105, 'lr_gd': 0.013285678053308094, 'alpha': 2.4, 's': 1,  # not done
        'epochs': 1000, 'from_data': False, 'gd_params': lambda t: [t.kernel.transform.scale], 'init_ls': False
    },
    "mix_rff": {
        'diag': 0.04158413731186824, 'lr': 0.342489358101857, 'lr_gd': 0.15121439497743513, 'alpha': 0.4, 'q': 6, 's': 1,  # done
        'epochs': 1000, 'from_data': False, 'gd_params': lambda t: [t.kernel.transform.scale], 'init_ls': False,
        'prior': TFTarget(tfd.Normal(jnp.zeros(d), jnp.ones(d) * 3))
    },
}

In [12]:
# experiment_run(subkey, X_val, y_val, y_scaler, params, R)
res = multi_run(KEY, X, y, y_scaler, params, R=100, n_runs=10)

Running experiment 1 of 10
{'seed': Array(3007921430, dtype=uint32), 'svgp': Array([ 4.06681940e+01,  4.93923481e+00,  9.66019452e-01, -3.70460195e-03],      dtype=float64), 'rff_rbf': Array([ 3.27686074e+01,  4.16276147e+00,  9.41747606e-01, -2.14048630e-03],      dtype=float64), 'rff': Array([1.74414914e+01, 3.00176125e+00, 9.56310689e-01, 5.36184370e-03],      dtype=float64), 'qrff': Array([1.71822282e+01, 2.85929268e+00, 9.66019452e-01, 7.47114638e-03],      dtype=float64), 'srf': Array([2.98831819e+01, 3.73473810e+00, 9.56310689e-01, 5.28508430e-03],      dtype=float64), 'mix': Array([1.47601628e+01, 2.81100518e+00, 9.70873773e-01, 1.71924147e-03],      dtype=float64)}
Running experiment 2 of 10
{'seed': Array(2158071644, dtype=uint32), 'svgp': Array([4.64786264e+01, 5.09141418e+00, 9.41747606e-01, 3.81034956e-03],      dtype=float64), 'rff_rbf': Array([ 3.52274302e+01,  4.49195611e+00,  9.32038844e-01, -4.87785340e-03],      dtype=float64), 'rff': Array([3.27948508e+01, 3.7379579

In [14]:
res_df = []
for exp in res:
    res_df.extend([[k, *v.tolist()] for k,v in exp.items() if k != "seed"])
res_df = pd.DataFrame(res_df, columns=["model", "mse", "mae", "cal", "z"])
res_df.to_csv("results.csv", index=False)

## Evaluate

In [17]:
res_df = pd.read_csv("results.csv")

names = {
    "rff": "RFF",
    "nrff": "Nonstationary RFF",
    "rff_rbf": "RFF-RBF",
    "svgp": "Sparse VGP",
    "srf": "Stein RFF",
    "qrff": "RFF Equiv Mix Big-O",
    "mix": "Stein Mixture RFF",
    "nmix": "Nonstationary Stein Mixture RFF",
}

res_df["model"] = res_df["model"].apply(lambda x: names[x])
res_df["rmse"] = np.sqrt(res_df["mse"])
res_df["log_mse"] = np.log(res_df["mse"])

In [18]:
sum_df = res_df.groupby("model").agg(["mean", "std"])
sum_df.columns = sum_df.columns.map(lambda x: '_'.join(x) if isinstance(x, tuple) else x)
sum_df = sum_df.sort_values("mse_mean", ascending=True).reset_index()
sum_df

Unnamed: 0,model,mse_mean,mse_std,mae_mean,mae_std,cal_mean,cal_std,z_mean,z_std,rmse_mean,rmse_std,log_mse_mean,log_mse_std
0,Stein Mixture RFF,18.452997,4.112192,2.912778,0.194475,0.95,0.011902,0.002242,0.009804,4.272389,0.471039,2.893613,0.217378
1,RFF Equiv Mix Big-O,23.588992,5.772869,3.313453,0.307066,0.953398,0.015554,0.0031,0.010663,4.822807,0.605094,3.132085,0.257168
2,RFF,25.823744,5.810565,3.478778,0.310246,0.946117,0.021884,0.001174,0.010437,5.051535,0.582842,3.227043,0.235986
3,Stein RFF,30.832076,5.985655,3.863761,0.265421,0.951456,0.009435,0.00121,0.012956,5.529272,0.536689,3.411635,0.19419
4,RFF-RBF,37.344351,6.410289,4.392094,0.360595,0.930583,0.016667,0.000491,0.011349,6.091209,0.518029,3.607265,0.168555
5,Sparse VGP,43.761791,5.906898,4.990825,0.237856,0.95534,0.013106,-0.002742,0.014371,6.60175,0.445582,3.770571,0.13497


In [None]:
unique_models = res_df["model"].unique()
colors = sns.color_palette("hls", len(unique_models))

pointplot = sns.pointplot(
    data=res_df, x="log_mse", y="model", hue="model",
    errorbar="ci", capsize=.4,
    palette=colors, legend=False,
    color=".5", linestyle="none", marker="D",
)
pointplot.set_title("Concrete UCI Dataset Error")