In [1]:
import json
from functools import partial
from itertools import chain, repeat

import numpy as np
from toolz import pipe
from tqdm.auto import trange

import pyrlmala.envs
from pyrlmala.utils import Toolbox

In [None]:
model_name = "banana"
stan_code_path = f"{model_name}.stan"
stan_data_path = f"{model_name}.json"

with open(stan_data_path, "r") as f:
    data = json.load(f)

    log_target_pdf = Toolbox.make_log_target_pdf(stan_code_path, data)
    grad_log_target_pdf = Toolbox.make_grad_log_target_pdf(stan_code_path, data)

In [3]:
total_timesteps = 500_000
sample_dim = 2
initial_sample = np.zeros(sample_dim)
initial_covariance = np.eye(sample_dim)

step_size_config = [0.1, 0.5, 1.0, 2.0]
env_func_list = []

make_mala_env_func = partial(
    Toolbox.make_env,
    env_id="MALAEnv-v1.0",
    log_target_pdf=log_target_pdf,
    grad_log_target_pdf=grad_log_target_pdf,
    initial_sample=initial_sample,
    initial_covariance=initial_covariance,
    total_timesteps=total_timesteps,
)

make_barker_env_func = partial(
    Toolbox.make_env,
    env_id="BarkerEnv-v1.0",
    log_target_pdf=log_target_pdf,
    grad_log_target_pdf=grad_log_target_pdf,
    initial_sample=initial_sample,
    initial_covariance=initial_covariance,
    total_timesteps=total_timesteps,
)

actions = pipe(
    step_size_config,
    lambda x: (repeat(item, 4) for item in x),
    chain.from_iterable,
    list,
    np.array,
    lambda x: x.reshape(-1, 2),
    Toolbox.inverse_softplus,
)

In [4]:
barker_env_01 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([0.1]))
)()
mala_env_01 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([0.1]))
)()
barker_env_05 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([0.5]))
)()
mala_env_05 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([0.5]))
)()
barker_env_10 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([1.0]))
)()
mala_env_10 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([1.0]))
)()
barker_env_20 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([2.0]))
)()
mala_env_20 = make_mala_env_func(
    initial_step_size=Toolbox.inverse_softplus(np.array([2.0]))
)()

envs = {
    "barker_env_01": barker_env_01,
    "mala_env_01": mala_env_01,
    "barker_env_05": barker_env_05,
    "mala_env_05": mala_env_05,
    "barker_env_10": barker_env_10,
    "mala_env_10": mala_env_10,
    "barker_env_20": barker_env_20,
    "mala_env_20": mala_env_20,
}

In [5]:
for i in envs.values():
    _ = i.reset()


In [None]:
for _ in trange(total_timesteps):
    for idx, key in enumerate(envs):
        _ = envs[key].step(actions[idx])

In [7]:
np.savez_compressed("Data/store_accepted_sample.npz", **{key: value.get_wrapper_attr("store_accepted_sample") for key, value in envs.items()})
np.savez_compressed("Data/store_reward.npz", **{key: value.get_wrapper_attr("store_reward") for key, value in envs.items()})