In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import sims

In [3]:
seed = 42
map_size = 8
max_interactions = 10_000
map_scale = int(jnp.log2(map_size))

# hyperparams
lr = 1e-3
rollout_length = 2 * map_size**2
n_rollouts = max_interactions // rollout_length
train_iter = 8

In [4]:
import utils
import qlearning
import replay

@partial(jax.jit, static_argnames=("map_size", "lr", "replay_capacity"))
def setup_simulation(rng_key, map_size, lr, replay_capacity):
    rng_env, rng_dql = jax.random.split(rng_key)
    #env = utils.FrozenLake.make_preset(rng_env, (map_size, map_size))
    env = utils.FrozenLake.make_random(rng_env, (map_size, map_size), 0.8)

    env_state, obs = env.reset(rng_key)
    action = env.action_space.sample(rng_key)
    sample_transition = utils.Transition(env_state, obs, action, 0.0, obs, False, {})

    qnet = utils.ConvNet(hidden=[2 * map_size] * int(np.log2(map_size)), out=4)
    dql_state = qlearning.DQLTrainState.create(rng_dql, qnet, obs, lr)
    replay_memory = replay.CircularBuffer.create(sample_transition, replay_capacity)
    return sims.SimulationState(env, dql_state, replay_memory)

In [5]:
rng_key = jax.random.PRNGKey(seed)
rng_init, rng_sim = jax.random.split(rng_key)
sim_state = setup_simulation(rng_key, map_size, lr, replay_capacity=1024)
%timeit setup_simulation(rng_key, map_size, lr, replay_capacity=1024)[0].frozen.block_until_ready()

q_learning_step =partial(sims.q_learning_step, rollout_length=rollout_length, train_iter=train_iter)
sim_state, results = q_learning_step(sim_state, rng_sim)
%timeit q_learning_step(setup_simulation(rng_key, map_size, lr, 1024), rng_sim)[0][0].frozen.block_until_ready()

rng_sim = jax.random.split(rng_sim, n_rollouts)
q_learning_step =partial(sims.q_learning_step, rollout_length=rollout_length, train_iter=train_iter)
scan = jax.jit(lambda rng: jax.lax.scan(q_learning_step, sim_state, rng))
%timeit scan(rng_sim)[0][0].frozen.block_until_ready()