In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3, 4, 6'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.87'

In [2]:
import jax
import jax.numpy as jnp
import flax
import matplotlib.pyplot as plt
import flashbax as fbx
import chex
import warnings
import jumanji
from jumanji.wrappers import AutoResetWrapper
import rlax
from tqdm import tqdm
from IPython.display import clear_output

%matplotlib inline
jax.device_count(), jax.devices()

(3, [cuda(id=0), cuda(id=1), cuda(id=2)])

In [3]:
import utils
import anakin_v3 as anakin

In [4]:
SEED = 42
LEARNING_RATE = 5e-4
N_ENVS = 100
GAMMA = 0.99
N_ITERATIONS = 100
COEFS = (1., 1., 0.01)

TRAINING_EVAL_ITERS=100
NUM_EVAL_EPISODES = 50

In [5]:
env = jumanji.make("Tetris-v0")

In [6]:
n_devices, actor_fn, critic_fn, params, optim, opt_state, rng = utils.setup_experiment(env, SEED, LEARNING_RATE)
params, opt_state, rngs_pv, rng = utils.broadcast_to_pv_shape(n_devices, N_ENVS, params, opt_state, rng)

learn_fn = anakin.get_learner_fn(
    env,
    gamma=GAMMA,
    n_iterarions=N_ITERATIONS,
    coefs=COEFS,
    actor_fn=actor_fn,
    critic_fn=critic_fn,
    opt_update_fn=optim.update
)

learn_fn = jax.vmap(learn_fn, axis_name="envs")
learn_fn = jax.pmap(learn_fn, axis_name="devices")

AttributeError: 'Observation' object has no attribute 'tetromino'

In [7]:
@jax.jit
def eval_one_episode(params, rng):
  state, timestep = env.reset(rng)

  def step_fn(val):

    def sample_action(logits, action_mask, rng):
        rng, rng_sample = jax.random.split(rng, 2)
        masked_logits = utils.masked_fill(action_mask, logits, -jnp.inf)
        num_columns = masked_logits.shape[1]
        flattened_logits = masked_logits.reshape(-1, )
        action = jnp.argmax(flattened_logits)
        return jnp.stack([action // num_columns, action % num_columns])
    
    params, state, timestep, total_r, done, rng = val
    rng, _ = jax.random.split(rng, 2)
    
    grid = timestep.observation.grid
    tetromino = timestep.observation.tetromino
    action_mask = timestep.observation.action_mask
      
    logits = actor_fn(
        params,
        jnp.expand_dims(timestep.observation.grid, (0, 3)),
        jnp.expand_dims(timestep.observation.tetromino, (0, 3))
      )
    
    action = sample_action(logits.squeeze(0), action_mask, rng)

    next_state, next_timestep = env.step(state, action)
    total_r += next_timestep.reward
    return (params, next_state, next_timestep, total_r, next_timestep.last(), rng)

  params, state, timestep, total_r, done, rng = jax.lax.while_loop(
      lambda x: x[4] == False, step_fn, (params, state, timestep, 0, False, rng)
  )
  return params, total_r

@jax.jit
def eval(params, rng):
  rngs = jax.random.split(rng, NUM_EVAL_EPISODES)
  params = jax.tree.map(lambda x: x[0][0], params)
  _, total_r = jax.lax.scan(eval_one_episode, params, rngs)
  return jnp.mean(total_r), jnp.var(total_r)

In [None]:
avg_reward = []
for iter in tqdm(range(TRAINING_EVAL_ITERS)):
    params, opt_state, rngs_pv = learn_fn(params, opt_state, rngs_pv)
    
    rng, rng_eval = jax.random.split(rng, 2)
    total_r, total_r_var = eval(params, rng_eval)
    avg_reward.append(total_r)
    
    clear_output(True)
    print(f"Mean Reward at iteration {iter}: {total_r}")
    # print(f"SE reward at iteration {iter}: {jnp.sqrt(total_r_var)}")
    
    plt.plot(avg_reward)
    # plt.plot(jnp.array(avg_reward) + jnp.sqrt(jnp.array(bounds_avg_reward)), linestyle='dashed', color='red')
    # plt.plot(jnp.array(avg_reward) - jnp.sqrt(jnp.array(bounds_avg_reward)), linestyle='dashed', color='red')
    plt.title("Average reward each iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Reward")
    plt.show()


In [None]:
# states = []
# rng, rng_gif = jax.random.split(rng, 2)
# state, timestep = env.reset(rng_gif)
# params = jax.tree.map(lambda x: jnp.mean(x, axis=(0, 1)), params_state.online)

# for i in tqdm(range(500)):
#     states.append(state)
#     obs = timestep.observation.board
#     action_mask = timestep.observation.action_mask
#     q_values = network.apply(params, jnp.expand_dims(obs, (0, 3))).squeeze(0)
#     q_values = utils.masked_fill(action_mask, q_values, -jnp.inf)
#     action = jnp.argmax(q_values)
#     state, timestep = env.step(state, action)
#     if timestep.last():
#         break

# env.animate(states, interval=150).save("./2048.gif")

In [None]:
param_count = sum(x.size for x in jax.tree.leaves(params))
param_count

In [None]:
24011736 // 10 ** 6