In [None]:
%pip install --quiet -U pip -r requirements.txt dm-acme[jax]==0.4.0

In [None]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import jax
import pickle

from discounting_chain.bmg_a2c import create_bmg_a2c_agent
from discounting_chain.meta_a2c import create_meta_a2c_agent
from discounting_chain.train_utils import run
from discounting_chain.list_logger import ListLogger
from discounting_chain.nets import create_linear_forward_fn
from discounting_chain.envs.gymnax_dc_wrapper import create_dc_gmnax

In [None]:
# Fixed Parameters for all runs
# Agent config
init_discount = 0.95
outer_discount = 1.0
inner_lambda_gae = 0.0
outer_lambda_gae = 0.0
inner_entropy_cost = 0.005
outer_entropy_cost = 0.005
lr = 0.5
meta_lr = 0.1

normalise = False
batch_size = 128
mapping_seed = 3
sgd_optimizer = True
sgd_meta_optimizer = False

env, true_value_fn = create_dc_gmnax(mapping_seed=mapping_seed)

# Run config
num_iterations = 1000
n_updates_per_iter = 50
n_step = 100

In [4]:
def run_experiment(use_bmg, meta_lr, meta_value_head, num_iterations, seed):
    forward_fn = create_linear_forward_fn(
            environment_spec=env.spec,
            double_value_head=meta_value_head,
        )
    if not use_bmg:
        print(f"running MG with meta value head {meta_value_head}")
        agent = create_meta_a2c_agent(
            env=env,
            forward_fn=forward_fn,
            true_value_fn=true_value_fn,
            meta_value_head=meta_value_head,
            lr=lr,
            meta_lr=meta_lr,
            init_discount=init_discount,
            batch_size_per_device=batch_size,
            entropy_cost=inner_entropy_cost,
            outer_discount=outer_discount,
            n_step=n_step,
            lambda_gae=inner_lambda_gae,  # 1 step TD error
            outer_lambda_gae=outer_lambda_gae,
            normalise=normalise,
            sgd_optimizer=sgd_optimizer,
            sgd_meta_optimizer=sgd_meta_optimizer,
            outer_entropy_cost=outer_entropy_cost,
            outer_policy_grad_cost=1.0,
            outer_critic_cost=0.0,
            policy_grad_cost=1.0,
            critic_cost=0.0,
        )
    else:
        print(f"running BMG meta value head {meta_value_head}")
        agent = create_bmg_a2c_agent(
            true_value_fn=true_value_fn,
            env=env,
            forward_fn=forward_fn,
            meta_value_head=meta_value_head,
            lr=lr,
            meta_lr=meta_lr,
            init_discount=init_discount,
            batch_size_per_device=batch_size,
            outer_discount=outer_discount,
            n_step=n_step,
            policy_grad_cost_inner=1.0,
            critic_cost_inner=0.0,
            entropy_cost_inner=inner_entropy_cost,
            lambda_gae_inner=inner_lambda_gae,  # 1 step TD error
            outer_lambda_gae=outer_lambda_gae,
            outer_critic_cost=0.0,
            outer_entropy_cost=outer_entropy_cost,
            outer_policy_grad_cost=1.0,
            n_bootstrap_target_updates=1,  # 1 target update
            normalise=normalise,
            sgd_optimizer=sgd_optimizer,
            sgd_meta_optimizer=sgd_meta_optimizer,
            only_bmg_updates=False,
            kl_over_full_batch=False,
            use_outer_optimizer=False,
        )
    logger = ListLogger()
    run(num_iterations, n_updates_per_iter, agent, logger, seed)
    return logger.history

# 10 Seeds

In [None]:
num_seeds = 10

# A2C
agent_history = []
for seed in range(num_seeds):
    history = run_experiment(False, 0.0, False, num_iterations, seed)
    agent_history.append(history)
histories = [agent_history]

# Meta agents
for meta_value_head in [False, True]:
    for use_bmg in [False, True]:
        agent_history = []
        for seed in range(num_seeds):
            history = run_experiment(use_bmg, meta_lr, meta_value_head, num_iterations, seed)
            agent_history.append(history)
        histories.append(agent_history)


In [None]:
histories_array = [
    jax.tree_util.tree_map(lambda *xs: np.stack(xs), *[
        {key: np.array(value) for key, value in agent_history_per_seed.items()}
        for agent_history_per_seed in agent_histories
    ])
    for agent_histories in histories
]

In [None]:
with open("discounting_chain/data/discounting_chain_histories_array.pickle", "wb") as f:
    pickle.dump(histories_array, f)

## Appendix - advantage normalisation

In [None]:
# Fixed Parameters for all runs
# Agent config
init_discount = 0.95
outer_discount = 1.0
inner_lambda_gae = 0.0
outer_lambda_gae = 0.0
inner_entropy_cost = 0.005
outer_entropy_cost = 0.005
lr = 0.5
meta_lr = 0.1

normalise = True  # This differs from the main paper
batch_size = 128
mapping_seed = 3
sgd_optimizer = True
sgd_meta_optimizer = False

env, true_value_fn = create_dc_gmnax(mapping_seed=mapping_seed)

# Run config
num_iterations = 1000
n_updates_per_iter = 50
n_step = 100

In [6]:
def run_experiment(use_bmg, meta_lr, meta_value_head, num_iterations, seed):
    forward_fn = create_linear_forward_fn(
            environment_spec=env.spec,
            double_value_head=meta_value_head,
        )
    if not use_bmg:
        print(f"running MG with meta value head {meta_value_head}")
        agent = create_meta_a2c_agent(
            env=env,
            forward_fn=forward_fn,
            true_value_fn=true_value_fn,
            meta_value_head=meta_value_head,
            lr=lr,
            meta_lr=meta_lr,
            init_discount=init_discount,
            batch_size_per_device=batch_size,
            entropy_cost=inner_entropy_cost,
            outer_discount=outer_discount,
            n_step=n_step,
            lambda_gae=inner_lambda_gae,  # 1 step TD error
            outer_lambda_gae=outer_lambda_gae,
            normalise=normalise,
            sgd_optimizer=sgd_optimizer,
            sgd_meta_optimizer=sgd_meta_optimizer,
            outer_entropy_cost=outer_entropy_cost,
            outer_policy_grad_cost=1.0,
            outer_critic_cost=0.0,
            policy_grad_cost=1.0,
            critic_cost=0.0,
        )
    else:
        print(f"running BMG meta value head {meta_value_head}")
        agent = create_bmg_a2c_agent(
            true_value_fn=true_value_fn,
            env=env,
            forward_fn=forward_fn,
            meta_value_head=meta_value_head,
            lr=lr,
            meta_lr=meta_lr,
            init_discount=init_discount,
            batch_size_per_device=batch_size,
            outer_discount=outer_discount,
            n_step=n_step,
            policy_grad_cost_inner=1.0,
            critic_cost_inner=0.0,
            entropy_cost_inner=inner_entropy_cost,
            lambda_gae_inner=inner_lambda_gae,  # 1 step TD error
            outer_lambda_gae=outer_lambda_gae,
            outer_critic_cost=0.0,
            outer_entropy_cost=outer_entropy_cost,
            outer_policy_grad_cost=1.0,
            n_bootstrap_target_updates=1,  # 1 target update
            normalise=normalise,
            sgd_optimizer=sgd_optimizer,
            sgd_meta_optimizer=sgd_meta_optimizer,
            only_bmg_updates=False,
            kl_over_full_batch=False,
            use_outer_optimizer=False,
        )
    logger = ListLogger()
    run(num_iterations, n_updates_per_iter, agent, logger, seed)
    return logger.history

In [None]:
# A2C
appendix_history = run_experiment(False, 0.0, False, num_iterations, seed=0)
appendix_histories = [appendix_history]

# Meta agents
for meta_value_head in [False, True]:
    for use_bmg in [False, True]:
        appendix_history = run_experiment(use_bmg, meta_lr, meta_value_head, num_iterations, seed=0)
        appendix_histories.append(appendix_history)


In [9]:
appendix_histories_array = [
    {key: np.array(value) for key, value in agent_history.items()}
    for agent_history in appendix_histories
]

In [10]:
with open("discounting_chain/data/discounting_chain_appendix_histories_array.pickle", "wb") as f:
    pickle.dump(appendix_histories_array, f)