<a href="https://colab.research.google.com/github/nathanwispinski/meta-rl/blob/main/play_env_as_human.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# play_env_as_human.ipynb

This is a Google Colab notebook to demo the meta-rl multi-armed bandit environment.

For more details, see the GitHub repository (https://github.com/nathanwispinski/meta-rl).

# Colab setup

In [1]:
#@title Clone GitHub repository.
!git clone https://github.com/nathanwispinski/meta-rl

Cloning into 'meta-rl'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 60 (delta 23), reused 38 (delta 12), pack-reused 0[K
Unpacking objects: 100% (60/60), 202.89 KiB | 3.22 MiB/s, done.


In [2]:
#@title Change working directory to cloned repository (i.e., /content/meta-rl/).
%cd meta-rl

/content/meta-rl


In [3]:
# @title Install dependencies from `requirements.txt`.
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting absl_py==1.3.0
  Downloading absl_py-1.3.0-py3-none-any.whl (124 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.6/124.6 KB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting chex==0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.3/85.3 KB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dm_haiku==0.0.9
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jax==0.3.25
  Downloading jax-0.3.25.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.3.25
  Downloading jaxl

# Import dependencies

In [4]:
#@title Import dependencies after install.
import numpy as np
import time
from IPython.display import display, clear_output
import ipywidgets as widgets

import modules.envs as envs

# Create an environment to play.

In [5]:
#@title Create a config for the environment.
#@markdown Available `reward_structure` configs are: `independent`, and `correlated`.

steps_per_episode = 10 #@param {type:"integer"}
num_arms = 2 #@param {type:"integer"}
reward_structure = "correlated" #@param {type:"string"}
total_episodes = 10 #@param {type:"integer"}

env_config = {
    "steps_per_episode": steps_per_episode,
    "num_arms": num_arms,
    "reward_structure": reward_structure,
}

In [6]:
#@title Create environment.
env = envs.create_env(env_config=env_config)
valid_action_range = np.arange(env.num_actions)
observation = env.reset()

# Play

In [7]:
#@title Play environment as a human.

output = widgets.Output()

# Make one button per bandit arm
num_arms = env_config['num_arms']
buttons = []
click_fns = []
for i in range(num_arms):
    arm_name = "Arm " + str(i + 1)
    button = widgets.Button(description=arm_name)
    def button_click(a, action=i):
        next_observation, reward, done, info = env.step(action)
        step = interactive_info['step']
        print(f'Step: {step}; Action: {action}; Reward: {reward}; {next_observation}')
        if done:
            interactive_info['step'] = 0
        else:
            interactive_info['step'] += 1
        interactive_info['observation'] = next_observation
        if done:
            print(f'Episode done. Arm win probs were: {env._arm_probs}')
            env.reset()
    buttons.append(button)
    click_fns.append(button_click)

for (button, click_fn) in zip(buttons, click_fns):
    button.on_click(click_fn)

# Make a clear history button
button_0 = widgets.Button(description="Clear history")
def button_0_click(b):
    clear_output()
    display(*buttons)
button_0.on_click(button_0_click)

interactive_info = {
    'step': 0,
    'observation': None,
    'next_observation': None,
    'reward': None,
    'done': None,
    'info': None,
    'action': None
    }

buttons.append(button_0)
buttons.append(output)

display(*buttons)

interactive_info['observation'] = env.reset()
print(interactive_info['observation'])


Button(description='Arm 1', style=ButtonStyle())

Button(description='Arm 2', style=ButtonStyle())

Button(description='Clear history', style=ButtonStyle())

Output()

{'vector_input': array([0., 0., 0.])}
Step: 0; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 1; Action: 0; Reward: 0.0; {'vector_input': array([1., 0., 0.])}
Step: 2; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 3; Action: 1; Reward: 0.0; {'vector_input': array([0., 1., 0.])}
Step: 4; Action: 1; Reward: 0.0; {'vector_input': array([0., 1., 0.])}
Step: 5; Action: 1; Reward: 1.0; {'vector_input': array([0., 1., 1.])}
Step: 6; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 7; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 8; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 9; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Episode done. Arm win probs were: [0.90476954 0.09523046]
Step: 0; Action: 0; Reward: 0.0; {'vector_input': array([1., 0., 0.])}
Step: 1; Action: 0; Reward: 1.0; {'vector_input': array([1., 0., 1.])}
Step: 2; Action: 0; Reward: 1.0; {'vector_input': ar