# MDP

In [None]:
import popgym_arcade
import matplotlib.pyplot as plt
import jax
import imageio
import equinox as eqx
from popgym_arcade.baselines.model.builder import QNetwork
key = jax.random.PRNGKey(0)
env_name = "TetrisEasy"
env, env_params = popgym_arcade.make(env_name, obs_size = 128, partial_obs=False)
key, re_key = jax.random.split(key)
vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))(
    jax.random.split(rng, n_envs), env_params
)
vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap(
    env.step, in_axes=(0, 0, 0, None)
)(jax.random.split(rng, n_envs), env_state, action, env_params)

obs, state = vmap_reset(1)(re_key)

model_name = "PQN_{}_model_Partial=False_SEED=0.pkl".format(env_name)
key, model_key = jax.random.split(key)

qnet = QNetwork(model_key, 128)
model = eqx.tree_deserialise_leaves(model_name, qnet)
frames = []
done = False
for _ in range(500):
    q_val = model(obs)
    action = jax.numpy.argmax(q_val, axis=-1)
    key, step_key = jax.random.split(key)
    obs, state, reward, done, info = vmap_step(1)(step_key, state, action)
    frame = obs.squeeze()
    frame = (frame * 255).astype(jax.numpy.uint8)
    frames.append(frame)

imageio.mimwrite("MLP_MDP_{}.gif".format(env_name), frames, fps=4)

# POMDP

In [None]:
import popgym_arcade
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
import imageio
import equinox as eqx
from popgym_arcade.baselines.model import QNetworkRNN, add_batch_dim

seed = jax.random.PRNGKey(0)
seed, _rng = jax.random.split(seed)
env_name = "BreakoutEasy"
env, env_params = popgym_arcade.make(env_name, obs_size=128, partial_obs=True)
vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))(
    jax.random.split(rng, n_envs), env_params
)
vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap(
    env.step, in_axes=(0, 0, 0, None)
)(jax.random.split(rng, n_envs), env_state, action, env_params)
model_name = "PQN_RNN_lru_{}_model_Partial=True_SEED=0.pkl".format(env_name)
seed, model_key = jax.random.split(seed)
qnetrnn = QNetworkRNN(model_key, 128)
model = eqx.tree_deserialise_leaves(model_name, qnetrnn)
obs, state = vmap_reset(2)(_rng)
init_done = jnp.zeros(2, dtype=bool)
init_action = jnp.zeros(2, dtype=int)
init_hs = model.initialize_carry(key=_rng)
hs = add_batch_dim(init_hs, 2)

# Initialize frames array
frame_shape = obs[0].shape
frames = jnp.zeros((500, *frame_shape), dtype=jnp.float32)

carry = (hs, obs, init_done, init_action, state, frames, _rng)

def evaluate_step(carry, i):
    hs, obs, done, action, state, frames, _rng = carry
    _rng, rng_step = jax.random.split(_rng, 2)
    obs_batch = obs[jnp.newaxis, :]
    done_batch = done[jnp.newaxis, :]
    action_batch = action[jnp.newaxis, :]
    hs, q_val = model(hs, obs_batch, done_batch, action_batch)
    q_val = jax.lax.stop_gradient(q_val)
    q_val = q_val.squeeze(axis=0)
    action = jnp.argmax(q_val, axis=-1)
    obs, new_state, reward, done, info = vmap_step(2)(rng_step, state, action)
    jax.debug.print(obs)
    state = new_state
    frame = jnp.asarray(obs[0]) * 255
    # Update frames array at index i
    frames = frames.at[i].set(frame)
    carry = (hs, obs, done, action, state, frames, _rng)
    return carry, reward

def body_fun(i, carry):
    carry, _ = evaluate_step(carry, i)
    return carry

carry = jax.lax.fori_loop(0, 500, body_fun, carry)
_, _, _, _, _, frames, _rng = carry
frames = np.array(frames, dtype=np.uint8)
# frames = frames.transpose((0, 3, 1, 2))
# imageio.mimsave('{}_{}_{}_Partial={}_SEED={}.gif'.format(config["TRAIN_TYPE"], config["MEMORY_TYPE"], config["ENV_NAME"], config["PARTIAL"], config["SEED"]), frames)
imageio.mimwrite("LRU_POMDP_{}.gif".format(env_name), frames, fps=30)

step: 1
step: 1
step: 2
step: 2
step: 3
step: 3
step: 4
step: 4
step: 5
step: 5
step: 6
step: 6
step: 7
step: 7
step: 8
step: 8
step: 9
step: 9
step: 10
step: 10
step: 11
step: 11
step: 12
step: 12
step: 13
step: 13
step: 14
step: 14
step: 15
step: 15
step: 16
step: 16
step: 17
step: 17
step: 18
step: 18
step: 19
step: 19
step: 20
step: 1
step: 21
step: 2
step: 22
step: 3
step: 23
step: 4
step: 24
step: 5
step: 25
step: 6
step: 26
step: 7
step: 8
step: 27
step: 9
step: 28
step: 10
step: 29
step: 11
step: 30
step: 31
step: 12
step: 32
step: 13
step: 33
step: 14
step: 15
step: 34
step: 35
step: 16
step: 17
step: 36
step: 37
step: 18
step: 19
step: 38
step: 1
step: 39
step: 40
step: 2
step: 41
step: 3
step: 42
step: 4
step: 5
step: 43
step: 1
step: 6
step: 7
step: 2
step: 8
step: 3
step: 4
step: 9
step: 10
step: 5
step: 11
step: 6
step: 7
step: 12
step: 13
step: 8
step: 14
step: 9
step: 10
step: 15
step: 16
step: 11
step: 12
step: 17
step: 13
step: 18
step: 14
step: 19
step: 20
step: 15
s