We will create a simplified Monkey Banana environment to test classical RL methods on.

The environment is a "2D world" with discrete and finite states, actions, and rewards. It is fully observable to our agent, the Monkey.

### State

State consists of 5 values:

- agent position (x-axis)
- chair position (x-axis)
- banana position (x-axis)
- is_holding_chair (0 or 1)
- on_chair (0 or 1)

We do not explicitly model the y-axis, since all we care about is whether the Monkey is on the chair when he reaches for the banana.

### Actions

- move left one step
- move right one step
- climb on the chair
- climb down the chair
- grab the chair
- drop the chair
- grab the banana

### Rewards

- -1 for each action
- +10 for grabbing the banana


In [None]:
# %load_ext autoreload
# %autoreload 2
import os  # noqa
import sys  # noqa

module_path = os.path.abspath(os.path.join("posts/monkey-banana-mdp/code"))
sys.path.insert(0, module_path)
from environment import LineWorldEnv  # noqa

In [None]:
import numpy as np  # noqa
import pygame  # noqa

import gymnasium as gym  # noqa
from gymnasium import spaces  # noqa
from gymnasium.envs.registration import register  # noqa

First, we test the setup by observing the Monkey take random actions.


In [None]:
env = LineWorldEnv(render_mode="human", size=5)
obs, info = env.reset()

for _ in range(50):
    actions = env.get_possible_actions(env.flatten_obs(obs))
    # Choose random action
    action = np.random.choice(actions)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, info = env.close()

The blue circle is the Monkey, the green square is the chair, and the yellow square is the banana. On the top left you will see the action that is being taken.


Next, we use Dynamic Programming with value iteration to find the optimal policy.


In [None]:
env = LineWorldEnv(size=10)
obs, info = env.reset()

all_states = env.get_all_states()
state_values = {s: 0 for s in all_states}

def action_evaluation(state, action):
    env.set_obs(state)
    next_state, reward, terminated, truncated, _  =  env.step(action)
    flattened_next_state = env.flatten_obs(next_state)
    value = reward + state_values[flattened_next_state]
    return 0 if terminated else value

# Value iteration
theta = 0.1
sweep_count = 0
biggest_change = np.inf
while biggest_change > theta:
    biggest_change = 0
    for s in all_states:
        original_value = state_values[s]
        best_value = -np.inf
        possible_actions = env.get_possible_actions(s)
        for action in possible_actions:
            value = action_evaluation(s, action)
            if value > best_value:
                best_value = value
        state_values[s] = best_value
        biggest_change = max(biggest_change, abs(original_value - state_values[s]))
    sweep_count += 1
print("Number of sweeps: ", sweep_count)
env.close()

# Create optimal policy pi from state values:
policy = {} 
for s in all_states:
      possible_actions = env.get_possible_actions(s)
      best_value = -np.inf
      best_action = None
      for a in possible_actions:
          value = action_evaluation(s, a)
          if value > best_value:
              best_value = value
              best_action = a
      policy[s] = best_action


We can observe the Monkey following the optimal policy produced from Dynamic Programming:


In [None]:
env = LineWorldEnv(render_mode="human", size=10)
cbs, info = env.reset()

while True:
    action = policy[env.flatten_obs(obs)]
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        env.close()
        break