In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from functools import partial
import frozen_lake
import qlearning
import actions
import mangoenv
import utils
import nets
import optuna

In [2]:
@partial(jax.jit, static_argnames=("map_scale", "p"))
@partial(jax.vmap, in_axes=(0, None, None))
def setup_env(rng_key, map_scale, p):
    env = frozen_lake.FrozenLake.make_random(rng_key, map_scale, p)
    return env


@partial(jax.jit, static_argnames=("map_scale", "cell_scale", "lr", "tau"))
@partial(jax.vmap, in_axes=(0, 0, None, None, None, None))
def setup_dql_states(rng_key, env, map_scale, cell_scale, lr, tau):
    reward_fn = actions.get_reward_fn(cell_scale)
    beta_fn = actions.get_beta_fn(cell_scale)
    qnet = nets.MultiTaskQnet(
        n_actions=env.action_space.n,
        n_comands=5,
        map_shape=(2**map_scale, 2**map_scale),
        cell_shape=(2**cell_scale, 2**cell_scale),
    )
    dql_state = qlearning.MultiDQLTrainState.create(
        rng_key, qnet, env, reward_fn=reward_fn, beta_fn=beta_fn, lr=lr, soft_update_rate=tau
    )
    return dql_state

In [3]:
@partial(jax.jit, static_argnames=("rollout_steps", "n_train_iter", "batch_size", "eval_steps"))
@partial(jax.vmap, in_axes=(0, 0, 0, None, None, None, None))
def train_stage(rng_key, env, dql_state, rollout_steps, n_train_iter, batch_size, eval_steps):
    rng_rollout, rng_train, rng_eval = jax.random.split(rng_key, 3)
    n_rollouts = n_train_iter * batch_size // rollout_steps
    episodes = utils.multi_random_rollout(env, rng_rollout, rollout_steps, n_rollouts)
    replay_buffer = utils.CircularBuffer.store_episodes(episodes)

    @partial(jax.vmap, in_axes=(0, None, None))
    def eval_policy(rng_key, dql_state, steps):
        transitions = dql_state.greedy_rollout(env, rng_key, steps)
        episodes = jnp.clip(transitions.done.sum(), a_min=1)
        rewards = transitions.reward.sum()
        return rewards.sum() / episodes.sum()

    def train_step(dql_state, rng_key):
        transitions = replay_buffer.sample(rng_key, batch_size)
        dql_state = dql_state.update_params(transitions)
        return dql_state, None

    rng_steps = jax.random.split(rng_train, n_train_iter)
    rng_eval = jax.random.split(rng_eval, 32)
    dql_state, _ = jax.lax.scan(train_step, dql_state, rng_steps)
    accuracy = eval_policy(rng_eval, dql_state, eval_steps)
    return dql_state, accuracy.mean()

# Objective

In [4]:
def get_objective_fn(
    map_scale,
    p,
    cell_scales,
    max_steps,
    train_minicycles=32,
    n_sims=16,
    batch_size=256,
    rollout_steps=4,
):
    def objective(trial: optuna.Trial) -> jnp.float_:
        # global hyperparameters
        total_train_iter = max_steps // batch_size
        train_iter = total_train_iter // train_minicycles

        seed = 42
        trial.set_user_attr("max_steps", max_steps)
        trial.set_user_attr("cycle_train_iter", train_iter)
        trial.set_user_attr("batch_size", batch_size)
        trial.set_user_attr("rollout_steps", rollout_steps)

        # setup
        rng_key = jax.random.PRNGKey(seed)
        rng_env, rng_stages = jax.random.split(rng_key)
        rng_env = jax.random.split(rng_env, n_sims)
        rng_stages = jax.random.split(rng_stages, len(cell_scales))

        env = setup_env(rng_env, map_scale, p)

        cycles_accuracy = []
        burned_cycles = 0
        stages_lr = [
            trial.suggest_float(f"lr_{cell_scale}", 3e-5, 3e-3, log=True)
            for cell_scale in cell_scales
        ]
        stages_tau = [
            trial.suggest_float(f"tau_{cell_scale}", 1e-3, 1e-1, log=True)
            for cell_scale in cell_scales
        ]
        stages_train_cycles = [
            train_minicycles // 2 ** (len(cell_scales) - i - int(i == 0))
            for i in range(len(cell_scales))
        ]
        stages_eval_steps = [4 ** cell_scales[0]] + [
            4 ** (j - i) for i, j in zip(cell_scales[:-1], cell_scales[1:])
        ]

        for i, (rng_stage, cell_scale, lr, tau, train_cycles, eval_steps) in enumerate(
            zip(rng_stages, cell_scales, stages_lr, stages_tau, stages_train_cycles, stages_eval_steps)
        ):
            rng_init, rng_train = jax.random.split(rng_stage)
            rng_init = jax.random.split(rng_init, n_sims)
            dql_state = setup_dql_states(rng_init, env, map_scale, cell_scale, lr, tau)

            rng_cycles = jax.random.split(rng_train, train_cycles)
            for j, rng_cycle in enumerate(rng_cycles):
                rng_cycle = jax.random.split(rng_cycle, n_sims)
                dql_state, accuracy = train_stage(
                    rng_cycle, env, dql_state, rollout_steps, train_iter, batch_size, eval_steps
                )
                cycles_accuracy.append(accuracy)
                trial.report(accuracy.min(), step=(1 + j + sum(stages_train_cycles[:i])))
                if trial.should_prune():
                    raise optuna.TrialPruned()
            env = mangoenv.MangoEnv(env, dql_state, max_steps=eval_steps)

        # store and plot results
        accuracy_evol = jnp.stack(cycles_accuracy, axis=-1)
        stages_duration = jnp.array(stages_train_cycles)
        trial.set_user_attr(f"stages_duration", np.asarray(stages_duration).tolist())
        trial.set_user_attr(f"accuracy_evol_mean", np.asarray(accuracy_evol.mean(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_min", np.asarray(accuracy_evol.min(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_max", np.asarray(accuracy_evol.max(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_std", np.asarray(accuracy_evol.std(axis=0)).tolist())

        # evaluate objective
        final_accuracy = cycles_accuracy[-1].min()
        burned_cycles = jnp.array([acc.min() == 1 for acc in cycles_accuracy]).mean()
        return final_accuracy + (burned_cycles if final_accuracy == 1 else 0)

    return objective

# 4x4

In [5]:
map_scale = 2
p = 0.5
max_steps = 1024 * 256
storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"

for cell_scales in [(1, 2,), (2,)]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps),
        n_trials=256,
        n_jobs=2,
        show_progress_bar=True,
    )

[I 2024-03-14 23:22:58,904] A new study created in RDB with name: mango_stages_[1, 2]


  0%|          | 0/256 [00:00<?, ?it/s]

[I 2024-03-14 23:23:16,270] Trial 0 finished with value: 0.4114583432674408 and parameters: {'lr_1': 4.4834723444644425e-05, 'lr_2': 6.0479638452289986e-05, 'tau_1': 0.09352430102254537, 'tau_2': 0.010748709706708635}. Best is trial 0 with value: 0.4114583432674408.
[I 2024-03-14 23:23:16,293] Trial 1 finished with value: 0.296875 and parameters: {'lr_1': 4.0176189709593835e-05, 'lr_2': 0.0006178112367403856, 'tau_1': 0.0026876517832835605, 'tau_2': 0.030644391868586317}. Best is trial 0 with value: 0.4114583432674408.
[I 2024-03-14 23:23:26,808] Trial 3 pruned. 
[I 2024-03-14 23:23:28,194] Trial 2 finished with value: 0.3671875 and parameters: {'lr_1': 0.0007723061890731571, 'lr_2': 0.0002248239109464564, 'tau_1': 0.03787842496401685, 'tau_2': 0.009294555614485243}. Best is trial 0 with value: 0.4114583432674408.
[I 2024-03-14 23:23:29,402] Trial 4 pruned. 
[I 2024-03-14 23:23:32,781] Trial 6 pruned. 
[I 2024-03-14 23:23:36,910] Trial 5 finished with value: 0.546875 and parameters: {'

[I 2024-03-14 23:35:44,935] A new study created in RDB with name: mango_stages_[2]


[I 2024-03-14 23:35:44,879] Trial 255 finished with value: 0.84375 and parameters: {'lr_1': 0.002059274248560172, 'lr_2': 0.0018439742241434442, 'tau_1': 0.011857193623639106, 'tau_2': 0.004743617773240374}. Best is trial 36 with value: 1.0625.


  0%|          | 0/256 [00:00<?, ?it/s]

[I 2024-03-14 23:35:51,415] Trial 0 finished with value: 0.53125 and parameters: {'lr_2': 0.0011123012871886349, 'tau_2': 0.0013385239576933952}. Best is trial 0 with value: 0.53125.
[I 2024-03-14 23:35:51,932] Trial 1 finished with value: 0.65625 and parameters: {'lr_2': 0.0002762934101483573, 'tau_2': 0.006927619875594647}. Best is trial 1 with value: 0.65625.
[I 2024-03-14 23:35:57,728] Trial 3 finished with value: 0.65625 and parameters: {'lr_2': 0.00046273247575892584, 'tau_2': 0.003222298581419539}. Best is trial 1 with value: 0.65625.
[I 2024-03-14 23:35:57,928] Trial 2 finished with value: 0.5 and parameters: {'lr_2': 0.00014144831939111584, 'tau_2': 0.01138447785070632}. Best is trial 1 with value: 0.65625.
[I 2024-03-14 23:36:03,060] Trial 4 finished with value: 0.40625 and parameters: {'lr_2': 0.0001294071880255717, 'tau_2': 0.003858788888215129}. Best is trial 1 with value: 0.65625.
[I 2024-03-14 23:36:04,695] Trial 5 finished with value: 0.78125 and parameters: {'lr_2': 0.

# 8x8

In [5]:
map_scale = 3
p = 0.8
max_steps = 1024 * 1024 * 4

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
for cell_scales in [(1, 2, 3,)]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}_mask",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps),
        n_trials=256,
        n_jobs=2,
        show_progress_bar=True,
    )

[I 2024-03-15 20:24:32,451] Using an existing study with name 'mango_stages_[1, 2, 3]_mask' instead of creating a new one.


  0%|          | 0/256 [00:00<?, ?it/s]

[I 2024-03-15 20:25:03,557] Trial 99 pruned. 
[I 2024-03-15 20:28:22,683] Trial 100 finished with value: 0.7083333730697632 and parameters: {'lr_1': 0.0003381929475151006, 'lr_2': 0.001751388737563094, 'lr_3': 0.0002500690108722546, 'tau_1': 0.03038991414218805, 'tau_2': 0.0055058788889844395, 'tau_3': 0.0022865385912625887}. Best is trial 23 with value: 0.7447916865348816.
[I 2024-03-15 20:28:35,471] Trial 101 pruned. 
[I 2024-03-15 20:28:57,828] Trial 103 pruned. 
[I 2024-03-15 20:29:39,725] Trial 104 pruned. 
[I 2024-03-15 20:31:08,982] Trial 102 finished with value: 0.6927083730697632 and parameters: {'lr_1': 0.00040264044363855065, 'lr_2': 0.001681841000510362, 'lr_3': 0.00034739216324982536, 'tau_1': 0.029927274555204598, 'tau_2': 0.007563711941159954, 'tau_3': 0.004023871808191394}. Best is trial 23 with value: 0.7447916865348816.
[I 2024-03-15 20:32:59,043] Trial 105 pruned. 
[I 2024-03-15 20:34:26,451] Trial 107 pruned. 
[I 2024-03-15 20:34:48,186] Trial 106 finished with valu

# 16x16

In [5]:
map_scale = 4
p = 0.8
max_steps = 1024*1024*64

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
for cell_scales in [(2,4,),]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps, train_minicycles=64),
        n_trials=32,
        show_progress_bar=True,
    )

[I 2024-03-15 17:22:01,266] Using an existing study with name 'mango_stages_[2, 4]' instead of creating a new one.


  0%|          | 0/32 [00:00<?, ?it/s]

2024-03-15 17:29:11.818926: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 40.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


[W 2024-03-15 17:29:11,865] Trial 1 failed with parameters: {'lr_2': 7.315342877434472e-05, 'lr_4': 0.0029840027187507993, 'tau_2': 0.014156262722951923, 'tau_4': 0.004121465421498776} because of the following error: XlaRuntimeError('RESOURCE_EXHAUSTED: Out of memory while trying to allocate 42966450176 bytes.').
Traceback (most recent call last):
  File "/home/davide_sartor/.conda/envs/dl_env/lib/python3.10/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_3216847/2119128647.py", line 58, in objective
    dql_state, accuracy = train_stage(
  File "/home/davide_sartor/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/davide_sartor/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/davide_sartor/.loca

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 42966450176 bytes.