In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3, 4, 6'

In [None]:
import jax
import jax.numpy as jnp
import flax
import matplotlib.pyplot as plt
import flashbax as fbx
import chex
import warnings
import jumanji
from jumanji.wrappers import AutoResetWrapper
import rlax
from tqdm import tqdm
from IPython.display import clear_output

%matplotlib inline
jax.device_count(), jax.devices()

In [None]:
import utils
import anakin
import dqn_v2

In [None]:
TRAINING_EVAL_ITERS=25

#training hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 6.25e-05
SEED = 42
N_ENVS = 8
BUFFER_SIZE = 10_000
ROLLOUT_LEN = 512
OPTIM_UPDATE_LEN = 64
N_ITERATIONS = 20
UPDATE_PERIOD = 10
GAMMA = 0.99
START_EPSILON = 1.0
END_EPSILON = 0.1
STEPS_EPSILON = 10_000

#eval hyperparameters
NUM_EVAL_EPISODES = 50
MAX_EVAL_ITERS = 1000

In [None]:
env = jumanji.make("Game2048-v1")
training_env = AutoResetWrapper(env)

In [None]:
n_devices, network, params, optim, opt_state, buffer, buffer_state, epsilon_schedule_fn, rng = utils.setup_experiment(
    env=training_env,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    seed=SEED,
    buffer_size=BUFFER_SIZE,
    start_epsilon=START_EPSILON,
    end_epsilon=END_EPSILON,
    steps_epsilon=STEPS_EPSILON
)

In [None]:
rng, *rngs_env = jax.random.split(rng, n_devices * N_ENVS + 1)
states, timesteps = jax.vmap(env.reset)(jnp.stack(rngs_env))
reshape = lambda x: x.reshape((n_devices, N_ENVS) + x.shape[1:])
states = jax.tree.map(reshape, states)
timesteps = jax.tree.map(reshape, timesteps)

params_state, opt_state, buffer_state, rngs_pv, rng = utils.broadcast_to_pv_shape(
    n_devices, N_ENVS, params, opt_state, buffer_state, rng
)

learn_fn = anakin.get_learner_fn(
    env=training_env,
    rollout_len=ROLLOUT_LEN,
    gamma=GAMMA,
    buffer=buffer,
    update_period=UPDATE_PERIOD,
    n_iterations=N_ITERATIONS,
    optim_update_len=OPTIM_UPDATE_LEN,
    forward_fn=network.apply,
    opt_update_fn=optim.update,
    epsilon_schedule_fn=epsilon_schedule_fn
)

learn_fn = jax.pmap(learn_fn, axis_name='i')

In [None]:
@jax.jit
def eval_one_episode(params, rng):
  state, timestep = env.reset(rng)

  def step_fn(val):
    params, state, timestep, total_r, done, rng = val
    rng, _ = jax.random.split(rng, 2)
    obs = timestep.observation.board
    action_mask = timestep.observation.action_mask
    q_values = network.apply(params, jnp.expand_dims(obs, (0, 3))).squeeze(0)
    q_values = utils.masked_fill(action_mask, q_values, -jnp.inf)
    action = jnp.argmax(q_values)
    next_state, next_timestep = env.step(state, action)
    total_r += next_timestep.reward
    return (params, next_state, next_timestep, total_r, next_timestep.last(), rng)

  params, state, timestep, total_r, done, rng = jax.lax.while_loop(
      lambda x: x[4] == False, step_fn, (params, state, timestep, 0, False, rng)
  )
  return params, total_r

@jax.jit
def eval(params, rng):
  rngs = jax.random.split(rng, NUM_EVAL_EPISODES)
  params = jax.tree.map(lambda x: x[0][0], params)
  _, total_r = jax.lax.scan(eval_one_episode, params, rngs)
  return jnp.mean(total_r), jnp.var(total_r)

In [None]:
# param_count = sum(x.size for x in jax.tree.leaves(params_state.online))
# param_count

In [None]:
avg_reward = []
bounds_avg_reward = []

for iter in tqdm(range(TRAINING_EVAL_ITERS)):

    # Train
    params_state, opt_state, buffer_state, states, timesteps, rngs_pv = learn_fn(
        params_state, opt_state, buffer_state, states, timesteps, rngs_pv
    )
    # params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.
    # Eval
    rng, rng_eval = jax.random.split(rng, 2)
    total_r, total_r_var = eval(params_state.online, rng_eval)
    avg_reward.append(total_r)
    bounds_avg_reward.append(total_r_var)

    # print(total_r_arr[:10])
    clear_output(True)
    print(f"Mean Reward at iteration {iter}: {total_r}")
    print(f"SE reward at iteration {iter}: {jnp.sqrt(total_r_var)}")
    print(f"Epsilon at iteration {iter}: {epsilon_schedule_fn(params_state.update_count)[0][0]}")
    
    plt.plot(avg_reward)
    plt.plot(jnp.array(avg_reward) + jnp.sqrt(jnp.array(bounds_avg_reward)), linestyle='dashed', color='red')
    plt.plot(jnp.array(avg_reward) - jnp.sqrt(jnp.array(bounds_avg_reward)), linestyle='dashed', color='red')
    plt.title("Average reward each iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Reward")
    plt.show()

In [None]:
states = []
rng, rng_gif = jax.random.split(rng, 2)
state, timestep = env.reset(rng_gif)
params = jax.tree.map(lambda x: x[0][0], params_state.online)

for i in tqdm(range(500)):
    states.append(state)
    obs = timestep.observation.board
    action_mask = timestep.observation.action_mask
    q_values = network.apply(params, jnp.expand_dims(obs, (0, 3))).squeeze(0)
    q_values = utils.masked_fill(action_mask, q_values, -jnp.inf)
    action = jnp.argmax(q_values)
    state, timestep = env.step(state, action)
    if timestep.last():
        break

env.animate(states, interval=150).save("./2048.gif")