In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import host_callback
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from functools import partial
import frozen_lake
import plotting
import qlearning
import actions
import mangoenv
import utils
import nets
import optuna

In [2]:
def plot_accuracy_evolution(accuracy_evol, stages_duration):
    def smooth(x, w=0.01):
        filter = jnp.ones(int(1 + w * len(x)))
        smoothed = jnp.convolve(x, filter / filter.sum(), mode="full")
        return smoothed[:len(x)]

    for run_accuracy in accuracy_evol:
        plt.plot(smooth(run_accuracy))

    plt.plot(smooth(accuracy_evol.mean(axis=0)), label="mean", color="black", linewidth=3)
    for stage_duration in jnp.cumsum(jnp.array(stages_duration)):
        plt.axvline(stage_duration, color="red")
    plt.grid()
    plt.ylim(0, 1)
    plt.show()

In [3]:
def setup_env(map_scale, p, rng_key):
    env = frozen_lake.FrozenLake.make_random(rng_key, map_scale, p)
    return env


def setup_replay_buffer(env, rng_key, n_rollouts, rollout_steps):
    episodes = utils.multi_random_rollout(env, rng_key, rollout_steps, n_rollouts)
    replay_buffer = utils.CircularBuffer.store_episodes(episodes)
    return replay_buffer


def setup_dql_state(env, rng_key, lr, map_scale, cell_scale):
    reward_fn = actions.get_reward_fn(cell_scale)
    beta_fn = actions.get_beta_fn(cell_scale)
    qnet = nets.MultiTaskQnet(
        n_actions=env.action_space.n,
        n_comands=5,
        map_shape=(2**map_scale, 2**map_scale),
        cell_shape=(2**cell_scale, 2**cell_scale),
    )
    dql_state = qlearning.MultiDQLTrainState.create(
        rng_key, qnet, env, reward_fn=reward_fn, beta_fn=beta_fn, lr=lr
    )
    return dql_state


def eval_policy(env, dql_state, rng_key, episodes, steps):
    def eval_single(rng_key):
        transitions = dql_state.greedy_rollout(env, rng_key, steps)
        episodes = jnp.clip(transitions.done.sum(), a_min=1)
        rewards = transitions.reward.sum()
        return rewards, episodes

    rewards, episodes = jax.vmap(eval_single)(jax.random.split(rng_key, episodes))
    return rewards.sum() / episodes.sum()


def train_stage(rng_key, env, dql_state, replay_buffer, n_train_iter, batch_size, eval_steps):
    #pbar = tqdm(total=n_train_iter, desc="Training")

    def train_step(dql_state, rng_key):
        rng_train, rng_eval = jax.random.split(rng_key)
        # rng_train, rng_eval = host_callback.id_tap(
        #     lambda a, t: pbar.update(1), jax.random.split(rng_key)
        # )

        dql_state = dql_state.update_params(replay_buffer.sample(rng_train, batch_size))
        accuracy = eval_policy(env, dql_state, rng_eval, episodes=8, steps=eval_steps)
        return dql_state, accuracy

    rng_steps = jax.random.split(rng_key, n_train_iter)
    dql_state, accuracy_evolution = jax.lax.scan(train_step, dql_state, rng_steps)
    return dql_state, accuracy_evolution

In [4]:
def get_objective_fn(
    map_scale,
    cell_scales,
    max_steps,
    p,
    train_minicycles=16,
    n_sims=16,
    batch_size=512,
    rollout_steps=16,
    plot_results = True,
):
    multi_env_setup = jax.jit(
        jax.vmap(setup_env, in_axes=(None, None, 0)),
        static_argnames=("map_scale", "p"),
    )
    multi_replay_buffer_setup = jax.jit(
        jax.vmap(setup_replay_buffer, in_axes=(0, 0, None, None)),
        static_argnames=("n_rollouts", "rollout_steps"),
    )
    multi_dql_state_setup = jax.jit(
        jax.vmap(setup_dql_state, in_axes=(0, 0, None, None, None)),
        static_argnames=("lr", "map_scale", "cell_scale"),
    )
    multi_train_stage = jax.jit(
        jax.vmap(train_stage, in_axes=(0, 0, 0, 0, None, None, None)),
        static_argnames=("n_train_iter", "batch_size", "eval_steps"),
    )

    def objective(trial: optuna.Trial) -> jnp.float_:
        # global hyperparameters
        # batch_size = 2 ** trial.suggest_int("log_batch_size", 10, 10)
        # rollout_steps = 2 ** trial.suggest_int("log_rollout_steps", 3, 3)
        total_train_iter = max_steps // batch_size
        train_iter = total_train_iter // train_minicycles
        n_rollouts = train_iter * batch_size // rollout_steps

        # attributes
        seed = trial.number
        trial.set_user_attr("max_steps", max_steps)
        trial.set_user_attr("cycle_train_iter", train_iter)
        trial.set_user_attr("batch_size", batch_size)
        trial.set_user_attr("rollout_steps", rollout_steps)

        # setup rng
        rng_key = jax.random.PRNGKey(seed)
        rng_env, rng_stages = jax.random.split(rng_key, 2)
        rng_env = jax.random.split(rng_env, n_sims)
        rng_stages = jax.random.split(rng_stages, len(cell_scales))

        # sim setup
        cycles_accuracy = []
        envs = multi_env_setup(map_scale, p, rng_env)

        def run_stage(envs, rng_key, cell_scale, n_cycles, n_rollouts, eval_steps):
            rng_init, rng_cycles = jax.random.split(rng_key)
            rng_init = jax.random.split(rng_init, n_sims)
            rng_cycles = jax.random.split(rng_cycles, n_cycles)

            dql_states = multi_dql_state_setup(envs, rng_init, lr, map_scale, cell_scale)
            stage_cycles_accuracy = []
            for rng_cycle in tqdm(rng_cycles, desc=f"Training {cell_scale}"):
                # setup rngs
                rng_rollout, rng_train = jax.random.split(rng_cycle)
                rng_rollout = jax.random.split(rng_rollout, n_sims)
                rng_train = jax.random.split(rng_train, n_sims)

                # train stage
                replay_buffers = multi_replay_buffer_setup(
                    envs, rng_rollout, n_rollouts, rollout_steps
                )
                dql_states, accuracy_evol = multi_train_stage(
                    rng_train, envs, dql_states, replay_buffers, train_iter, batch_size, eval_steps
                )
                stage_cycles_accuracy.append(accuracy_evol)
            return dql_states, stage_cycles_accuracy

        cycles_remaining = train_minicycles
        for i, rng_stage in enumerate(rng_stages):
            # inner stage hyperparameters
            eval_steps = (2 ** (cell_scales[i] - ([0] + list(cell_scales))[i])) ** 2
            lr = trial.suggest_float(f"lr_{i}", 3e-5, 3e-3, log=True)
            if i == len(cell_scales) - 1:
                train_cycles = cycles_remaining
            else:
                train_cycles = trial.suggest_int(
                    f"train_cycles_{i}", 1, train_minicycles // len(cell_scales)
                )
            cycles_remaining -= train_cycles

            # train stage
            dql_states, stage_cycles_accuracy = run_stage(
                envs, rng_stage, cell_scales[i], train_cycles, n_rollouts, eval_steps
            )

            # setup next stage and store results
            cycles_accuracy.extend(stage_cycles_accuracy)
            envs = mangoenv.MangoEnv(envs, dql_states, max_steps=eval_steps)
            # n_rollouts = n_rollouts // 4

        # store and plot results
        accuracy_evol = jnp.concatenate(cycles_accuracy, axis=-1)
        stages_duration = [
            train_iter * trial.params[f"train_cycles_{i}"] for i in range(len(cell_scales) - 1)
        ]
        trial.set_user_attr(f"accuracy_evol_mean", np.asarray(accuracy_evol.mean(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_min", np.asarray(accuracy_evol.min(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_max", np.asarray(accuracy_evol.max(axis=0)).tolist())
        trial.set_user_attr(f"accuracy_evol_std", np.asarray(accuracy_evol.std(axis=0)).tolist())        
        if plot_results:
            plot_accuracy_evolution(accuracy_evol, stages_duration)

        # evaluate objective
        cycles_accuracy = [acc.mean() for acc in cycles_accuracy]
        final_accuracy = cycles_accuracy[-1]
        convergence_time = train_minicycles
        if final_accuracy > 0.9:
            convergence_time = next((convergence_time-i for i, x in enumerate(cycles_accuracy[::-1]) if x < 0.9), 0)
        return final_accuracy, convergence_time

    return objective

# RUN 4x4 vanilla

In [5]:
map_scale = 2
cell_scales = (2,)
p = 0.8
max_steps = 1024*512

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
study = optuna.create_study(
    study_name=f"mango_stages_{list(cell_scales)}",
    storage=storage_path,
    load_if_exists=True,
    directions=["maximize", "minimize"],
)
study.optimize(
    get_objective_fn(map_scale, cell_scales, max_steps, p, train_minicycles=16, plot_results=False),
    n_trials=64,
    show_progress_bar=True,
)

[I 2024-03-03 09:21:12,351] A new study created in RDB with name: mango_stages_[2]


  0%|          | 0/32 [00:00<?, ?it/s]

Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:23,572] Trial 0 finished with values: [0.9986236095428467, 6.0] and parameters: {'lr_0': 0.0011903521204309074}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:30,544] Trial 1 finished with values: [0.8564881682395935, 16.0] and parameters: {'lr_0': 0.00012887562945618482}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:37,889] Trial 2 finished with values: [0.9706429243087769, 10.0] and parameters: {'lr_0': 0.00025754717973078186}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:45,482] Trial 3 finished with values: [0.997687041759491, 6.0] and parameters: {'lr_0': 0.00206465928513767}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:52,905] Trial 4 finished with values: [0.6089180707931519, 16.0] and parameters: {'lr_0': 3.154265501995509e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:21:59,970] Trial 5 finished with values: [0.9845476746559143, 9.0] and parameters: {'lr_0': 0.0003632574786868855}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:07,416] Trial 6 finished with values: [0.9992600679397583, 6.0] and parameters: {'lr_0': 0.0009720073943732756}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:14,859] Trial 7 finished with values: [0.8670086860656738, 16.0] and parameters: {'lr_0': 9.295453428121506e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:22,166] Trial 8 finished with values: [0.8000866174697876, 16.0] and parameters: {'lr_0': 5.719644118817717e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:29,479] Trial 9 finished with values: [0.7803810834884644, 16.0] and parameters: {'lr_0': 5.0189490986677066e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:36,771] Trial 10 finished with values: [0.9978077411651611, 6.0] and parameters: {'lr_0': 0.0021919084702805656}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:44,334] Trial 11 finished with values: [0.976447343826294, 10.0] and parameters: {'lr_0': 0.00030559586463286225}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:51,689] Trial 12 finished with values: [0.6938637495040894, 16.0] and parameters: {'lr_0': 4.045118584580183e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:22:58,998] Trial 13 finished with values: [0.6848653554916382, 16.0] and parameters: {'lr_0': 3.826493843832486e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:06,294] Trial 14 finished with values: [0.8816117644309998, 16.0] and parameters: {'lr_0': 0.00015053157920727513}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:13,586] Trial 15 finished with values: [0.998837411403656, 6.0] and parameters: {'lr_0': 0.0011588730533008496}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:21,098] Trial 16 finished with values: [0.9583842754364014, 11.0] and parameters: {'lr_0': 0.00024034765953453538}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:28,397] Trial 17 finished with values: [0.9545530080795288, 11.0] and parameters: {'lr_0': 0.0002468357870782538}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:35,735] Trial 18 finished with values: [0.9122848510742188, 14.0] and parameters: {'lr_0': 0.00017955290372898114}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:43,203] Trial 19 finished with values: [0.9834376573562622, 9.0] and parameters: {'lr_0': 0.00038054784387343745}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:50,553] Trial 20 finished with values: [0.9525682926177979, 11.0] and parameters: {'lr_0': 0.00027032073689527516}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:23:58,211] Trial 21 finished with values: [0.9988300800323486, 6.0] and parameters: {'lr_0': 0.0016149210894557493}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:05,658] Trial 22 finished with values: [0.9775094389915466, 9.0] and parameters: {'lr_0': 0.00044808608181186706}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:12,900] Trial 23 finished with values: [0.8657180070877075, 16.0] and parameters: {'lr_0': 0.0001541193961439427}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:20,344] Trial 24 finished with values: [0.857600212097168, 16.0] and parameters: {'lr_0': 0.00011100365739740085}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:27,735] Trial 25 finished with values: [0.849923849105835, 16.0] and parameters: {'lr_0': 8.575387597958821e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:35,391] Trial 26 finished with values: [0.9993906021118164, 5.0] and parameters: {'lr_0': 0.0018342223305941097}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:42,710] Trial 27 finished with values: [0.9846211075782776, 9.0] and parameters: {'lr_0': 0.0003798407556312993}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:50,047] Trial 28 finished with values: [0.9982700943946838, 6.0] and parameters: {'lr_0': 0.0014490789370145058}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:24:57,359] Trial 29 finished with values: [0.7683514356613159, 16.0] and parameters: {'lr_0': 6.265452002017598e-05}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:25:04,399] Trial 30 finished with values: [0.9803440570831299, 10.0] and parameters: {'lr_0': 0.0004662758520392543}. 


Training 2:   0%|          | 0/16 [00:00<?, ?it/s]

[I 2024-03-03 09:25:11,747] Trial 31 finished with values: [0.9996241331100464, 6.0] and parameters: {'lr_0': 0.0010420844066937485}. 


# RUN 4x4 mango

In [6]:
map_scale = 2
cell_scales = (1,2,)
p = 0.8
max_steps = 1024*512

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
study = optuna.create_study(
    study_name=f"mango_stages_{list(cell_scales)}",
    storage=storage_path,
    load_if_exists=True,
    directions=["maximize", "minimize"],
)
study.optimize(
    get_objective_fn(map_scale, cell_scales, max_steps, p, train_minicycles=16, plot_results=False),
    n_trials=128,
    show_progress_bar=True,
)

[I 2024-03-03 09:25:12,469] A new study created in RDB with name: mango_stages_[1, 2]


  0%|          | 0/64 [00:00<?, ?it/s]

Training 1:   0%|          | 0/7 [00:00<?, ?it/s]

Training 2:   0%|          | 0/9 [00:00<?, ?it/s]

[I 2024-03-03 09:25:27,672] Trial 0 finished with values: [1.0, 10.0] and parameters: {'lr_0': 0.00012923550033092608, 'train_cycles_0': 7, 'lr_1': 0.0021889071633619693}. 


Training 1:   0%|          | 0/2 [00:00<?, ?it/s]

Training 2:   0%|          | 0/14 [00:00<?, ?it/s]

[I 2024-03-03 09:25:42,229] Trial 1 finished with values: [0.7007726430892944, 16.0] and parameters: {'lr_0': 0.0002230177188573658, 'train_cycles_0': 2, 'lr_1': 0.00018578440574212648}. 


Training 1:   0%|          | 0/5 [00:00<?, ?it/s]

Training 2:   0%|          | 0/11 [00:00<?, ?it/s]

[I 2024-03-03 09:25:55,735] Trial 2 finished with values: [0.828105628490448, 16.0] and parameters: {'lr_0': 0.00011981713051120457, 'train_cycles_0': 5, 'lr_1': 0.00014374307103657825}. 


Training 1:   0%|          | 0/2 [00:00<?, ?it/s]

Training 2:   0%|          | 0/14 [00:00<?, ?it/s]

[I 2024-03-03 09:26:09,782] Trial 3 finished with values: [0.8213696479797363, 16.0] and parameters: {'lr_0': 0.00010337543891221303, 'train_cycles_0': 2, 'lr_1': 0.0005904249018779899}. 


Training 1:   0%|          | 0/6 [00:00<?, ?it/s]

Training 2:   0%|          | 0/10 [00:00<?, ?it/s]

[I 2024-03-03 09:26:22,573] Trial 4 finished with values: [0.9425994753837585, 11.0] and parameters: {'lr_0': 4.3208695517972565e-05, 'train_cycles_0': 6, 'lr_1': 0.0013804001434584055}. 


Training 1:   0%|          | 0/7 [00:00<?, ?it/s]

Training 2:   0%|          | 0/9 [00:00<?, ?it/s]

[I 2024-03-03 09:26:35,203] Trial 5 finished with values: [0.7724594473838806, 16.0] and parameters: {'lr_0': 8.508721568446896e-05, 'train_cycles_0': 7, 'lr_1': 0.0001028460358632276}. 


Training 1:   0%|          | 0/2 [00:00<?, ?it/s]

Training 2:   0%|          | 0/14 [00:00<?, ?it/s]

# RUN 8x8 vanilla

In [None]:
map_scale = 3
cell_scales = (3,)
p = 0.8
max_steps = 1024*1024*4

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
study = optuna.create_study(
    study_name=f"mango_stages_{list(cell_scales)}",
    storage=storage_path,
    load_if_exists=True,
    directions=["maximize", "minimize"],
)
study.optimize(
    get_objective_fn(map_scale, cell_scales, max_steps, p, train_minicycles=16),
    n_trials=32,
    show_progress_bar=True,
)

[I 2024-03-03 09:18:50,743] A new study created in RDB with name: mango_stages_[3]


  0%|          | 0/32 [00:00<?, ?it/s]

Training 3:   0%|          | 0/16 [00:00<?, ?it/s]

# RUN 8x8 mango

In [None]:
map_scale = 3
cell_scales = (1,2,3)
p = 0.8
max_steps = 1024*1024*4

storage_path = f"sqlite:///optuna_studies/{2**map_scale}x{2**map_scale}_p_{p}.db"
study = optuna.create_study(
    study_name=f"mango_stages_{list(cell_scales)}",
    storage=storage_path,
    load_if_exists=True,
    directions=["maximize", "minimize"],
)
study.optimize(
    get_objective_fn(map_scale, cell_scales, max_steps, p, train_minicycles=16),
    n_trials=128,
    show_progress_bar=True,
)

[I 2024-03-03 09:09:55,975] A new study created in RDB with name: mango_stages_[1, 2, 3]


  0%|          | 0/128 [00:00<?, ?it/s]

Training 1:   0%|          | 0/5 [00:00<?, ?it/s]