In [1]:
import numpy as np
import torch

import plotly.express as px
import plotly.graph_objects as go

from experiment import Experiment
from backpack import backpack, extend

In [2]:
def plot_data(X_train, y_train, X_test, y_test, title, line_title, scatter_title):
    fig = go.Figure([
        go.Scatter(
            name='Training data',
            x=X_train,
            y=y_train,
            mode='markers'   
        ),
        go.Scatter(
            name='Original function',
            x=X_test,
            y=y_test,
            mode='lines',
            line=dict(color='rgb(31, 119, 180)'),
    )])

    fig.update_layout(
        yaxis_title='y',
        xaxis_title='x',
        title=title,
        hovermode="x"
    )
    fig.update_layout(
    autosize=False,
    width=1000,
    height=1000,)

    fig.show()

In [3]:
def plot_regression( X_train, y_train, X_test, f_test, y_std, title="LA"):
    fig = go.Figure([
    go.Scatter(
        name='Training data',
        x=X_train,
        y=y_train,
        mode='markers'   
    ),
    go.Scatter(
        name='MAP',
        x=X_test,
        y=f_test,
        mode='lines',
        line=dict(color='rgb(31, 119, 180)'),
    ),
    go.Scatter(
        name='Upper Bound',
        x=X_test,
        y=f_test + y_std,
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        showlegend=False
    ),
    go.Scatter(
        name='Lower Bound',
        x=X_test,
        y=f_test - y_std,
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        fillcolor='rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        showlegend=False
    )
    ])
    fig.update_layout(
        yaxis_title='y',
        xaxis_title='x',
        title=title,
        hovermode="x"
    )
    fig.update_layout(
    autosize=False,
    width=1000,
    height=1000)

    
    fig.show()

    



In [4]:
experiment_btwn = Experiment(dataset="sinusoidal", split='between', deterministic=True,)

In [5]:
X_train, y_train  = experiment_btwn.train_dataloader.dataset.get_arrays()
X_test, y_test  = experiment_btwn.test_dataloader.dataset.get_arrays()

In [6]:
plot_data(X_train, y_train, X_test, y_test, title="Synthetic data", line_title="Original function", scatter_title="Training data")

In [7]:
model, loss = experiment_btwn.experiment_mlp(epochs=50, lr=1e-2, check_point="mlp_btwn.pt")

Epoch 0 loss: 0.17042691599239002
Epoch 10 loss: 0.04317905144257979
Epoch 20 loss: 0.027038794180208988
Epoch 30 loss: 0.016333842480724507
Epoch 40 loss: 0.01001601471480998
Test loss: 0.5842509454426666


In [8]:
def experiment_la(la, X_tr, y_tr, X_te, title):
    print(X_te.shape)
    X = X_te.reshape(X_te.shape[0])
    X_te = torch.from_numpy(X_te.astype(np.float32))
    f_mu, f_var = la(x=X_te)
    f_mu = f_mu.squeeze().detach().cpu().numpy()
    f_sigma = f_var.squeeze().sqrt().cpu().numpy()
    pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)
    plot_regression(X_tr, y_tr, X, f_mu, pred_std, title=title)

In [15]:
X_train, y_train = experiment_btwn.train_dataloader.dataset.get_arrays()
X = np.linspace(-2, 2, 1000).reshape(1000, 1)

In [16]:
model = experiment_btwn.load_model("mlp_btwn.pt")
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="all", hessian_structure="full", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (all/full)")


(1000, 1)


In [17]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="all", hessian_structure="kron", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (all/kron)")


(1000, 1)


In [19]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="all", hessian_structure="diag", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (all/diag)")


(1000, 1)


In [20]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="all", hessian_structure="lowrank", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (all/lowrank)")


(1000, 1)


In [21]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="last_layer", hessian_structure="kron", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (ll/kron)")


(1000, 1)


In [23]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="last_layer", hessian_structure="diag", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (ll/diag)")


(1000, 1)


In [24]:
la = experiment_btwn.train_la(model=model, dataloader=experiment_btwn.train_dataloader, subset_of_weights="last_layer", hessian_structure="full", sigma_noise=0.075)
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation (ll/full)")


(1000, 1)


In [25]:
la = experiment_btwn.train_la_marglik(dataloader=experiment_btwn.train_dataloader, hessian_structure="full")
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation - marglik(all/full)")


(1000, 1)


In [27]:
la = experiment_btwn.train_la_marglik(dataloader=experiment_btwn.train_dataloader, hessian_structure="diag")
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation - (marglik/diag)")


(1000, 1)


In [28]:
la = experiment_btwn.train_la_marglik(dataloader=experiment_btwn.train_dataloader, hessian_structure="kron")
experiment_la(la, X_tr=X_train, y_tr=y_train, X_te=X_test, title="Bayesian regression with laplace approximation - (marglik/kron)")


(1000, 1)
