In [6]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import jumanji
from jumanji.wrappers import AutoResetWrapper
import dm_env
from tqdm import tqdm

In [7]:
def get_observation_element(observation_spec, observation, key):
    """
    Helper to return a specific observation element based on the observation_spec structure.

    Args:
        observation_spec (dict): The dictionary specifying the observation structure.
        observation (tuple or list): The actual observation values in the same order as observation_spec.
        key (str): The name of the element to extract.

    Returns:
        The corresponding element from observation.
    """
    keys = list(observation_spec.keys())  # Get the ordered keys from observation_spec
    index = keys.index(key)  # Find the index of the requested key
    return observation[index]  # Return the corresponding value from observation


def plot_learning_curve(list_of_episode_returns):
    """Plot the learning curve."""
    plt.figure(figsize=(7, 5))

    def moving_average(x, w):
        return np.convolve(x, np.ones(w), "valid") / w

    smoothed_returns = moving_average(list_of_episode_returns, 30)
    plt.plot(smoothed_returns)

    plt.xlabel("Average episode returns")
    plt.xlabel("Number of episodes")

    ax = plt.gca()
    ax.spines["left"].set_visible(True)
    ax.spines["bottom"].set_visible(True)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("left")

In [8]:
@jax.jit
def softmax_policy(parameters, key, obs):
    """Sample action from a softmax policy."""
    _, p = network(parameters, obs)
    return jax.random.categorical(key, p)


def network(params, observation):
    # Implement forward pass here
    w = params["w"]
    b = params["b"]
    # these are theta (vector) -- policy fn params
    w_p = params["w_p"]
    b_p = params["b_p"]
    # these are w (vector) -- value fn params (just overloaded notation bc they're linear params so hence the w name for weights)
    w_v = params["w_v"]
    b_v = params["b_v"]

    flat_observation = observation.flatten()
    h = jax.nn.relu(flat_observation.dot(w) + b)  # hidden representation
    # print(flat_observation.shape,w.shape,h.shape)
    p = h.dot(w_p) + b_p
    # print(h.shape,w_p.shape,p.shape)
    v = h.dot(w_v) + b_v
    return v, p


def create_parameters(rng_key, observation):
    # Returns a dictionary with the desired parameters for the network
    params = {}
    rng_key, param_key = jax.random.split(rng_key)
    params["w"] = jax.random.truncated_normal(
        param_key, -1, 1, (jnp.size(observation), jnp.size(observation))
    )
    rng_key, param_key = jax.random.split(rng_key)
    params["b"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation),))
    rng_key, param_key = jax.random.split(rng_key)
    params["w_p"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation), 3))
    rng_key, param_key = jax.random.split(rng_key)
    params["b_p"] = jax.random.truncated_normal(param_key, -1, 1, (3,))
    rng_key, param_key = jax.random.split(rng_key)
    params["w_v"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation), 1))
    rng_key, param_key = jax.random.split(rng_key)
    params["b_v"] = jax.random.truncated_normal(param_key, -1, 1, (1,))

    return params

In [9]:
def value_TD(parameters, obs_tm1, r_t, discount_t, obs_t):
    v_t, _ = network(parameters, obs_t)
    v_tm1, p_tm1 = network(parameters, obs_tm1)

    # Calculate the TD error
    q_pi = r_t + discount_t * v_t
    TD_err = q_pi - v_tm1
    grad = jax.lax.stop_gradient(TD_err[0]) * v_tm1[0]
    return grad


def value_update(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    value_grads = jax.grad(value_TD)(parameters, obs_tm1, r_t, discount_t, obs_t)

    return value_grads


def policy_TD(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    # Get the current and previous state value and policy logits
    v_t, p_t = network(parameters, obs_t)
    v_tm1, p_tm1 = network(parameters, obs_tm1)
    # Compute the log probability of the action a_tm1
    log_pi = jax.nn.log_softmax(p_tm1)
    log_pi_a_tm1 = log_pi[a_tm1]

    # Calculate the TD error
    q_pi = r_t + discount_t * v_t
    TD_err = q_pi - v_tm1

    return log_pi_a_tm1 * jax.lax.stop_gradient(TD_err[0])


def policy_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    policy_grads = jax.grad(policy_TD)(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)

    return policy_grads


def opt_init(parameters):
    mu = jax.tree_map(jnp.zeros_like, parameters)
    nu = jax.tree_map(jnp.ones_like, parameters)
    b1 = 0.9
    b2 = 0.999
    epsilon = 1e-4
    alpha = 0.003
    # alpha = 5e-4
    opt_state = (alpha, b1, b2, epsilon, mu, nu)
    return opt_state


def opt_update(grads, opt_state):
    alpha, b1, b2, epsilon, mu, nu = opt_state
    mu = jax.tree_map(lambda m, g: (1 - b1) * g + b1 * m, mu, grads)
    nu = jax.tree_map(lambda n, g: (1 - b2) * (g**2) + b2 * n, nu, grads)
    updates = jax.tree_map(lambda m, n: alpha * (m / (epsilon + jnp.sqrt(n))), mu, nu)
    opt_state = (alpha, b1, b2, epsilon, mu, nu)
    return updates, opt_state

In [10]:
@jax.jit
def compute_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    pgrads = policy_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)
    td_update = value_update(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)
    return jax.tree_map(lambda pg, td: pg + td, pgrads, td_update)

In [11]:
env = jumanji.make("Snake-v1")  # Create a Snake environment
env = AutoResetWrapper(env)  # Automatically reset the environment when an episode terminates
env = jumanji.wrappers.JumanjiToDMEnvWrapper(env)

# Experiment configs.
train_episodes = 2500
discount_factor = 0.99


# Build and initialize network.
rng = jax.random.PRNGKey(44)
rng, init_rng = jax.random.split(rng)
sample_input = env.observation_spec()["grid"].generate_value()
parameters = create_parameters(init_rng, sample_input)

# Initialize optimizer state.
opt_state = opt_init(parameters)


# Apply updates
def apply_updates(params, updates):
    return jax.tree_map(lambda p, u: p + u, params, updates)


# Jit.
opt_update = jax.jit(opt_update)
apply_updates = jax.jit(apply_updates)

print(f"Training agent for {train_episodes} episodes...")
all_episode_returns = []

for _ in tqdm(range(train_episodes)):
    episode_return = 0.0
    timestep = env.reset()
    obs_tm1 = timestep.observation[0]

    # Sample initial action.
    rng, policy_rng = jax.random.split(rng)
    a_tm1 = softmax_policy(parameters, policy_rng, obs_tm1)

    while not timestep.last():
        # Step environment.
        new_timestep = env.step(int(a_tm1))

        # Sample action from agent policy.
        rng, policy_rng = jax.random.split(rng)
        a_t = softmax_policy(parameters, policy_rng, new_timestep.observation[0])

        # Update params.
        r_t = new_timestep.reward
        discount_t = discount_factor * new_timestep.discount
        dJ_dtheta = compute_gradient(
            parameters, obs_tm1, a_tm1, r_t, discount_t, new_timestep.observation[0]
        )
        updates, opt_state = opt_update(dJ_dtheta, opt_state)
        parameters = apply_updates(parameters, updates)

        # Within episode book-keeping.
        episode_return += new_timestep.reward
        timestep = new_timestep
        obs_tm1 = new_timestep.observation[0]
        a_tm1 = a_t

    # Experiment results tracking.
    all_episode_returns.append(episode_return)

# Plot learning curve.
plot_learning_curve(all_episode_returns)

HELLO
dict
bounded
discrete
bounded


  mu = jax.tree_map(jnp.zeros_like, parameters)


Training agent for 2500 episodes...


  2%|▏         | 50/2500 [00:02<01:44, 23.52it/s]


KeyboardInterrupt: 