# Experiment: Performance of MCMC

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib agg


import global_settings
from utils import experiments, results, evaluation
import jax
from multiprocessing import Pool
from tqdm import tqdm
import jax.numpy as jnp
import os



In [2]:
experiment = experiments.ExperimentSampleStandard(
    settings=experiments.settings.SettingsExperimentSample(
        output_path=global_settings.PATH_RESULTS,
        dataset=global_settings.DATASET_NAMES[1],
        dataset_normalization="standardization",
        hidden_layers=1,
        hidden_neurons=3,
        activation="tanh",
        activation_last_layer="none",
        num_warmup=2**10,
        statistic="reduced",
        statistic_p=0.99,
        samples_per_chain=1,
        identifiable_modes=3,
        pool_size=1, # has no effect in jupyter notebook, use script instead
        seed=0,
        overwrite_chains=None
    )
)

# for single chain approach, set for example:
# * statistic_p=0.0
# * identifiable_modes=1
# * samples_per_chain=1274

## Run Experiment

In [3]:
experiment.run()

model transformation parameters 10
number of chains: 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:19<00:00, 19.84s/it]


## Save Experiment

## Load experiment

## Evaluation

In [5]:
dataset = experiment._dataset
samples = experiment.result.samples
model = experiment._model

# compute lppd
inputs = dataset.data_test[:, dataset.conditional_indices]
outputs = dataset.data_test[:, dataset.dependent_indices]
log_prob_means = evaluation.computed_lppd_mcmc(
    inputs=inputs,
    outputs=outputs,
    parameters_network=samples["parameters"],
    parameters_data_std=samples["std"],
    regression_model=model
)

log_probs_mean = jnp.mean(log_prob_means, axis=0)
log_probs_std = jnp.std(log_prob_means, axis=0)
log_probs_std_error = log_probs_std / jnp.sqrt(inputs.shape[0])

if len(experiment.result.indices_test) == 0:
    print("WARNING: trained on entire dataset?")
mcmc_str = "lppd: {:.2f}, std_error: {:.2f}".format(log_probs_mean.item(), log_probs_std_error.item())
print("mcmc", experiment._settings.dataset, mcmc_str)

mcmc izmailov lppd: 0.65, std_error: 0.07
