In [1]:
!pip install dm-haiku numpyro

Collecting dm-haiku
  Downloading dm_haiku-0.0.11-py3-none-any.whl (370 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m371.0/371.0 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpyro
  Downloading numpyro-0.13.2-py3-none-any.whl (312 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m312.7/312.7 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, numpyro, dm-haiku
Successfully installed dm-haiku-0.0.11 jmp-0.0.4 numpyro-0.13.2


In [8]:
from typing import Sequence, NamedTuple
import copy, time, json

import numpy as np
import plotly.express as px

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, HMC, NUTS

import jax
import jax.numpy as jnp
import jax.tree_util as jtree
import haiku as hk
import optax

In [3]:
class SGLDConfig(NamedTuple):
  epsilon: float
  gamma: float
  num_steps: int

def generate_rngkey_tree(key_or_seed, tree_or_treedef):
    rngseq = hk.PRNGSequence(key_or_seed)
    return jtree.tree_map(lambda _: next(rngseq), tree_or_treedef)

def optim_sgld(epsilon, rngkey_or_seed, preconditioning=None):
    @jax.jit
    def sgld_delta(g, rngkey, p=1):
        eta = jax.random.normal(rngkey, shape=g.shape) * jnp.sqrt(epsilon * p)
        return -epsilon * p * g / 2 + eta

    def init_fn(_):
        return rngkey_or_seed

    if preconditioning is None:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree)
    else:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree, preconditioning)

    @jax.jit
    def update_fn(grads, state):
        rngkey, new_rngkey = jax.random.split(state)
        rngkey_tree = generate_rngkey_tree(rngkey, grads)
        updates = update_map(grads, rngkey_tree)
        return updates, new_rngkey
    return optax.GradientTransformation(init_fn, update_fn)


def create_local_logposterior(avgnegloglikelihood_fn, num_training_data, w_init, gamma, itemp):
    def helper(x, y):
        return jnp.sum((x - y)**2)

    def _logprior_fn(w):
        sqnorm = jax.tree_util.tree_map(helper, w, w_init)
        return jax.tree_util.tree_reduce(lambda a,b: a + b, sqnorm)

    def logprob(w, x, y):
        loglike = -num_training_data * avgnegloglikelihood_fn(w, x, y)
        logprior = -gamma / 2 * _logprior_fn(w)
        return itemp * loglike + logprior
    return logprob


In [4]:
def mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size):
    """
    Calculate the acceptance probability for a MALA transition.

    Args:
    current_point: The current point in parameter space.
    proposed_point: The proposed point in parameter space.
    loss_and_grad_fn (function): Function to compute loss and loss gradient at a point.
    step_size (float): Step size parameter for MALA.

    Returns:
    float: Acceptance probability for the proposed transition.
    """
    # Compute the gradient of the loss at the current point
    current_loss, current_grad = loss_and_grad_fn(current_point)
    proposed_loss, proposed_grad = loss_and_grad_fn(proposed_point)

    # Compute the log of the proposal probabilities (using the Gaussian proposal distribution)
    log_q_proposed_to_current = -jnp.sum((current_point - proposed_point - (step_size * 0.5 * -proposed_grad)) ** 2) / (2 * step_size)
    log_q_current_to_proposed = -jnp.sum((proposed_point - current_point - (step_size * 0.5 * -current_grad)) ** 2) / (2 * step_size)

    # Compute the acceptance probability
    acceptance_log_prob = log_q_proposed_to_current - log_q_current_to_proposed + current_loss - proposed_loss
    return jnp.minimum(1.0, jnp.exp(acceptance_log_prob))

# Example usage:
# Define a simple log likelihood function for demonstration
def log_likelihood_example(x):
    return -0.5 * jnp.sum(x**2)

loss_and_grad_fn = jax.jit(jax.value_and_grad(lambda x: -log_likelihood_example(x), argnums=0))

# Current and proposed points for demonstration
current_point = jnp.array([1.0, 2.0])
proposed_point = jnp.array([1.5, 2.5])
step_size = 0.1

# Calculate the acceptance probability
acceptance_prob = mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size)
acceptance_prob


Array(0.95719343, dtype=float32)

In [5]:
def pack_params(params):
    params_flat, treedef = jax.tree_util.tree_flatten(params)
    shapes = [p.shape for p in params_flat]
    indices = np.cumsum([p.size for p in params_flat])
    params_packed = jnp.concatenate([jnp.ravel(p) for p in params_flat])
    pack_info = (treedef, shapes, indices)
    return params_packed, pack_info

def unpack_params(params_packed, pack_info):
    treedef, shapes, indices = pack_info
    params_split = jnp.split(params_packed, indices)
    params_flat = [jnp.reshape(p, shape) for p, shape in zip(params_split, shapes)]
    params = jax.tree_util.tree_unflatten(treedef, params_flat)
    return params

# MLPs

In [10]:
class ReluNetwork(hk.Module):
    def __init__(self, layer_widths):
        super().__init__()
        self.layer_widths = layer_widths

    def __call__(self, x):
        for width in self.layer_widths[:-1]:
            x = hk.Linear(width)(x)
            x = jax.nn.relu(x)
        x = hk.Linear(self.layer_widths[-1])(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = ReluNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model

def generate_training_data(true_param, model, input_dim, num_samples):
    inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)

def create_minibatches(inputs, targets, batch_size, num_iter=None, shuffle=True):
    assert len(inputs) == len(targets)

    if num_iter is None:
        num_iter = len(inputs)

    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    iter = 0
    while True:
        for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            yield inputs[excerpt], targets[excerpt]

            iter += 1
            if iter >= num_iter:
                return



In [12]:
input_dim = 10  # Dimension of the input
layer_widths = [20, 20, 10]

# Create the model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim), dtype=jnp.float32)
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Set biases to not be zero
for i in range(len(layer_widths)-1):
    rngkey, subkey = jax.random.split(rngkey)
    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
    true_param[loc]['b'] = jax.random.normal(subkey, (layer_widths[i],))


In [23]:
step_size = 1e-5
gamma = 1.0
batch_size = 32
num_samples = 10000
num_runs = 5

# Sweep over different training data size
dataset_sizes = np.round(10**np.linspace(2, 5, 7)).astype(np.int64)
sgld_times = []
mala_times = []
memories = []
sgld_lambdas = []
mala_lambdas = []
for num_training_data in dataset_sizes:
    sgld_times.append([])
    mala_times.append([])
    memories.append([])
    sgld_lambdas.append([])
    mala_lambdas.append([])

    for run in range(num_runs):
        print(f"Training set size: {num_training_data}, run {run}")

        # Generate training data
        np.random.seed(0)
        x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
        print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
        jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

        memory = (x_train.size * x_train.itemsize) + (y_train.size * y_train.itemsize)
        memories[-1].append(memory)
        print("Training data memory: ", memory)

        itemp = 1 / np.log(num_training_data)
        param_init = copy.deepcopy(true_param)

        loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
        local_logprob = create_local_logposterior(
                avgnegloglikelihood_fn=loss_fn,
                num_training_data=num_training_data,
                w_init=param_init,
                gamma=gamma,
                itemp=itemp,
            )

        # Run SGLD
        sgld_start_time = time.time()

        sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

        rngkey = jax.random.PRNGKey(run)
        sgldoptim = optim_sgld(step_size, rngkey)
        samples = []
        losses = []
        accept_probs = []
        opt_state = sgldoptim.init(param_init)
        param = param_init
        t = 0
        for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, num_iter=num_samples):
            old_param = param.copy()

            nll, grads = sgld_grad_fn(param, x_batch, y_batch)
            updates, opt_state = sgldoptim.update(grads, opt_state)
            param = optax.apply_updates(param, updates)
            samples.append(param)
            losses.append(loss_fn(param, x_batch, y_batch))

            t += 1

            if t % 20 == 0:
                old_param_packed, pack_info = pack_params(old_param)
                param_packed, _ = pack_params(param)
                def grad_fn_packed(w):
                    nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                    grad_packed, _ = pack_params(grad)
                    return nll, grad_packed
                accept_probs.append(mala_acceptance_probability(
                    old_param_packed, param_packed, grad_fn_packed, step_size))

        init_loss = loss_fn(param_init, x_train, y_train)
        lambdahat = ((np.mean(losses) - init_loss) * num_training_data * itemp).item()

        sgld_time = time.time() - sgld_start_time

        sgld_lambdas[-1].append(lambdahat)
        sgld_times[-1].append(sgld_time)

        print("SGLD lambda: ", lambdahat)
        print("SGLD accept prob: ", np.mean(accept_probs))
        print("SGLD execution time: ", sgld_time)

        # Run MALA
        mala_start_time = time.time()

        rngkey = jax.random.PRNGKey(run)

        param_init = copy.deepcopy(true_param)
        beta = itemp*num_training_data
        mala_step_size = np.sqrt(step_size)
        num_steps = 1

        loss_fn = jax.jit(lambda w: mse_loss(w, model, x_train, y_train))
        potential_fn = jax.jit(lambda w: -local_logprob(w, x_train, y_train))

        hmc_kernel = HMC(
            #potential_fn=lambda param: beta * loss_fn(param),
            potential_fn=potential_fn,
            step_size=mala_step_size,
            trajectory_length=mala_step_size*num_steps,
            adapt_step_size=False,
            adapt_mass_matrix=False
        )
        mcmc = MCMC(hmc_kernel, num_samples=10000, num_warmup=0)
        rngkey, subkey = jax.random.split(rngkey)
        mcmc.run(subkey, init_params=param_init)

        # Extract samples
        samples = mcmc.get_samples()

        init_loss = loss_fn(param_init)
        losses = (jax.lax.map(loss_fn, samples))
        lambdahat = ((np.mean(losses) - init_loss) * beta).item()

        mala_time = time.time() - mala_start_time

        mala_lambdas[-1].append(lambdahat)
        mala_times[-1].append(mala_time)

        print("MALA lambda: ", lambdahat)
        print("MALA execution time: ", mala_time)

# Save results to JSON
results = {
    "dataset_sizes": dataset_sizes.tolist(),
    "SGLD": {
        "lambdas": sgld_lambdas,
        "times": sgld_times
    },
    "MALA": {
        "lambdas": mala_lambdas,
        "times": mala_times
    },
}
json.dump(results, open("mala_v_sgld_results.json", "w"))


Training set size: 100, run 0
x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  57.37425994873047
SGLD accept prob:  0.99507684
SGLD execution time:  33.13397693634033


sample: 100%|██████████| 10000/10000 [05:20<00:00, 31.23it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  50.44807434082031
MALA execution time:  322.0246331691742
Training set size: 100, run 1
x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  57.60322952270508
SGLD accept prob:  0.9949707
SGLD execution time:  31.66882586479187


sample: 100%|██████████| 10000/10000 [05:16<00:00, 31.60it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  56.770263671875
MALA execution time:  317.98200130462646
Training set size: 100, run 2
x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  55.59379577636719
SGLD accept prob:  0.99578637
SGLD execution time:  32.193074226379395


sample: 100%|██████████| 10000/10000 [05:20<00:00, 31.18it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  53.56113815307617
MALA execution time:  321.9629626274109
Training set size: 100, run 3
x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  52.4981689453125
SGLD accept prob:  0.9961521
SGLD execution time:  33.36083197593689


sample: 100%|██████████| 10000/10000 [05:20<00:00, 31.21it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  58.9940185546875
MALA execution time:  321.6508927345276
Training set size: 100, run 4
x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  51.65175247192383
SGLD accept prob:  0.99534637
SGLD execution time:  31.509000062942505


sample: 100%|██████████| 10000/10000 [05:21<00:00, 31.07it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  51.5749397277832
MALA execution time:  323.8936080932617
Training set size: 316, run 0
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  90.11341857910156
SGLD accept prob:  0.9918604
SGLD execution time:  32.32085371017456


sample: 100%|██████████| 10000/10000 [05:26<00:00, 30.61it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  82.34846496582031
MALA execution time:  328.3618640899658
Training set size: 316, run 1
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  88.00234985351562
SGLD accept prob:  0.992213
SGLD execution time:  31.303765296936035


sample: 100%|██████████| 10000/10000 [03:33<00:00, 46.86it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  88.04817199707031
MALA execution time:  214.82731652259827
Training set size: 316, run 2
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  87.2174301147461
SGLD accept prob:  0.99178475
SGLD execution time:  33.534610748291016


sample: 100%|██████████| 10000/10000 [01:33<00:00, 107.26it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  86.43769073486328
MALA execution time:  95.11934041976929
Training set size: 316, run 3
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  84.71837615966797
SGLD accept prob:  0.99199104
SGLD execution time:  31.8625328540802


sample: 100%|██████████| 10000/10000 [03:31<00:00, 47.18it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  91.58927917480469
MALA execution time:  213.74442982673645
Training set size: 316, run 4
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  83.54853057861328
SGLD accept prob:  0.9935224
SGLD execution time:  32.6684365272522


sample: 100%|██████████| 10000/10000 [03:32<00:00, 47.11it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  83.38666534423828
MALA execution time:  213.8702564239502
Training set size: 1000, run 0
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  140.54531860351562
SGLD accept prob:  0.9863273
SGLD execution time:  32.47716474533081


sample: 100%|██████████| 10000/10000 [03:33<00:00, 46.95it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  139.86952209472656
MALA execution time:  215.41528844833374
Training set size: 1000, run 1
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  135.913818359375
SGLD accept prob:  0.98634744
SGLD execution time:  31.702524423599243


sample: 100%|██████████| 10000/10000 [03:31<00:00, 47.21it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  141.77679443359375
MALA execution time:  214.84855699539185
Training set size: 1000, run 2
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  138.0520782470703
SGLD accept prob:  0.98419017
SGLD execution time:  31.822795629501343


sample: 100%|██████████| 10000/10000 [03:32<00:00, 47.13it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  141.62376403808594
MALA execution time:  214.1477770805359
Training set size: 1000, run 3
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  133.38619995117188
SGLD accept prob:  0.98663634
SGLD execution time:  32.38638710975647


sample: 100%|██████████| 10000/10000 [03:32<00:00, 47.13it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  141.63909912109375
MALA execution time:  215.22088623046875
Training set size: 1000, run 4
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  136.6833038330078
SGLD accept prob:  0.9862479
SGLD execution time:  32.265281200408936


sample: 100%|██████████| 10000/10000 [03:31<00:00, 47.36it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  135.15985107421875
MALA execution time:  213.16329979896545
Training set size: 3162, run 0
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  216.7156219482422
SGLD accept prob:  0.9768807
SGLD execution time:  32.4572389125824


sample: 100%|██████████| 10000/10000 [03:44<00:00, 44.48it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  219.56593322753906
MALA execution time:  230.1003017425537
Training set size: 3162, run 1
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  210.0811767578125
SGLD accept prob:  0.9786686
SGLD execution time:  33.4709038734436


sample: 100%|██████████| 10000/10000 [03:43<00:00, 44.69it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  215.56591796875
MALA execution time:  228.94415950775146
Training set size: 3162, run 2
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  211.4884490966797
SGLD accept prob:  0.9793658
SGLD execution time:  32.61261057853699


sample: 100%|██████████| 10000/10000 [03:44<00:00, 44.54it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  217.02134704589844
MALA execution time:  230.35093355178833
Training set size: 3162, run 3
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  207.4788055419922
SGLD accept prob:  0.9762826
SGLD execution time:  32.9333279132843


sample: 100%|██████████| 10000/10000 [03:41<00:00, 45.07it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  217.24588012695312
MALA execution time:  227.01857805252075
Training set size: 3162, run 4
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  210.2639923095703
SGLD accept prob:  0.98327565
SGLD execution time:  32.28909945487976


sample: 100%|██████████| 10000/10000 [03:42<00:00, 44.92it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  207.61856079101562
MALA execution time:  227.46654200553894
Training set size: 10000, run 0
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  296.7646179199219
SGLD accept prob:  0.9770875
SGLD execution time:  31.998055458068848


sample: 100%|██████████| 10000/10000 [04:19<00:00, 38.53it/s, 1 steps of size 3.16e-03. acc. prob=0.96]


MALA lambda:  300.5058288574219
MALA execution time:  273.6828966140747
Training set size: 10000, run 1
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  295.031005859375
SGLD accept prob:  0.98025537
SGLD execution time:  32.08233976364136


sample: 100%|██████████| 10000/10000 [04:18<00:00, 38.62it/s, 1 steps of size 3.16e-03. acc. prob=0.95]


MALA lambda:  293.9537658691406
MALA execution time:  272.72760701179504
Training set size: 10000, run 2
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  297.7831115722656
SGLD accept prob:  0.98122716
SGLD execution time:  33.12505316734314


sample: 100%|██████████| 10000/10000 [04:17<00:00, 38.91it/s, 1 steps of size 3.16e-03. acc. prob=0.96]


MALA lambda:  302.910400390625
MALA execution time:  269.8833963871002
Training set size: 10000, run 3
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  286.37640380859375
SGLD accept prob:  0.97771543
SGLD execution time:  33.54761815071106


sample: 100%|██████████| 10000/10000 [04:20<00:00, 38.40it/s, 1 steps of size 3.16e-03. acc. prob=0.96]


MALA lambda:  308.1272277832031
MALA execution time:  273.35159492492676
Training set size: 10000, run 4
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  298.7560119628906
SGLD accept prob:  0.9783385
SGLD execution time:  32.09901571273804


sample: 100%|██████████| 10000/10000 [04:19<00:00, 38.61it/s, 1 steps of size 3.16e-03. acc. prob=0.94]


MALA lambda:  275.41015625
MALA execution time:  271.9844448566437
Training set size: 31623, run 0
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  344.8536071777344
SGLD accept prob:  0.98994964
SGLD execution time:  33.68674302101135


sample: 100%|██████████| 10000/10000 [05:46<00:00, 28.90it/s, 1 steps of size 3.16e-03. acc. prob=0.84]


MALA lambda:  355.89202880859375
MALA execution time:  382.2019579410553
Training set size: 31623, run 1
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  350.6094055175781
SGLD accept prob:  0.9908456
SGLD execution time:  32.47212839126587


sample: 100%|██████████| 10000/10000 [05:48<00:00, 28.68it/s, 1 steps of size 3.16e-03. acc. prob=0.79]


MALA lambda:  358.1321105957031
MALA execution time:  382.63348960876465
Training set size: 31623, run 2
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  355.47833251953125
SGLD accept prob:  0.99085
SGLD execution time:  32.135979890823364


sample: 100%|██████████| 10000/10000 [05:47<00:00, 28.81it/s, 1 steps of size 3.16e-03. acc. prob=0.80]


MALA lambda:  369.1020812988281
MALA execution time:  381.5964562892914
Training set size: 31623, run 3
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  342.5532531738281
SGLD accept prob:  0.99009657
SGLD execution time:  31.803086280822754


sample: 100%|██████████| 10000/10000 [05:46<00:00, 28.89it/s, 1 steps of size 3.16e-03. acc. prob=0.82]


MALA lambda:  369.27911376953125
MALA execution time:  381.2685797214508
Training set size: 31623, run 4
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  353.2110900878906
SGLD accept prob:  0.98845124
SGLD execution time:  32.564656019210815


sample: 100%|██████████| 10000/10000 [05:48<00:00, 28.69it/s, 1 steps of size 3.16e-03. acc. prob=0.73]


MALA lambda:  343.0865783691406
MALA execution time:  384.88666677474976
Training set size: 100000, run 0
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  357.1268005371094
SGLD accept prob:  0.98993737
SGLD execution time:  33.648396015167236


sample: 100%|██████████| 10000/10000 [15:44<00:00, 10.59it/s, 1 steps of size 3.16e-03. acc. prob=0.33]


MALA lambda:  352.4219055175781
MALA execution time:  1054.4247555732727
Training set size: 100000, run 1
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  359.12408447265625
SGLD accept prob:  0.9944546
SGLD execution time:  32.02788949012756


sample: 100%|██████████| 10000/10000 [15:41<00:00, 10.62it/s, 1 steps of size 3.16e-03. acc. prob=0.25]


MALA lambda:  345.0095520019531
MALA execution time:  1052.1668529510498
Training set size: 100000, run 2
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  362.73468017578125
SGLD accept prob:  0.99624765
SGLD execution time:  32.43241786956787


sample: 100%|██████████| 10000/10000 [15:39<00:00, 10.64it/s, 1 steps of size 3.16e-03. acc. prob=0.26]


MALA lambda:  347.5590515136719
MALA execution time:  1049.0516211986542
Training set size: 100000, run 3
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  350.6711120605469
SGLD accept prob:  0.9890411
SGLD execution time:  33.73149871826172


sample: 100%|██████████| 10000/10000 [16:15<00:00, 10.25it/s, 1 steps of size 3.16e-03. acc. prob=0.26]


MALA lambda:  355.29193115234375
MALA execution time:  1088.828677892685
Training set size: 100000, run 4
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  349.2572021484375
SGLD accept prob:  0.9918579
SGLD execution time:  33.226011514663696


sample: 100%|██████████| 10000/10000 [15:46<00:00, 10.57it/s, 1 steps of size 3.16e-03. acc. prob=0.20]


MALA lambda:  345.3625183105469
MALA execution time:  1058.6680824756622
