In [None]:
import os
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['OMP_NUM_THREADS'] = '4'
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']= 'platform'
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

import torch
from functools import partial
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import blackjax
from flax.training import train_state
import optax
import jaxopt
# sns.set(font_scale=2, style='whitegrid')

from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "png")
plt.rcParams["savefig.dpi"] = 150
plt.rcParams["figure.autolayout"] = True
plt.rcParams["figure.figsize"] = 6, 4
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["font.size"] = 16
plt.rcParams["lines.linewidth"] = 2.0
plt.rcParams["lines.markersize"] = 8
plt.rcParams["legend.fontsize"] = 14
plt.rcParams["grid.linestyle"] = "-"
plt.rcParams["grid.linewidth"] = 1.0
plt.rcParams["legend.facecolor"] = "white"
# plt.rcParams['grid.color'] = "grey"
plt.rcParams["text.usetex"] = True
# plt.rcParams['font.family'] = "normal"
# plt.rcParams['font.family'] = "sans-serif"
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams[
    "text.latex.preamble"
] = "\\usepackage{subdepth} \\usepackage{amsfonts} \\usepackage{type1cm}"

In [None]:
def snelson(x_train, y_train, n_test=500, x_test_lim=6, standardize_x=False, standardize_y=False, scale_x=1, holdout=False):
    if holdout:
        mask = ((x_train < 1.5) | (x_train > 3)).flatten()
        x_train = x_train[mask]
        y_train = y_train[mask]

    idx = np.argsort(x_train)
    x_train = x_train[idx]
    y_train = y_train[idx]

    if standardize_x:
        x_train = (x_train - x_train.mean(0)) / x_train.std(0) * scale_x
    if standardize_y:
        y_train = (y_train - y_train.mean(0)) / y_train.std(0)

    x_test = np.linspace(-x_test_lim, x_test_lim, n_test)[:, None]

    return x_train[:, None], y_train[:, None], x_test

def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = np.linspace(-1, 1, N)
    X = np.power(X[:, np.newaxis], np.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = np.dot(X, W) + 0.5 * np.power(0.5 + X[:, 1], 2.0) * np.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    mu = np.mean(Y)
    std = np.std(Y)
    Y -= mu
    Y /= std

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = np.linspace(-1.3, 1.3, N_test)
    X_test = np.power(X_test[:, np.newaxis], np.arange(D_X))

    Y_test = np.dot(X_test, W) + 0.5 * np.power(0.5 + X_test[:, 1], 2.0) * np.sin(4.0 * X_test[:, 1])
    Y_test = Y_test[:, np.newaxis]
    Y_test -= mu
    Y_test /= std

    return X[:, 1][:, None], Y, X_test[:, 1][:, None], Y_test

In [None]:
_raw_data = pd.read_csv('snelson.csv')
x_train, y_train = np.array(_raw_data.x.values), np.array(_raw_data.y.values)
x_train, y_train, x_test = snelson(x_train, y_train, standardize_x=True, standardize_y=True, scale_x=1, x_test_lim=6)
y_test = np.zeros_like(x_test)

# x_train, y_train, x_test, y_test = get_data(sigma_obs=0.1)
# # plot
# plt.scatter(x_train, y_train, label='trian', s=10)
# plt.scatter(x_test, y_test, label='test', s=10)
# plt.legend()

In [None]:
from flax import linen as nn
from typing import Callable

class MLP(nn.Module):
    out_size: int
    H: int = 64
    hidden_layers: int = 1
    act: Callable = nn.relu

    @nn.compact
    def __call__(self, x):
        layers = [nn.Dense(self.H)]
        for _ in range(self.hidden_layers - 1):
            layers += [self.act, nn.Dense(self.H)]
        layers += [self.act, nn.Dense(self.out_size)]
        return nn.Sequential(layers)(x)

def reparam_initializer(initializer, f):
    def init(key, shape, dtype=jnp.float32):
        # sample original parameters and then invert the reparametrization
        return f(initializer(key, shape, dtype))
    return init

class ReparamDense(nn.Module):
    # same as nn.Dense but with reparam weights
    # reparam: Callable = lambda x: x # w = reparam(x)
    # reparam_inv: Callable = lambda x: x # x = reparam_inv(w)
    # bias_init: Callable = reparam_initializer(nn.initializers.normal(stddev=1e-6))
    # kernel_init: Callable = reparam_initializer(nn.initializers.lecun_normal())

    def __init__(self, features, reparam, reparam_inv, init_scale=None):
        super().__init__()
        self.features = features
        self.reparam = reparam
        self.reparam_inv = reparam_inv
        # zero init for bias
        if init_scale is None:
            self.bias_init = reparam_initializer(nn.initializers.normal(stddev=1e-4), f=reparam_inv)
            self.kernel_init = reparam_initializer(nn.initializers.lecun_normal(), f=reparam_inv)
        else:
            self.bias_init = reparam_initializer(nn.initializers.normal(stddev=init_scale), f=reparam_inv)
            self.kernel_init = reparam_initializer(nn.initializers.normal(stddev=init_scale), f=reparam_inv)

    @nn.compact
    def __call__(self, inputs):
        reparam_kernel = self.param('reparam_kernel', self.kernel_init, (inputs.shape[-1], self.features))
        reparam_bias = self.param('reparam_bias', self.bias_init, (1, self.features)) # not using bias_init to avoid dividing by zero
        # invert weights
        kernel = jax.tree_util.tree_map(self.reparam, reparam_kernel)
        bias = jax.tree_util.tree_map(self.reparam, reparam_bias)
        # clamp to avoid numerical issues
        kernel = jnp.clip(kernel, a_min=-1e6, a_max=1e6)
        bias = jnp.clip(bias, a_min=-1e6, a_max=1e6)
        return jnp.dot(inputs, kernel) + bias

class ReparamMLP(nn.Module):
    out_size: int
    H: int = 64
    hidden_layers: int = 1
    reparam: Callable = lambda x: x # w = reparam(x)
    reparam_inv: Callable = lambda x: x # x = reparam_inv(w)
    act: Callable = nn.tanh
    init_scale: float = None

    @nn.compact
    def __call__(self, x):
        layers = [ReparamDense(features=self.H, reparam=self.reparam, reparam_inv=self.reparam_inv, init_scale=self.init_scale)]
        for _ in range(self.hidden_layers - 1):
            layers += [self.act, ReparamDense(features=self.H, reparam=self.reparam, reparam_inv=self.reparam_inv, init_scale=self.init_scale)]
        layers += [self.act, ReparamDense(features=self.out_size, reparam=self.reparam, reparam_inv=self.reparam_inv, init_scale=self.init_scale)]
        out = nn.Sequential(layers)(x)
        return out

In [None]:
def jacobian_sigular_values(model, p, x):
    jac_rev = jax.jacrev(lambda p, x: model.apply(p, x))
    jac_vmap = jax.vmap(jac_rev, in_axes=(None, 0))
    # j = jac_rev(p, x)
    j = jac_vmap(p, x)
    # move the batch axis to last
    j = jax.tree_util.tree_map(lambda x: jnp.einsum('b...->...b', x), j)
    # flatten j
    J, _ = jax.flatten_util.ravel_pytree(j)
    J = J.reshape(-1, x.shape[0]).T # (N, P)
    # sigular values of J
    _, S, _ = jnp.linalg.svd(J, full_matrices=False)
    return S

def effdim(eigs, cutoff):
    eigs = eigs[eigs > 0]
    return jnp.sum(eigs / (eigs + cutoff))

def fspace_effdim(model, p, x, cutoff, jitter=0):
    # K = fspace_hessian(model, p, x)
    # K = K + jitter * jnp.eye(K.shape[0])
    # eigenvals = jnp.linalg.eigvalsh(K)
    # assert jnp.all(eigenvals > 0), 'Hessian is not positive definite'
    # return jnp.sum(eigenvals / (eigenvals + cutoff))
    eigs = fspace_hessian_eigenvalues(model, p, x)
    return jnp.sum(eigs / (eigs + cutoff))

def fspace_hessian_eigenvalues(model, p, x):
    S = jacobian_sigular_values(model, p, x)
    return (S ** 2) / x.shape[0]

def fspace_hessian(model, p, x):
    jac_rev = jax.jacrev(lambda p, x: model.apply(p, x))
    jac_vmap = jax.vmap(jac_rev, in_axes=(None, 0))
    # j = jac_rev(p, x)
    j = jac_vmap(p, x)
    # move the batch axis to last
    j = jax.tree_util.tree_map(lambda x: jnp.einsum('b...->...b', x), j)
    # flatten j
    J, _ = jax.flatten_util.ravel_pytree(j)
    J = J.reshape(-1, x.shape[0]).T # (N, P)
    N = x.shape[0]
    K = J.T @ J / N # (P, P)
    return K

def log_det_K_svd(model, p, x, jitter=1e-6, scale=1.0):
    # log det J^T J = sum log s^2 (careful, check this for more general cases when J is not injective)
    s = jacobian_sigular_values(model, p, x) / (x.shape[0] ** 0.5)
    # s = s + jitter
    logdet_svd = 2 * jnp.sum(jnp.log(s))
    return logdet_svd

def log_det_K(model, p, x, jitter=1e-6, scale=1.0):
    K = fspace_hessian(model, p, x)
    # add jitter
    K = K + jitter * jnp.eye(K.shape[0])
    s, log_det = jnp.linalg.slogdet(K)
    # assert s > 0, 'K is not positive definite'
    # # compute cholesky
    # # L = jnp.linalg.cholesky(K)
    # # # compute log det
    # # log_det = 2 * jnp.sum(jnp.log(jnp.diag(L)))
    return log_det

    # log det J^T J = sum log s^2 (careful, check this for more general cases when J is not injective)
    s = jacobian_sigular_values(model, p, x) / (x.shape[0] ** 0.5)
    # s = s + jitter
    logdet_svd = 2 * jnp.sum(jnp.log(s))
    return logdet_svd


def log_det_diagonal_approx(model, p, x, jitter=1e-6):
    jac_rev = jax.jacrev(lambda p, x: model.apply(p, x))
    jac_vmap = jax.vmap(jac_rev, in_axes=(None, 0))
    # j = jac_rev(p, x)
    j = jac_vmap(p, x)
    # move the batch axis to last
    j = jax.tree_util.tree_map(lambda x: jnp.einsum('b...->...b', x), j)
    # flatten j
    J, _ = jax.flatten_util.ravel_pytree(j)
    J = J.reshape(-1, x.shape[0]).T # (N, P)
    avg_j_sq = jnp.mean(J ** 2, axis=0) # (P,)
    logdet_diag = jnp.sum(jnp.log(avg_j_sq + jitter))
    return logdet_diag

In [None]:
def optimize(lr, weight_decay, n_step, rng_key, loss_fn, model, x_train, y_train, x_eval_generator, optimizer, fsmap, jitter, diag=False):
    rng_key, init_params_key = jax.random.split(rng_key)
    init_params = jax.jit(model.init)(init_params_key, jnp.ones((1, x_train.shape[1])))
    if optimizer == 'adam':
        tx = optax.adam(learning_rate=lr)
    elif optimizer == 'sgd':
        tx = optax.sgd(learning_rate=lr, momentum=0.9)
    else:
        raise NotImplementedError
    ts = train_state.TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
    
    def augmented_loss_fn(p, x_eval):
        # loss = likelihood / N = 1 / (2 * sigma^2) * ||y - f(x)||^2 / N
        # it contains a factor 1 / N, where N = x_train.shape[0]
        # all other terms should be divided by N as well
        loss = loss_fn(p) 
        if fsmap:
            if diag:
                fs_loss = 1 / 2 * log_det_diagonal_approx(model, p, x_eval, jitter) / x_train.shape[0]
            else:
                fs_loss = 1 / 2 * log_det_K(model, p, x_eval, jitter) / x_train.shape[0]
        else:
            fs_loss = 0
        params_flat, unravel = jax.flatten_util.ravel_pytree(p)
        wd_loss = weight_decay * jnp.sum(params_flat ** 2) / x_train.shape[0]
        return loss + wd_loss + fs_loss

    grad_fn = jax.jit(jax.value_and_grad(augmented_loss_fn))
    losses = []
    @jax.jit
    def train_step(ts, rng_key):
        rng_key, x_eval_key = jax.random.split(rng_key)
        x_eval = x_eval_generator(x_eval_key)
        loss, grads = grad_fn(ts.params, x_eval)
        ts = ts.apply_gradients(grads=grads)
        return ts, loss, rng_key
    for e in tqdm(range(n_step)):
        ts, loss, rng_key = train_step(ts, rng_key)
        losses.append(loss.item())
    losses = np.array(losses)
    return ts.params, losses

def make_flat_function(f, unravel):
    def f_flat(p_flat, *args, **kwargs):
        p = unravel(p_flat)
        return f(p, *args, **kwargs)
    return f_flat

def optimize_func_decay(lr, func_decay, n_step, rng_key, loss_fn, model, x_train, y_train, x_eval_generator, optimizer):
    rng_key, init_params_key = jax.random.split(rng_key)
    init_params = jax.jit(model.init)(init_params_key, jnp.ones((1, x_train.shape[1])))
    if optimizer == 'adam':
        tx = optax.adam(learning_rate=lr)
    elif optimizer == 'sgd':
        tx = optax.sgd(learning_rate=lr, momentum=0.9)
    else:
        raise NotImplementedError
    ts = train_state.TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
    
    def augmented_loss_fn(p, x_eval):
        # loss = likelihood / N = 1 / (2 * sigma^2) * ||y - f(x)||^2 / N
        # it contains a factor 1 / N, where N = x_train.shape[0]
        # all other terms should be divided by N as well
        loss = loss_fn(p)
        fs_loss = func_decay * (model.apply(p, x_eval) ** 2).sum() / x_train.shape[0]
        return loss + fs_loss

    grad_fn = jax.jit(jax.value_and_grad(augmented_loss_fn))
    losses = []
    @jax.jit
    def train_step(ts, rng_key):
        rng_key, x_eval_key = jax.random.split(rng_key)
        x_eval = x_eval_generator(x_eval_key)
        loss, grads = grad_fn(ts.params, x_eval)
        ts = ts.apply_gradients(grads=grads)
        return ts, loss, rng_key
    for e in tqdm(range(n_step)):
        ts, loss, rng_key = train_step(ts, rng_key)
        losses.append(loss.item())
    losses = np.array(losses)
    return ts.params, losses


In [None]:
def run_experiment(arch, noise_scale, x_train, y_train, x_test, y_test, x_eval_generator, n_step, lr, weight_decay, optimizer, output_dir, seed, task='regression', fsmap=False, jitter=1e-6):
    # model
    model = arch(out_size=1)
    x_eval_sample = x_eval_generator(jax.random.PRNGKey(0))
    # count parameters
    init_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, x_train.shape[1])))
    leaves, _ = jax.tree_util.tree_flatten(init_params)
    n_params = sum([np.prod(p.shape) for p in leaves])
    print(f"Number of parameters: {n_params}")


    if task == 'regression':
        train_loss_fn = lambda p: 0.5 * jnp.mean((model.apply(p, x_train) - y_train) ** 2) / noise_scale ** 2
        test_loss_fn = lambda p: jnp.mean((model.apply(p, x_test) - y_test) ** 2)
    elif task == 'classification':
        train_loss_fn = lambda p: jnp.mean(jax.nn.sigmoid(model.apply(p, x_train)) * (1 - y_train) + (1 - jax.nn.sigmoid(model.apply(p, x_train))) * y_train)
        test_loss_fn = lambda p: jnp.mean(jnp.round(jax.nn.sigmoid(model.apply(p, x_test))) != y_test)
    else:
        raise NotImplementedError

    rng_key = jax.random.PRNGKey(seed)
    params, losses = optimize(lr, weight_decay, n_step, rng_key, train_loss_fn, model, x_train, y_train, x_eval_generator, optimizer, fsmap, jitter)
    torch.save(params, f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_params.pt')
    print('Saved ps parameters at ', f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_params.pt')
    # plot and save losses
    plt.figure()
    # plot without 90% quantiles as y limits
    q1, q2 = np.quantile(losses, [0., 0.95])
    plt.plot(losses)
    plt.ylim(q1, q2)
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


    fig, ax = plt.subplots(figsize=(9, 6), dpi=120)
    prediction = model.apply(params, x_test)
    ax.plot(x_test[..., 0], prediction, c='green', label='Prediction')
    # plot X_eval on the x-axis
    ax.scatter(x_eval_sample[..., 0], np.zeros(x_eval_sample.shape[0]), c='black', label='Eval Points Sample', linestyle='None', s=3)
    ax.set(xlabel='$x$', ylabel='$y$', ylim=[-3, 3], xlim=[2 * x_train.min(), 2 * x_train.max()])
    ax.scatter(x_train[..., 0], y_train, c='r', label='Train Data', s=10)
    ax.plot(x_test[..., 0], y_test, c='b', label='Test Data')
    ax.grid(True)

    # show noise_scale, prior_scale on plot
    ax.text(0.05, 0.95, f'wd: {weight_decay}', transform=ax.transAxes, fontsize=16, verticalalignment='top')
    ax.legend()
    fig.tight_layout()
    fig.show()

    train_loss = train_loss_fn(params)
    test_loss = test_loss_fn(params)
    result = {'train_loss': train_loss, 'test_loss': test_loss, 'n_params': n_params, 'weight_decay': weight_decay}
    if task == 'regression':
        result['test_loss'] = jnp.mean((model.apply(params, x_test) - y_test) ** 2)
    elif task == 'classification':
        result['test_loss'] = jnp.mean(jax.nn.sigmoid(model.apply(params, x_test)) * (1 - y_test) + (1 - jax.nn.sigmoid(model.apply(params, x_test))) * y_test)
        result['test_error'] = jnp.mean(jnp.round(jax.nn.sigmoid(model.apply(params, x_test))) != y_test)
    # Hessian of training loss
    params_flat, unravel = jax.flatten_util.ravel_pytree(params)
    model_apply_flat = make_flat_function(model.apply, unravel)
    if task == 'regression':
        loss_fn_flat = lambda p: jnp.mean((model_apply_flat(p, x_train) - y_train) ** 2)
    elif task == 'classification':
        loss_fn_flat = lambda p: jnp.mean(jax.nn.sigmoid(model_apply_flat(p, x_train)) * (1 - y_train) + (1 - jax.nn.sigmoid(model_apply_flat(p, x_train))) * y_train)
    # h_train_loss = jax.hessian(loss_fn_flat)(params_flat)
    # h_train_loss_eigenvals = jnp.linalg.eigvalsh(h_train_loss)
    # print('Fraction of eigenvalues smaller than 0: ', np.sum(h_train_loss_eigenvals < 0) / h_train_loss_eigenvals.shape[0])
    # print('Fraction of eigenvalues smaller than 0.01: ', np.sum(h_train_loss_eigenvals < 0.01) / h_train_loss_eigenvals.shape[0])
    # print('Fraction of eigenvalues smaller than 0.1: ', np.sum(h_train_loss_eigenvals < 0.1) / h_train_loss_eigenvals.shape[0])
    # result['h_train_loss_eigenvals'] = h_train_loss_eigenvals
    # # Hessian function space
    # h_fspace_eigenvals = fspace_hessian_eigenvalues(model, params, x_eval_sample)
    # result['h_fspace_eigenvals'] = h_fspace_eigenvals
    result['p_norm'] = jnp.linalg.norm(params_flat)
    torch.save(result, f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_result.pt')
    return result


In [None]:
def run_compare(arch, noise_scale, x_train, y_train, x_test, y_test, x_eval_generator, n_step, lr, weight_decay, func_decay, optimizer, output_dir, seed, task='regression', jitter=1e-6, load=True):
    if not isinstance(n_step , list):
        n_step = [n_step] * 3
    # model
    model = arch(out_size=1)
    x_eval_sample = x_eval_generator(jax.random.PRNGKey(0))
    # count parameters
    init_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, x_train.shape[1])))
    leaves, _ = jax.tree_util.tree_flatten(init_params)
    n_params = sum([np.prod(p.shape) for p in leaves])
    print(f"Number of parameters: {n_params}")


    if task == 'regression':
        train_loss_fn = lambda p: 0.5 * jnp.mean((model.apply(p, x_train) - y_train) ** 2) / noise_scale ** 2
        test_loss_fn = lambda p: jnp.mean((model.apply(p, x_test) - y_test) ** 2)
    elif task == 'classification':
        train_loss_fn = lambda p: jnp.mean(jax.nn.sigmoid(model.apply(p, x_train)) * (1 - y_train) + (1 - jax.nn.sigmoid(model.apply(p, x_train))) * y_train)
        test_loss_fn = lambda p: jnp.mean(jnp.round(jax.nn.sigmoid(model.apply(p, x_test))) != y_test)
    else:
        raise NotImplementedError

    # pmap
    rng_key = jax.random.PRNGKey(seed)
    param_path = f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_pmap_params.pt'
    if load and os.path.exists(param_path):
        print('Loading parameters from ', param_path)
        pmap_params = torch.load(param_path)
    else:
        pmap_params, losses = optimize(lr, weight_decay, n_step[0], rng_key, train_loss_fn, model, x_train, y_train, x_eval_generator, optimizer, False, jitter)
        torch.save(pmap_params, param_path)
        print('Saved parameters at ', param_path)
        # plot and save losses
        plt.figure()
        # plot without 90% quantiles as y limits
        q1, q2 = np.quantile(losses, [0., 0.95])
        plt.plot(losses)
        plt.ylim(q1, q2)
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    # fmap
    param_path = f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_fmap_params.pt'
    if load and os.path.exists(param_path):
        print('Loading parameters from ', param_path)
        fmap_params = torch.load(param_path)
    else:
        fmap_params, losses = optimize(lr, weight_decay, n_step[1], rng_key, train_loss_fn, model, x_train, y_train, x_eval_generator, optimizer, True, jitter)
        torch.save(fmap_params, param_path)
        print('Saved parameters at ', param_path)
        # plot and save losses
        plt.figure()
        # plot without 90% quantiles as y limits
        q1, q2 = np.quantile(losses, [0., 0.95])
        plt.plot(losses)
        plt.ylim(q1, q2)
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    # fmap diagonal approx
    param_path = f'{output_dir}/{task}_wd{weight_decay}_{optimizer}_{seed}_diag_fmap_params.pt'
    if load and os.path.exists(param_path):
        print('Loading parameters from ', param_path)
        diag_fmap_params = torch.load(param_path)
    else:
        diag_fmap_params, losses = optimize(lr, weight_decay, n_step[1], rng_key, train_loss_fn, model, x_train, y_train, x_eval_generator, optimizer, True, jitter, diag=True)
        torch.save(diag_fmap_params, param_path)
        print('Saved parameters at ', param_path)
        # plot and save losses
        plt.figure()
        # plot without 90% quantiles as y limits
        q1, q2 = np.quantile(losses, [0., 0.95])
        plt.plot(losses)
        plt.ylim(q1, q2)
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    # fmap white noise
    param_path = f'{output_dir}/{task}_fd{func_decay}_{optimizer}_{seed}_fdecay_params.pt'
    if load and os.path.exists(param_path):
        print('Loading parameters from ', param_path)
        fdecay_params = torch.load(param_path)
    else:
        fdecay_params, losses = optimize_func_decay(lr, func_decay, n_step[2], rng_key, train_loss_fn, model, x_train, y_train, x_eval_generator, optimizer)
        torch.save(fdecay_params, param_path)
        print('Saved parameters at ', param_path)
        # plot and save losses
        plt.figure()
        # plot without 90% quantiles as y limits
        q1, q2 = np.quantile(losses, [0., 0.95])
        plt.plot(losses)
        plt.ylim(q1, q2)
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()


    fig, ax = plt.subplots(figsize=(9, 6), dpi=120)
    ax.plot(x_test[..., 0], model.apply(fmap_params, x_test), c='purple', label=r'Exact F-MAP')
    ax.plot(x_test[..., 0], model.apply(diag_fmap_params, x_test), c='Magenta', label=r'Diag F-MAP', linestyle='--')
    ax.plot(x_test[..., 0], model.apply(fdecay_params, x_test), c='green', label=r'Linearized F-MAP')
    ax.plot(x_test[..., 0], model.apply(pmap_params, x_test), c='blue', label='P-MAP', linestyle='--')
    # plot X_eval on the x-axis
    ax.scatter(x_eval_sample[..., 0], np.zeros(x_eval_sample.shape[0]), c='black', label='Eval Points Sample', linestyle='None', s=3)
    # ax.set(xlabel='$x$', ylabel='$y$', ylim=[-3, 3], xlim=[2 * x_train.min(), 2 * x_train.max()])
    ax.set(xlabel='$x$', ylabel='$y$', ylim=[-3, 3], xlim=[-3, 3])
    ax.scatter(x_train[..., 0], y_train, c='r', label='Train Data', s=10)
    ax.grid(True)

    # show noise_scale, prior_scale on plot
    ax.text(0.05, 0.95, rf'$\lambda = {weight_decay:.2g}, \tau = {func_decay:2g}$', transform=ax.transAxes, fontsize=16, verticalalignment='top')
    ax.legend()
    fig.tight_layout()
    fig.show()
    fig.savefig(os.path.join(output_dir, f'noise{noise_scale}_wd{weight_decay}_fd{func_decay}.png'))


In [None]:
activations = {
    'tanh': nn.tanh,
    'elu': nn.elu,
}


### Loss definitions ###

$\mathcal{L}_\mathrm{P-MAP}(\sigma, \lambda) = (2\sigma^2)^{-1} ||Y - f(X)||^2 + \lambda ||w||^2$

$\mathcal{L}_\mathrm{Exact.F-MAP}(\sigma, \lambda) = (2\sigma^2)^{-1} ||Y - f(X)||^2 + \lambda ||w||^2 + 1/2 \log\det(E_{x' \in X'}[\nabla f(x') \nabla f(x')^T])$

$\mathcal{L}_\mathrm{Diagonal.F-MAP}(\sigma, \lambda) = (2\sigma^2)^{-1} ||Y - f(X)||^2 + \lambda ||w||^2 + 1/2 \sum_{i=1}^{p}\log E_{x' \in X'}[(\partial f(x') / \partial w_i)^2]$

$\mathcal{L}_\mathrm{Linearized.F-MAP}(\sigma, \tau) = (2\sigma^2)^{-1} ||Y - f(X)||^2 + \tau ||f(X')||^2$


$f$ has 1153 parameters (2 hidden layers, 32 units), $X'$ is 10,000 either uniformly random or linearly spaced samples in $[-3, 3]$, i.e., the function is defined to have domain $[-3, 3]$

### Random uniform evaluation points ###

In [None]:
task = 'regression'
reparam = lambda x: x
reparam_inv = lambda w: w

widths = [32]
depths = [2]
acts = ['tanh']

# SGD hyper
optimizer = 'adam'
lr = 1e-2 #3e-3
n_step = [int(1e4), int(1e4), int(1e4)]
weight_decays = [0.1, 1, 10, 100]
func_decays =  [0.001, 0.01, 0.1]
noise_scale = 0.1

n_eval = 1e4
def x_eval_generator(rng_key):
    X_eval = jax.random.uniform(rng_key, (int(n_eval), 1), minval=-3, maxval=3)
    return X_eval

load = True
results = []
for width in widths:
    for depth in depths:
        for act in acts:
            output_dir = f'comparison_rand_uniform_3/{width}_{depth}_{act}_noise{noise_scale}_lr{lr}_{optimizer}'
            # mkdir if needed
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            for weight_decay in weight_decays:
                for func_decay in func_decays + [weight_decay]:
                    for seed in range(0, 1):
                        arch = partial(ReparamMLP, reparam=reparam, reparam_inv=reparam_inv, H=width, hidden_layers=depth, act=activations[act])
                        results.append(run_compare(arch, noise_scale, x_train, y_train, x_test, y_test, x_eval_generator, n_step, lr, weight_decay, func_decay, optimizer, output_dir, seed, task=task, load=load))

### Linspace evaluation points ###

In [None]:
task = 'regression'
reparam = lambda x: x
reparam_inv = lambda w: w

widths = [32]
depths = [2]
acts = ['tanh']

# SGD hyper
optimizer = 'adam'
lr = 1e-2 #3e-3
n_step = [int(1e4), int(1e4), int(1e4)]
weight_decays = [0.1, 1, 10, 100]
func_decays =  [0.001, 0.01, 0.1]
noise_scale = 0.1

n_eval = 1e4
def x_eval_generator(rng_key):
    X_eval = jnp.linspace(-3, 3, int(n_eval)).reshape(-1, 1)
    return X_eval

load = True
results = []
for width in widths:
    for depth in depths:
        for act in acts:
            output_dir = f'comparison_linspace_3/{width}_{depth}_{act}_noise{noise_scale}_lr{lr}_{optimizer}'
            # mkdir if needed
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            for weight_decay in weight_decays:
                for func_decay in func_decays + [weight_decay]:
                    for seed in range(0, 1):
                        arch = partial(ReparamMLP, reparam=reparam, reparam_inv=reparam_inv, H=width, hidden_layers=depth, act=activations[act])
                        results.append(run_compare(arch, noise_scale, x_train, y_train, x_test, y_test, x_eval_generator, n_step, lr, weight_decay, func_decay, optimizer, output_dir, seed, task=task, load=load))

### Change Adam to SGD leaves the result qualitatively unchanged ###

In [None]:
task = 'regression'
reparam = lambda x: x
reparam_inv = lambda w: w

widths = [32]
depths = [2]
acts = ['tanh']

# SGD hyper
optimizer = 'sgd'
lr = 1e-3
n_step = [int(1e6), int(1e6), int(1e6)]
weight_decays = [0.1]
func_decays =  [0.001, 0.01, 0.1]
noise_scale = 0.1

n_eval = 1e4
def x_eval_generator(rng_key):
    X_eval = jax.random.uniform(rng_key, (int(n_eval), 1), minval=-3, maxval=3)
    return X_eval

load = True
results = []
for width in widths:
    for depth in depths:
        for act in acts:
            output_dir = f'comparison_rand_uniform_3_sgd_1e-3_1e6/{width}_{depth}_{act}_noise{noise_scale}_lr{lr}_{optimizer}'
            # mkdir if needed
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            for weight_decay in weight_decays:
                for func_decay in func_decays + [weight_decay]:
                    for seed in range(0, 1):
                        arch = partial(ReparamMLP, reparam=reparam, reparam_inv=reparam_inv, H=width, hidden_layers=depth, act=activations[act])
                        results.append(run_compare(arch, noise_scale, x_train, y_train, x_test, y_test, x_eval_generator, n_step, lr, weight_decay, func_decay, optimizer, output_dir, seed, task=task, load=load))