In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'
import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from functools import partial
import frozen_lake
import qlearning
import actions
import mangoenv
import utils
import nets
import optuna

In [2]:
@partial(jax.jit, static_argnames=("map_scale", "p"))
@partial(jax.vmap, in_axes=(0, None, None))
def setup_env(rng_key, map_scale, p):
    env = frozen_lake.FrozenLake.make_random(rng_key, map_scale, p)
    return env


@partial(jax.jit, static_argnames=("map_scale", "cell_scale", "lr", "tau"))
@partial(jax.vmap, in_axes=(0, 0, None, None, None, None))
def setup_dql_states(rng_key, env, map_scale, cell_scale, lr, tau):
    reward_fn = actions.get_reward_fn(cell_scale)
    beta_fn = actions.get_beta_fn(cell_scale)
    qnet = nets.MultiTaskQnet(
        n_actions=env.action_space.n,
        n_comands=5,
        map_shape=(2**map_scale, 2**map_scale),
        cell_shape=(2**cell_scale, 2**cell_scale),
    )
    dql_state = qlearning.MultiDQLTrainState.create(
        rng_key, qnet, env, reward_fn=reward_fn, beta_fn=beta_fn, lr=lr, soft_update_rate=tau
    )
    return dql_state

In [3]:
@partial(jax.jit, static_argnames=("rollout_steps", "n_train_iter", "batch_size", "eval_steps"))
@partial(jax.vmap, in_axes=(0, 0, 0, None, None, None, None))
def train_stage(rng_key, env, dql_state, rollout_steps, n_train_iter, batch_size, eval_steps):
    rng_rollout, rng_train, rng_eval = jax.random.split(rng_key, 3)
    n_rollouts = n_train_iter * batch_size // rollout_steps
    episodes = utils.multi_random_rollout(env, rng_rollout, rollout_steps, n_rollouts)
    replay_buffer = utils.CircularBuffer.store_episodes(episodes)

    @partial(jax.vmap, in_axes=(0, None, None))
    def eval_policy(rng_key, dql_state, steps):
        transitions = dql_state.greedy_rollout(env, rng_key, steps)
        episodes = jnp.clip(transitions.done.sum(), a_min=1)
        rewards = transitions.reward.sum()
        return rewards.sum() / episodes.sum()

    def train_step(dql_state, rng_key):
        transitions = replay_buffer.sample(rng_key, batch_size)
        dql_state = dql_state.update_params(transitions)
        return dql_state, None

    rng_steps = jax.random.split(rng_train, n_train_iter)
    rng_eval = jax.random.split(rng_eval, 32)
    dql_state, _ = jax.lax.scan(train_step, dql_state, rng_steps)
    accuracy = eval_policy(rng_eval, dql_state, eval_steps)
    return dql_state, accuracy.mean()

# Objective

In [4]:
def get_objective_fn(
    map_scale,
    p,
    cell_scales,
    max_steps,
    train_minicycles=32,
    n_sims=16,
    batch_size=256,
    rollout_steps=4,
):
    def objective(trial: optuna.Trial) -> jnp.float_:
        # global hyperparameters
        total_train_iter = max_steps // batch_size
        train_iter = total_train_iter // train_minicycles

        seed = 42
        trial.set_user_attr("max_steps", max_steps)
        trial.set_user_attr("cycle_train_iter", train_iter)
        trial.set_user_attr("batch_size", batch_size)
        trial.set_user_attr("rollout_steps", rollout_steps)

        # setup
        rng_key = jax.random.PRNGKey(seed)
        rng_env, rng_stages = jax.random.split(rng_key)
        rng_env = jax.random.split(rng_env, n_sims)
        rng_stages = jax.random.split(rng_stages, len(cell_scales))

        env = setup_env(rng_env, map_scale, p)

        cycles_accuracy = []
        burned_cycles = 0
        stages_lr = [
            trial.suggest_float(f"lr_{cell_scale}", 3e-5, 3e-3, log=True)
            for cell_scale in cell_scales
        ]
        stages_tau = [
            trial.suggest_float(f"tau_{cell_scale}", 1e-3, 1e-1, log=True)
            for cell_scale in cell_scales
        ]
        stages_train_cycles = [
            train_minicycles // 2 ** (len(cell_scales) - i - int(i == 0))
            for i in range(len(cell_scales))
        ]
        stages_eval_steps = [4 ** cell_scales[0]] + [
            4 ** (j - i) for i, j in zip(cell_scales[:-1], cell_scales[1:])
        ]

        for i, (rng_stage, cell_scale, lr, tau, train_cycles, eval_steps) in enumerate(
            zip(rng_stages, cell_scales, stages_lr, stages_tau, stages_train_cycles, stages_eval_steps)
        ):
            rng_init, rng_train = jax.random.split(rng_stage)
            rng_init = jax.random.split(rng_init, n_sims)
            dql_state = setup_dql_states(rng_init, env, map_scale, cell_scale, lr, tau)

            rng_cycles = jax.random.split(rng_train, train_cycles)
            for j, rng_cycle in enumerate(rng_cycles):
                rng_cycle = jax.random.split(rng_cycle, n_sims)
                dql_state, accuracy = train_stage(
                    rng_cycle, env, dql_state, rollout_steps, train_iter, batch_size, eval_steps
                )
                cycles_accuracy.append(accuracy)
                trial.report(accuracy.min(), step=(1 + j + sum(stages_train_cycles[:i])))
                if trial.should_prune():
                    raise optuna.TrialPruned()
            env = mangoenv.MangoEnv(env, dql_state, max_steps=eval_steps)

        # store and plot results
        accuracy_evol = jnp.stack(cycles_accuracy, axis=-1)
        stages_duration = jnp.array(stages_train_cycles)
        trial.set_user_attr(f"stages_duration", np.asarray(stages_duration).tolist())
        trial.set_user_attr(f"accuracy_evol_mean", np.asarray(accuracy_evol.mean(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_min", np.asarray(accuracy_evol.min(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_max", np.asarray(accuracy_evol.max(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_std", np.asarray(accuracy_evol.std(axis=0)).tolist())

        # evaluate objective
        final_accuracy = cycles_accuracy[-1].min()
        burned_cycles = jnp.array([acc.min() == 1 for acc in cycles_accuracy]).mean()
        return final_accuracy + (burned_cycles if final_accuracy == 1 else 0)

    return objective

# 4x4

In [5]:
map_scale = 2
p = 0.8
max_steps = 1024 * 256
storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"

for cell_scales in [(1, 2,), (2,)]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps),
        n_trials=256,
        n_jobs=2,
        show_progress_bar=True,
    )

[I 2024-03-14 23:22:41,341] A new study created in RDB with name: mango_stages_[1, 2]


  0%|          | 0/256 [00:00<?, ?it/s]

[I 2024-03-14 23:22:58,951] Trial 1 finished with value: 0.375 and parameters: {'lr_1': 0.00013539218025660025, 'lr_2': 0.0009373437223709245, 'tau_1': 0.01653357285179791, 'tau_2': 0.09462215703959587}. Best is trial 1 with value: 0.375.
[I 2024-03-14 23:22:58,958] Trial 0 finished with value: 0.3723958432674408 and parameters: {'lr_1': 0.0014263723472408313, 'lr_2': 5.637496279203831e-05, 'tau_1': 0.003659190186892462, 'tau_2': 0.0038347944168220503}. Best is trial 1 with value: 0.375.
[I 2024-03-14 23:23:02,322] Trial 2 pruned. 
[I 2024-03-14 23:23:05,318] Trial 4 pruned. 
[I 2024-03-14 23:23:09,099] Trial 3 finished with value: 0.65625 and parameters: {'lr_1': 0.0007753356148339908, 'lr_2': 0.0006125688387701214, 'tau_1': 0.009169886030406317, 'tau_2': 0.014648981476217111}. Best is trial 3 with value: 0.65625.
[I 2024-03-14 23:23:11,839] Trial 6 pruned. 
[I 2024-03-14 23:23:14,357] Trial 5 pruned. 
[I 2024-03-14 23:23:17,310] Trial 8 pruned. 
[I 2024-03-14 23:23:21,161] Trial 7 fi

[I 2024-03-14 23:34:48,272] A new study created in RDB with name: mango_stages_[2]


[I 2024-03-14 23:34:48,219] Trial 255 pruned. 


  0%|          | 0/256 [00:00<?, ?it/s]

[I 2024-03-14 23:34:54,698] Trial 0 finished with value: 0.75 and parameters: {'lr_2': 0.00029018219861029514, 'tau_2': 0.03790088386099555}. Best is trial 0 with value: 0.75.
[I 2024-03-14 23:34:55,045] Trial 1 finished with value: 0.875 and parameters: {'lr_2': 0.00048739445089215536, 'tau_2': 0.011841232680132388}. Best is trial 1 with value: 0.875.
[I 2024-03-14 23:35:00,679] Trial 2 finished with value: 0.84375 and parameters: {'lr_2': 0.0004157262233071495, 'tau_2': 0.01844407473006411}. Best is trial 1 with value: 0.875.
[I 2024-03-14 23:35:01,105] Trial 3 finished with value: 0.46875 and parameters: {'lr_2': 0.00011464878458950197, 'tau_2': 0.008051088257337491}. Best is trial 1 with value: 0.875.
[I 2024-03-14 23:35:03,986] Trial 5 pruned. 
[I 2024-03-14 23:35:06,479] Trial 6 pruned. 
[I 2024-03-14 23:35:06,936] Trial 4 finished with value: 0.78125 and parameters: {'lr_2': 0.00037130063120588907, 'tau_2': 0.006950146301021325}. Best is trial 1 with value: 0.875.
[I 2024-03-14 

# 8x8

In [5]:
map_scale = 3
p = 0.8
max_steps = 1024 * 1024 * 8

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
for cell_scales in [(1, 2, 3),]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}_long",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(),
        pruner=optuna.pruners.NopPruner(),
        #sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        #pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps),
        n_trials=128,
        show_progress_bar=True,
    )

[I 2024-03-16 01:32:44,953] A new study created in RDB with name: mango_stages_[1, 2, 3]_long


  0%|          | 0/128 [00:00<?, ?it/s]

# 16x16

In [None]:
map_scale = 4
p = 0.8
max_steps = 1024*1024*64

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
for cell_scales in [(2, 4),]:
    study = optuna.create_study(
        study_name=f"mango_stages_{list(cell_scales)}",
        storage=storage_path,
        sampler=optuna.samplers.CmaEsSampler(consider_pruned_trials=True),
        pruner=optuna.pruners.HyperbandPruner(),
        load_if_exists=True,
        direction="maximize",
    )
    study.optimize(
        get_objective_fn(map_scale, p, cell_scales, max_steps, train_minicycles=64),
        n_trials=16,
        n_jobs=1,
        show_progress_bar=True,
    )