In [4]:
import gymnasium as gym
import gym_bart
import numpy as np

import ipywidgets as widgets
from IPython.display import display
from functools import partial

%run ../env/gym_bart/envs/bart_env.py
%run ../env/gym_bart/envs/bart_meta_env.py

# Display

In [8]:
class GameDisplay:
    def __init__(self, env):
        self.env = env
        env.reset()
        self.output = widgets.Output()

        if self.env.toggle_task:
            labels = ['Wait', 'Start/Stop', 'N/A']
        else:
            labels = ['Stop', 'Inflate', 'N/A']
        self.buttons = []
        for i, label in enumerate(labels):
            button = widgets.Button(description=label)
            button.on_click(self.generate_button_callback(i))
            self.buttons.append(button)
        self.widgets = [*self.buttons, self.output]

    def update(self, output):
        self.output.clear_output()
        with self.output:
            print(output)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.output.clear_output()
        with self.output:
            # print(f'action {action}')
            # print(f'Observation {obs}, Reward {reward}, Done {done}')
            print(f'Color: {self.env.current_color}, Size: {self.env.current_size:.2f}, Passive: {self.env.passive_trial}')
            print(f'Observation: {obs}')
            print(f'Reward: {reward}')

            # print(done)
            if terminated:
                if reward > 0:
                    print(f'Banked {reward}')
                else:
                    print(f'Popped! Reward {reward}')
                print('Env Reset')
                obs = self.env.reset()
        return obs, reward, terminated, info

    def generate_button_callback(self, action):
        on_click = lambda b: self.step(action)
        return on_click

    def display(self):
        display(*self.widgets)

In [22]:
# env = BartEnv(passive_trial_prob=0, fixed_reward_prob=0, random_start_wait=True)
env = BartEnv(passive_trial_prob=0, fixed_reward_prob=0,
              random_start_wait=False, fix_conditions=[{
                  'color': 'orange',
                  'delay': 3,
                  'passive': True,
                #   'size': 0.2
              }], punish_passive=-0.1)
out = GameDisplay(env)
out.display()

Button(description='Wait', style=ButtonStyle())

Button(description='Start/Stop', style=ButtonStyle())

Button(description='N/A', style=ButtonStyle())

Output()

In [11]:
# env = BartEnv(passive_trial_prob=0, fixed_reward_prob=0, random_start_wait=True)
env = BartEnv(num_balloons=2, max_steps=30, punish_passive=-0.1)
out = GameDisplay(env)
out.display()

Button(description='Wait', style=ButtonStyle())

Button(description='Start/Stop', style=ButtonStyle())

Button(description='N/A', style=ButtonStyle())

Output()

# BartMetaEnv

In [14]:
class GameDisplay:
    def __init__(self, env):
        self.env = env
        env.reset()
        self.output = widgets.Output()

        if self.env.toggle_task:
            labels = ['Wait', 'Start/Stop', 'N/A']
        else:
            labels = ['Stop', 'Inflate', 'N/A']
        self.buttons = []
        for i, label in enumerate(labels):
            button = widgets.Button(description=label)
            button.on_click(self.generate_button_callback(i))
            self.buttons.append(button)
        self.widgets = [*self.buttons, self.output]

    def update(self, output):
        self.output.clear_output()
        with self.output:
            print(output)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.output.clear_output()
        with self.output:
            # print(f'action {action}')
            # print(f'Observation {obs}, Reward {reward}, Done {done}')
            print(f'Color: {self.env.current_color}, Size: {self.env.current_size}')
            print(f'Observation: {obs}')
            print(f'Color mean: {self.env.balloon_mean_sizes[self.env.color_to_idx[self.env.current_color]]}')
            print(f'True max: {self.env.current_balloon_limit}')
            print(f'Reward: {reward}')
            print(f'Inflate delay: {info["inflate_delay"]}')

            # print(done)
            if 'bart_finished' in info and info['bart_finished']:
                if info['popped']:
                    print(f'Popped! Reward {reward}')
                else:
                    print(f'Banked {reward}')
                print('Env Reset')
            if terminated:
                print('TERMINATED')
        return obs, reward, terminated, info

    def generate_button_callback(self, action):
        on_click = lambda b: self.step(action)
        return on_click

    def display(self):
        display(*self.widgets)

In [15]:
env = gym.make('BartMetaEnv', meta_setup=1, colors_used=1, inflate_noise=0, num_balloons=2,
               rew_structure=6, rew_p=1.5, rew_on_pop=-0.1)
# env = BartMetaEnv(meta_setup=1, fix_sizes={1: 0.8}, colors_used=1, num_balloons=2, give_rew=True,
#                   rew_structure=6, rew_p=1.5, rew_on_pop=-0.1)
out = GameDisplay(env)
out.display()

Button(description='Wait', style=ButtonStyle())

Button(description='Start/Stop', style=ButtonStyle())

Button(description='N/A', style=ButtonStyle())

Output()

In [8]:
env.inflate_delay

0

In [14]:
env.observation_space.shapeg

(9,)

In [None]:
import sys
sys.path.append('../')
from ppo.envs import make_vec_env


In [None]:

env_kwargs = {'meta_setup': 1, 'colors_used': 1, 
            'inflate_noise': 0, 'pop_noise': 0,
            'rew_structure': 0, 'max_steps': 5,
            'num_balloons': 5}
envs = make_vec_env('BartMetaEnv', env_kwargs=env_kwargs, n_envs=4)
envs.reset()

array([[0.        , 0.00353516, 0.        , 0.        , 0.        ,
        0.        , 0.00353516, 0.        ],
       [0.        , 0.00353516, 0.        , 0.        , 0.        ,
        0.        , 0.00353516, 0.        ],
       [0.        , 0.00353516, 0.        , 0.        , 0.        ,
        0.        , 0.00353516, 0.        ],
       [0.        , 0.00353516, 0.        , 0.        , 0.        ,
        0.        , 0.00353516, 0.        ]], dtype=float32)