In [91]:
import jax
import jax.numpy as jnp

import jumanji
from jumanji.wrappers import AutoResetWrapper
import dm_env

In [130]:
def get_observation_element(observation_spec, observation, key):
    """
    Extracts a specific observation element based on the observation_spec structure.

    Args:
        observation_spec (dict): The dictionary specifying the observation structure.
        observation (tuple or list): The actual observation values in the same order as observation_spec.
        key (str): The name of the element to extract.

    Returns:
        The corresponding element from observation.
    """
    keys = list(observation_spec.keys())  # Get the ordered keys from observation_spec
    index = keys.index(key)  # Find the index of the requested key
    return observation[index]  # Return the corresponding value from observation

In [71]:
@jax.jit
def softmax_policy(parameters, key, obs):
    """Sample action from a softmax policy."""
    _, p = network(parameters, obs)
    return jax.random.categorical(key, p)


def network(params, observation):
    # Implement forward pass here
    w = params["w"]
    b = params["b"]
    # these are theta (vector) -- policy fn params
    w_p = params["w_p"]
    b_p = params["b_p"]
    # these are w (vector) -- value fn params (just overloaded notation bc they're linear params so hence the w name for weights)
    w_v = params["w_v"]
    b_v = params["b_v"]

    flat_observation = observation.flatten()
    h = jax.nn.relu(flat_observation.dot(w) + b)  # hidden representation
    # print(flat_observation.shape,w.shape,h.shape)
    p = h.dot(w_p) + b_p
    # print(h.shape,w_p.shape,p.shape)
    v = h.dot(w_v) + b_v
    return v, p


def create_parameters(rng_key, observation):
    # Returns a dictionary with the desired parameters for the network
    params = {}
    rng_key, param_key = jax.random.split(rng_key)
    params["w"] = jax.random.truncated_normal(
        param_key, -1, 1, (jnp.size(observation), jnp.size(observation))
    )
    rng_key, param_key = jax.random.split(rng_key)
    params["b"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation),))
    rng_key, param_key = jax.random.split(rng_key)
    params["w_p"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation), 3))
    rng_key, param_key = jax.random.split(rng_key)
    params["b_p"] = jax.random.truncated_normal(param_key, -1, 1, (3,))
    rng_key, param_key = jax.random.split(rng_key)
    params["w_v"] = jax.random.truncated_normal(param_key, -1, 1, (jnp.size(observation), 1))
    rng_key, param_key = jax.random.split(rng_key)
    params["b_v"] = jax.random.truncated_normal(param_key, -1, 1, (1,))

    return params

In [151]:
env = jumanji.make("Snake-v1")  # Create a Snake environment
env = AutoResetWrapper(env)  # Automatically reset the environment when an episode terminates
# env = jumanji.wrappers.JumanjiToDMEnvWrapper(env)
batch_size = 7
rollout_length = 5
num_actions = env.action_spec.num_values

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
obs_spec = env.observation_spec
sample_input = obs_spec.generate_value()
observation = [obs for obs in sample_input]
parameters = create_parameters(init_rng, observation[0])


def step_fn(state, key):
    key, obs_tm1 = key_and_obs_tm1
    # Sample action based on policy
    a_tm1 = softmax_policy(parameters, key, state)
    # Step the environment

    state, timestep = env.step(a_tm1)
    return state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout


# Instantiate a batch of environment states
keys = jax.random.split(rng, batch_size)
state, timestep = jax.vmap(env.reset)(keys)
# Collect rollouts
keys = jax.random.split(init_rng, batch_size)
rollout = jax.vmap(run_n_steps, in_axes=(0, 0, 0, None))(state, timestep, keys, rollout_length)

# Shape and type of given rollout:
# TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)

TimeStep(step_type=Array([0, 0, 0, 0, 0, 0, 0], dtype=int8), reward=Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32), discount=Array([1., 1., 1., 1., 1., 1., 1.], dtype=float32), observation=Observation(grid=Array([[[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        ...,

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
 

ValueError: scan got values with different leading axis sizes: 5, 12.

In [152]:
# def v(parameters, observation):
#   v,_ = network(parameters, observation)
#   return v[0]


def value_TD(parameters, obs_tm1, r_t, discount_t, obs_t):
    v_t, _ = network(parameters, obs_t)
    v_tm1, p_tm1 = network(parameters, obs_tm1)

    # Calculate the TD error
    q_pi = r_t + discount_t * v_t
    TD_err = q_pi - v_tm1
    grad = jax.lax.stop_gradient(TD_err[0]) * v_tm1[0]
    return grad


def value_update(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    value_grads = jax.grad(value_TD)(parameters, obs_tm1, r_t, discount_t, obs_t)

    return value_grads

In [153]:
def policy_TD(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    # Get the current and previous state value and policy logits
    v_t, p_t = network(parameters, obs_t)
    v_tm1, p_tm1 = network(parameters, obs_tm1)
    # Compute the log probability of the action a_tm1
    log_pi = jax.nn.log_softmax(p_tm1)
    log_pi_a_tm1 = log_pi[a_tm1]

    # Calculate the TD error
    q_pi = r_t + discount_t * v_t
    TD_err = q_pi - v_tm1

    return log_pi_a_tm1 * jax.lax.stop_gradient(TD_err[0])


def policy_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    policy_grads = jax.grad(policy_TD)(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)

    return policy_grads

In [154]:
@jax.jit
def compute_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    pgrads = policy_gradient(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)
    td_update = value_update(parameters, obs_tm1, a_tm1, r_t, discount_t, obs_t)
    return jax.tree_map(lambda pg, td: pg + td, pgrads, td_update)

In [155]:
def opt_init(parameters):
    mu = jax.tree_map(jnp.zeros_like, parameters)
    nu = jax.tree_map(jnp.ones_like, parameters)
    b1 = 0.9
    b2 = 0.999
    epsilon = 1e-4
    alpha = 0.003
    # alpha = 5e-4
    opt_state = (alpha, b1, b2, epsilon, mu, nu)
    return opt_state


def opt_update(grads, opt_state):
    alpha, b1, b2, epsilon, mu, nu = opt_state
    mu = jax.tree_map(lambda m, g: (1 - b1) * g + b1 * m, mu, grads)
    nu = jax.tree_map(lambda n, g: (1 - b2) * (g**2) + b2 * n, nu, grads)
    updates = jax.tree_map(lambda m, n: alpha * (m / (epsilon + jnp.sqrt(n))), mu, nu)
    opt_state = (alpha, b1, b2, epsilon, mu, nu)
    return updates, opt_state

In [156]:
def plot_learning_curve(list_of_episode_returns):
    """Plot the learning curve."""
    plt.figure(figsize=(7, 5))

    def moving_average(x, w):
        return np.convolve(x, np.ones(w), "valid") / w

    smoothed_returns = moving_average(list_of_episode_returns, 30)
    plt.plot(smoothed_returns)

    plt.xlabel("Average episode returns")
    plt.xlabel("Number of episodes")

    ax = plt.gca()
    ax.spines["left"].set_visible(True)
    ax.spines["bottom"].set_visible(True)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("left")

In [168]:
from tqdm import tqdm

In [167]:
# DO NOT CHANGE THIS CELL
import matplotlib.pyplot as plt

env = jumanji.make("Snake-v1")  # Create a Snake environment
env = AutoResetWrapper(env)  # Automatically reset the environment when an episode terminates
env = jumanji.wrappers.JumanjiToDMEnvWrapper(env)

# Experiment configs.
train_episodes = 2500
discount_factor = 0.99


# Build and initialize network.
rng = jax.random.PRNGKey(44)
rng, init_rng = jax.random.split(rng)
sample_input = env.observation_spec()["grid"].generate_value()
parameters = create_parameters(init_rng, sample_input)

# Initialize optimizer state.
opt_state = opt_init(parameters)


# Apply updates
def apply_updates(params, updates):
    return jax.tree_map(lambda p, u: p + u, params, updates)


# Jit.
opt_update = jax.jit(opt_update)
apply_updates = jax.jit(apply_updates)

print(f"Training agent for {train_episodes} episodes...")
all_episode_returns = []

for _ in range(train_episodes):
    episode_return = 0.0
    timestep = env.reset()
    obs_tm1 = timestep.observation[0]

    # Sample initial action.
    rng, policy_rng = jax.random.split(rng)
    a_tm1 = softmax_policy(parameters, policy_rng, obs_tm1)

    while not timestep.last():
        # Step environment.
        new_timestep = env.step(int(a_tm1))

        # Sample action from agent policy.
        rng, policy_rng = jax.random.split(rng)
        a_t = softmax_policy(parameters, policy_rng, new_timestep.observation[0])

        # Update params.
        r_t = new_timestep.reward
        discount_t = discount_factor * new_timestep.discount
        dJ_dtheta = compute_gradient(
            parameters, obs_tm1, a_tm1, r_t, discount_t, new_timestep.observation[0]
        )
        updates, opt_state = opt_update(dJ_dtheta, opt_state)
        parameters = apply_updates(parameters, updates)

        # Within episode book-keeping.
        episode_return += new_timestep.reward
        timestep = new_timestep
        obs_tm1 = new_timestep.observation[0]
        a_tm1 = a_t

    # Experiment results tracking.
    all_episode_returns.append(episode_return)

# Plot learning curve.
plot_learning_curve(all_episode_returns)

Training agent for 2500 episodes...
[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [1. 1. 1. 0. 1.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 

KeyboardInterrupt: 

In [94]:
# pylint: disable=g-bad-file-header
# Copyright 2019 The dm_env Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Catch reinforcement learning environment."""

import dm_env
from dm_env import specs
import numpy as np

_ACTIONS = (-1, 0, 1)  # Left, no-op, right.


class Catch(dm_env.Environment):
    """A Catch environment built on the `dm_env.Environment` class.

    The agent must move a paddle to intercept falling balls. Falling balls only
    move downwards on the column they are in.

    The observation is an array shape (rows, columns), with binary values:
    zero if a space is empty; 1 if it contains the paddle or a ball.

    The actions are discrete, and by default there are three available:
    stay, move left, and move right.

    The episode terminates when the ball reaches the bottom of the screen.
    """

    def __init__(self, rows: int = 10, columns: int = 5, seed: int = 1):
        """Initializes a new Catch environment.

        Args:
          rows: number of rows.
          columns: number of columns.
          seed: random seed for the RNG.
        """
        self._rows = rows
        self._columns = columns
        self._rng = np.random.RandomState(seed)
        self._board = np.zeros((rows, columns), dtype=np.float32)
        self._ball_x = None
        self._ball_y = None
        self._paddle_x = None
        self._paddle_y = self._rows - 1
        self._reset_next_step = True

    def reset(self) -> dm_env.TimeStep:
        """Returns the first `TimeStep` of a new episode."""
        self._reset_next_step = False
        self._ball_x = self._rng.randint(self._columns)
        self._ball_y = 0
        self._paddle_x = self._columns // 2
        return dm_env.restart(self._observation())

    def step(self, action: int) -> dm_env.TimeStep:
        """Updates the environment according to the action."""
        if self._reset_next_step:
            return self.reset()

        # Move the paddle.
        dx = _ACTIONS[action]
        self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)

        # Drop the ball.
        self._ball_y += 1

        # Check for termination.
        if self._ball_y == self._paddle_y:
            reward = 1.0 if self._paddle_x == self._ball_x else -1.0
            self._reset_next_step = True
            return dm_env.termination(reward=reward, observation=self._observation())
        else:
            return dm_env.transition(reward=0.0, observation=self._observation())

    def observation_spec(self) -> specs.BoundedArray:
        """Returns the observation spec."""
        return specs.BoundedArray(
            shape=self._board.shape,
            dtype=self._board.dtype,
            name="board",
            minimum=0,
            maximum=1,
        )

    def action_spec(self) -> specs.DiscreteArray:
        """Returns the action spec."""
        return specs.DiscreteArray(dtype=int, num_values=len(_ACTIONS), name="action")

    def _observation(self) -> np.ndarray:
        self._board.fill(0.0)
        self._board[self._ball_y, self._ball_x] = 1.0
        self._board[self._paddle_y, self._paddle_x] = 1.0
        return self._board.copy()

In [102]:
Catch().observation_spec().generate_value()
print(Catch())

<__main__.Catch object at 0x3103c4b50>


In [132]:
env = jumanji.make("Snake-v1")  # Create a Snake environment
env = AutoResetWrapper(env)  # Automatically reset the environment when an episode terminates

batch_size = 7
rollout_length = 5
num_actions = env.action_spec.num_values

random_key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(random_key)


def step_fn(state, key):
    action = jax.random.randint(key=key, minval=0, maxval=num_actions, shape=())
    new_state, timestep = env.step(state, action)
    return new_state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)
    return rollout


# Instantiate a batch of environment states
keys = jax.random.split(key1, batch_size)
state, timestep = jax.vmap(env.reset)(keys)

# Collect a batch of rollouts
keys = jax.random.split(key2, batch_size)
rollout = jax.vmap(run_n_steps, in_axes=(0, 0, None))(state, keys, rollout_length)

# Shape and type of given rollout:
# TimeStep(step_type=(7, 5), reward=(7, 5), discount=(7, 5), observation=(7, 5, 6, 6, 5), extras=None)

State(body=Array([[[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]],

       [[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False]],

       [[False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        [False, False, False, ..., False, False, False],
        ...,
        [False, False, False, ..., False, False, False],
        [False, False, False, ...,

In [82]:
rollout

TimeStep(step_type=Array([[1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1],
       [1, 1, 1, 2, 1],
       [1, 2, 2, 1, 1],
       [1, 1, 1, 1, 1]], dtype=int8), reward=Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32), discount=Array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 0., 1.],
       [1., 0., 0., 1., 1.],
       [1., 1., 1., 1., 1.]], dtype=float32), observation=Observation(grid=Array([[[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 1., 0.],
          [0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
         