# experiment: performance of our mcmc approach

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib agg


import global_settings
from utils import experiments, results
import jax
from multiprocessing import Pool
from tqdm import tqdm
import jax.numpy as jnp
import os



In [2]:
experiment = experiments.ExperimentSampleStandard(
    settings=experiments.settings.SettingsExperimentSample(
        output_path=global_settings.PATH_PAPER_RESULTS,
        dataset=global_settings.DATASET_NAMES[0],
        dataset_normalization="standardization",
        hidden_layers=1,
        hidden_neurons=3,
        activation="tanh",
        activation_last_layer="none",
        num_warmup=2**10,
        statistic="reduced",
        statistic_p=0.99,
        samples_per_chain=1,
        identifiable_modes=3,
        pool_size=1, # has no effect in jupyter notebook, use script instead
        seed=0,
        overwrite_chains=None
    )
)

## run experiment

In [3]:
experiment.run()

model transformation parameters 10
number of chains: 1274


100%|███████████████████████████████████████████████████████████████████████████████████████████| 1274/1274 [21:20<00:00,  1.01s/it]


## save experiment

In [10]:
result = experiment.result()

## evaluation

In [4]:
def computed_lppd_mcmc(inputs, outputs, parameters_network, parameters_data_std, regression_model):
    """variant of lppd for mcmc samples"""
    log_prob_means = []
    for xi, yi in zip(inputs, outputs):
        yi_preds = jax.vmap(regression_model._transformation.apply_from_vector, in_axes=(None, 0))(xi, parameters_network)
        prob_mean = jnp.exp(regression_model._outputs_likelihood(
            yi_preds,
            parameters_data_std).log_prob(yi)).mean(axis=0)
        if prob_mean > 0.0:
            log_prob_means.append(jnp.log(prob_mean))
    log_prob_means = jnp.array(log_prob_means)
    return log_prob_means

def computed_lppd_la(inputs, outputs, parameters_network, parameters_data_std, regression_model):
    log_prob_means = []
    for xi, yi in zip(inputs, outputs):
        yi_preds = jax.vmap(regression_model._transformation.apply_from_vector, in_axes=(None, 0))(xi, parameters_network)
        yi_std = parameters_data_std
        prob_mean = jnp.exp(regression_model._outputs_likelihood(
            yi_preds,
            yi_std).log_prob(yi)).mean(axis=0)
        if prob_mean > 0.0:
            log_prob_means.append(jnp.log(prob_mean))
    log_prob_means = jnp.array(log_prob_means)
    return log_prob_means

def computed_lppd_de(inputs, outputs, parameters_network, parameters_data_std, regression_model):
    log_prob_means = []
    for xi, yi in zip(inputs, outputs):
        yi_preds = jax.vmap(regression_model._transformation.apply_from_vector, in_axes=(None, 0))(xi, parameters_network)
        # means
        mean = yi_preds.mean(0)
        variance = (jnp.power(parameters_data_std, 2) + jnp.power(yi_preds, 2)).mean(0) - jnp.power(mean, 2)
        std = jnp.power(variance, 0.5)
        predictive_prob = jnp.exp(distributions.Normal(mean, std).log_prob(yi))
        log_prob_means.append(jnp.log(predictive_prob))
    log_prob_means = jnp.array(log_prob_means)
    return log_prob_means

def load_dataset(name):
    if name in global_settings.DATASET_NAMES_TOY:
        with open(os.path.join(global_settings.PATH_DATASETS, "toy_dataset_indices_0.2.json"), 'r') as f:
            indices = json.load(f)
            split = {
                "data_train": indices[name]["train"],
                "data_validate": [],
                "data_test": indices[name]["validate"] # validate as test data, since we do not need validation data...
            }
    elif name in global_settings.DATASET_NAMES_BENCHMARK:
        print(name)
        with open(os.path.join(os.path.join(global_settings.PATH_DATASETS, "benchmark_data"), "dataset_indices_0.2.json"), 'r') as f:
            indices = json.load(f)
            split = {
                "data_train": indices[name]["train"],
                "data_validate": [],
                "data_test": indices[name]["validate"] # validate as test data, since we do not need validation data...
            }
    else:
        return None
        
    # load dataset
    if name == "izmailov":
        dataset = datasets.Izmailov(split=split)
    elif name == "sinusoidal":
        dataset = datasets.Sinusoidal(split=split)
    elif name == "regression2d":
        dataset = datasets.Regression2d(split=split)
    elif name in global_settings.DATASET_NAMES_BENCHMARK:
        dataset = datasets.GenericBenchmark(dataset_name=name, split=split)
    else:
        return None
    
    return dataset

In [None]:
# TODO: I need evaluation methods for both: from experiment and from file

In [7]:
# retrieve relevant results
dataset = experiment._dataset
samples = experiment._samples
model = experiment._model

# compute lppd
inputs = dataset.data_test[:, dataset.conditional_indices]
outputs = dataset.data_test[:, dataset.dependent_indices]
log_prob_means = computed_lppd_mcmc(
    inputs=inputs,
    outputs=outputs,
    parameters_network=samples["parameters"],
    parameters_data_std=samples["std"],
    regression_model=model
)

log_probs_mean = jnp.mean(log_prob_means, axis=0)
log_probs_std = jnp.std(log_prob_means, axis=0)
log_probs_std_error = log_probs_std / jnp.sqrt(inputs.shape[0])

if len(result.indices_test) == 0:
    print("WARNING: trained on entire dataset?")
mcmc_str = "{:.2f}, {:.2f}".format(log_probs_mean.item(), log_probs_std_error.item())
print("mcmc (ours)", experiment._settings.dataset, mcmc_str)

NameError: name 'result' is not defined