In [None]:
!pip install pomdp-py

Collecting pomdp-py
  Downloading pomdp_py-1.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pomdp-py
Successfully installed pomdp-py-1.3.3


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## General setup

In [None]:
import pomdp_py
import numpy as np
import random

In [None]:
from collections import defaultdict

def tuple_default():
    return defaultdict((float, int))

defaultdict(<function __main__.tuple_default()>, {})

In [None]:
import random

class Q_Learning_Table:
  def __init__(self, gamma, n, n_actions) -> None:
    self.q_tab = defaultdict(tuple_default)
    self.gamma = gamma
    self.n = n
    self.n_actions = n_actions

  def update(self, goal, o_t, a_t, r_t, new_o):
    #TODO: Ler sobre flattening policy
    a_v = self.q_tab[(goal, new_o)]
    best_a = None
    max_v = float('-inf')

    for a, (v, _) in a_v.items():
      if v > max_v:
        max_v = v
        best_a = a

    if not best_a:
      best_a = random.randint(1, self.n_actions)

    new_value = r_t + self.gamma * max_v

    if (goal, o_t) not in self.q_tab or (a_t not in self.q_tab[(goal, o_t)]):
      self.q_tab[(goal, o_t)][a_t] = new_value
    else:
      v, N = self.q_tab[(goal, o_t)][a_t]
      alpha = 1/N
      self.q_tab[(goal, o_t)][a_t] = (1 - alpha) * self.q_tab[(goal, o_t)][a_t] \
        + alpha * new_value


class ExperienceMemory:
  def __init__(self, state_value_tab : Q_Learning_Table) -> None:
    self.state_value_tab = state_value_tab
    self.experiences = []

  def add_experience(self, goal, o_t, a_t, r_t, new_o):
    self.experiences.append((goal, o_t, a_t, r_t, new_o))
    self.state_value_tab.update(goal, o_t, a_t, r_t, new_o)

  def __hash__(self):
      return hash(experiences)

  def __eq__(self, other):
    if isinstance(other, State):
      return self.experiences == other.experiences
    else:
      return False

  def __str__(self):
    return f"""Experience memory: {self.experiences}\n"""

  def __repr__(self):
    return f"""Experience memory: {self.experiences}\n"""


In [None]:
class State(pomdp_py.State):
    def __init__(self, exp_mem : ExperienceMemory):
        self.exp_mem = exp_mem

    def __hash__(self):
      return hash(self.exp_mem)

    def __eq__(self, other):
      if isinstance(other, State):
        return self.exp_mem == other.exp_mem
      else:
        return False

    def __str__(self):
      return f"""Experience memory: {self.exp_mem}\n"""

    def __repr__(self):
      return f"""Experience memory: {self.exp_mem}\n"""

In [None]:
class Action(pomdp_py.Action):
    """Simple named action."""
    def __init__(self, name):
        self.name = name
    def __hash__(self):
        return hash(self.name)
    def __eq__(self, other):
        if isinstance(other, Action):
            return self.name == other.name
        elif type(other) == str:
            return self.name == other
    def __str__(self):
        return self.name
    def __repr__(self):
        return "Action(%s)" % self.name

In [None]:
class Observation(pomdp_py.Observation):
    def __init__(self, screen: str):
        self.screen = screen

    def __hash__(self):
      return hash(self.screen)

    def __eq__(self, other):
        return self.screen == other.screen

    def __str__(self):
      return f"""Screen: {self.screen}\n"""

    def __repr__(self):
      return f"""Screen: {self.screen}\n"""

## WebShop Case Study

### Installation and setup

In [None]:
#Util files
!git clone https://github.com/monilouise/IN1087.git

fatal: destination path 'IN1087' already exists and is not an empty directory.


In [None]:
!git clone https://github.com/princeton-nlp/webshop.git webshop

fatal: destination path 'webshop' already exists and is not an empty directory.


In [None]:
!cp -f IN1087/webshop/requirements.txt webshop

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

[0m✨🍰✨ Everything looks OK!


In [None]:
%cd webshop
!./setup.sh -d small

/content/webshop
Collecting beautifulsoup4==4.11.1
  Downloading beautifulsoup4-4.11.1-py3-none-any.whl (128 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.2/128.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cleantext==1.1.4
  Downloading cleantext-1.1.4-py3-none-any.whl (4.9 kB)
Collecting env==0.1.0
  Downloading env-0.1.0.tar.gz (1.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting Flask==2.1.2
  Downloading Flask-2.1.2-py3-none-any.whl (95 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.2/95.2 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gdown
  Downloading gdown-4.7.1-py3-none-any.whl (15 kB)
Collecting gradio
  Downloading gradio-4.7.1-py3-none-any.whl (16.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m83.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gym==0.24.0
  Downloading gym-0.24.0.tar.gz (694 kB)
[2K     [90m━━━━━━━━━━━━━

In [None]:
!conda install mkl=2021

\ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | 

In [None]:
import gym
from web_agent_site.envs import WebAgentTextEnv

env = gym.make('WebAgentTextEnv-v0', observation_mode='text', num_products=1000)

100%|██████████| 1000/1000 [00:00<00:00, 40110.01it/s]


Loaded 6910 goals.


You can set `disable_env_checker=True` to disable this check.[0m
  logger.warn(


In [None]:
class ObservationModel(pomdp_py.ObservationModel):
    def probability(self, observation, next_state, action):
        if action.name == "COC":
          return 1.0
        elif observation.is_pilot_responding and observation.current_alert == action.name:
          return 1.0
        elif observation.current_alert == "COC" and action.name != "COC":
          return 1 / 6
        elif self.is_opposite_sense(observation, action):
            return 1 / 6
        elif self.is_same_sense(observation, action):
            return 1 / 4
        else:
            return 1.0

    def is_opposite_sense(self, observation, action):
        return observation.vertical_speed < 0 and "C" in action.name or \
                observation.vertical_speed > 0 and "C" not in action.name

    def is_same_sense(self, observation, action):
        return observation.vertical_speed < 0 and "C" not in action.name or \
                observation.vertical_speed > 0 and "C" in action.name

    def sample(self, next_state, action):
        if action.name == "COC":
          thresh = 1.0
        elif next_state.is_pilot_responding and next_state.current_alert == action.name:
          thresh = 1.0
        elif next_state.current_alert == "COC" and action.name != "COC":
          thresh = 1 / 6
        elif self.is_opposite_sense(next_state, action):
            thresh = 1 / 6
        elif self.is_same_sense(next_state, action):
            thresh = 1 / 4
        else:
            thresh = 1.0

        if np.random.uniform(0,1) < thresh:
            return Observation(next_state.relative_altitude,
                               next_state.vertical_speed,
                               next_state.intruder_vertical_speed,
                               next_state.collision_time,
                               next_state.current_alert,
                               is_pilot_responding=True)
        else:
            return Observation(next_state.relative_altitude,
                               next_state.vertical_speed,
                               next_state.intruder_vertical_speed,
                               next_state.collision_time,
                               next_state.current_alert,
                               is_pilot_responding=False)

In [None]:
class TransitionModel(pomdp_py.TransitionModel):
    def probability(self, next_state, state, action):
        return 1.0 - 1e-9

    def sample(self, state, action):
        if state.collision_time == 0:
            init_relative_altitude = np.random.choice([i for i in range(-40000, 40001, 33)])
            init_vertical_speed = np.random.choice([i for i in range(-10000, 10001, 25)])
            init_intruder_vertical_speed = np.random.choice([i for i in range(-10000, 10001, 25)])
            init_collision_time = np.random.choice([i for i in range(0, 41)])
            return State(init_relative_altitude,
                         init_vertical_speed,
                         init_intruder_vertical_speed,
                         init_collision_time,
                         "COC",
                         is_pilot_responding=True)
        else:
            if state.is_pilot_responding:
              if self.is_opposite_sense(state, action):
                  acceleration = np.random.normal(loc=0, scale=1)
              elif self.is_same_sense(state, action):
                  acceleration = np.random.normal(loc=0, scale=1)
            else:
              acceleration = np.random.normal(loc=0, scale=3)

            intruder_acceleration = np.random.normal(loc=0, scale=3)
            next_relative_altitude = state.relative_altitude + \
                                    state.intruder_vertical_speed + \
                                    (1/2) * intruder_acceleration - \
                                    state.vertical_speed - \
                                    (1/2) * acceleration
            next_vertical_speed = state.vertical_speed + acceleration
            next_intruder_vertical_speed = state.intruder_vertical_speed + \
                                            intruder_acceleration
            next_collision_time = state.collision_time - 1
            next_current_alert = action.name
            next_is_pilot_responding = state.is_pilot_responding
            return State(
                relative_altitude=next_relative_altitude,
                vertical_speed=next_vertical_speed,
                intruder_vertical_speed=next_intruder_vertical_speed,
                collision_time=next_collision_time,
                current_alert=next_current_alert,
                is_pilot_responding=next_is_pilot_responding
            )

    def is_opposite_sense(self, state, action):
        return state.vertical_speed < 0 and "C" in action.name or \
                state.vertical_speed > 0 and "C" not in action.name

    def is_same_sense(self, state, action):
        return state.vertical_speed < 0 and "C" not in action.name or \
                state.vertical_speed > 0 and "C" in action.name

In [None]:
class PolicyModel(pomdp_py.RolloutPolicy):
    """A simple policy model with uniform prior over a
       small, finite action space"""
    ACTIONS = {
        Action(s) for s in {"COC", "DNC2000", "DND2000", "DNC1000", "DND1000",
                            "DNC500", "DND500", "DNC", "DND", "MDES", "MCL",
                            "DES1500", "CL1500", "SDES1500", "SCL1500",
                            "SDES2500", "SCL2500"
        }
    }

    def sample(self, state):
        return np.random.choice([a for a in self.get_all_actions()])

    def rollout(self, state, *args):
        """Treating this PolicyModel as a rollout policy"""
        return self.sample(state)

    def get_all_actions(self, state=None, history=None):
        return PolicyModel.ACTIONS

In [None]:
class RewardModel(pomdp_py.RewardModel):
    def _reward_func(self, state, action):
      reward = 0
      closure = state.intruder_vertical_speed - state.vertical_speed
      #delta_h = min()
      if state.relative_altitude <= 175 and \
         state.collision_time <=0:
          reward += -1
      if state.current_alert == action.name and \
         state.vertical_speed < 1500:
          reward += -1
      if self.is_preventive(state, action) and \
         self.is_crossing(state, action):
          reward += -1
      if state.relative_altitude > 650 and \
         closure < 2000 and \
         not self.is_preventive(state, action):
          reward += -0.1
      if state.relative_altitude > 1000 and \
         closure < 4000 and \
         not self.is_preventive(state, action):
          reward += -3e-2
      if state.relative_altitude > 650 and \
         closure < 2000 and \
         self.is_preventive(state, action):
          reward += -1e-2
      if state.relative_altitude > 500 and self.is_crossing(state, action):
          reward += -1e-2
      if self.is_reversal(state, action):
          reward += -8e-3
      if self.is_strengthening(state, action):
          reward += -5e-3
      if self.is_weakening(state, action):
          reward += -1e-3
      if ("DND" not in action.name or "DNC" not in action.name) and \
          closure > 3000:
          reward += -1.5e-3
      if action.name != "COC" and closure < 3000:
          reward += -2.3e-3
      if ("DND" not in action.name or "DNC" not in action.name) and \
          closure > 3000:
          reward += -5e-4
      if action.name == "COC":
          reward += 1e-9

      return reward

    def is_up_action(self, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return action_alert.endswith('C') or action_alert.endswith('CL')

    def is_down_action(self, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return action_alert.endswith('D') or action_alert.endswith('DES')

    def is_crossing(self, state, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return (state.relative_altitude < 0 and self.is_down_action(action)) or \
             (state.relative_altitude > 0 and self.is_up_action(action))

    def is_preventive(self, state, action):
      if any(char.isdigit() for char in action.name):
        action_speed_str = ''.join([s for s in action.name if s.isdigit()])
        action_speed = int(action_speed_str)
        if action_speed and self.is_up_action(action):
            return state.vertical_speed < action_speed
        else:
            return state.vertical_speed > action_speed
      elif action.name == 'DNC':
          return state.vertical_speed < 0
      elif action.name == 'DND':
          return state.vertical_speed > 0
      else:
          return True

    def is_reversal(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return (state_alert.endswith('D') or state_alert.endswith('DES') and \
            self.is_up_action(action)) or \
            (state_alert.endswith('C') or state_alert.endswith('CL') and \
            self.is_down_action(action))

    def is_strengthening(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      state_speed = ''.join([s for s in state.current_alert if s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      action_speed = ''.join([s for s in action.name if s.isdigit()])
      return ((state_alert.endswith('D') or state_alert.endswith('DES')) and \
            self.is_down_action(action)) and state_speed < action_speed or \
            ((state_alert.endswith('C') or state_alert.endswith('CL')) and \
            self.is_up_action(action)) and state_speed < action_speed

    def is_weakening(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      state_speed = ''.join([s for s in state.current_alert if s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      action_speed = ''.join([s for s in action.name if s.isdigit()])
      return ((state_alert.endswith('D') or state_alert.endswith('DES')) and \
            self.is_down_action(action)) and state_speed > action_speed or \
            ((state_alert.endswith('C') or state_alert.endswith('CL')) and \
            self.is_up_action(action)) and state_speed > action_speed


    def sample(self, state, action, next_state):
        # deterministic
        return self._reward_func(state, action)

In [None]:
ranges = {"COC":(float('-inf'), float('inf')),
          "DNC2000": (float('-inf'), 2000),
          "DND2000": (-2000, float('inf')),
          "DNC1000": (float('-inf'), 1000),
          "DND1000": (-1000, float('inf')),
          "DNC500": (float('-inf'), 500),
          "DND500": (-500, float('inf')),
          "DNC": (float('-inf'), 0),
          "DND": (0, float('inf')),
          "DES1500": (float('-inf'), -1500),
          "CL1500": (1500, float('inf')),
          "SDES1500": (float('-inf'), -1500),
          "SCL1500": (1500, float('inf')),
          "SDES2500": (float('-inf'), -2500),
          "SCL2500": (2500, float('inf'))}

#TESTE MONIQUE
class RewardModel2(pomdp_py.RewardModel):
    def _reward_func(self, state, action):
      reward = 0
      closure = state.intruder_vertical_speed - state.vertical_speed
      #delta_h = min()
      if state.relative_altitude <= 175 and \
         state.collision_time <=0:
          reward += -1
      if state.current_alert == action.name and \
         state.vertical_speed < 1500:
          reward += -1
      #TODO: PROHIBITED ADVISORY TRANSITON
      if self.is_preventive(state, action) and \
         self.is_crossing(state, action):
          reward += -1

      if state.relative_altitude > 650 and \
         closure < 2000 and \
         not self.is_preventive(state, action):
          reward += -0.1

      if state.relative_altitude > 1000 and \
         closure < 4000 and \
         not self.is_preventive(state, action):
          reward += -3e-2

      if state.relative_altitude > 650 and \
         closure < 2000 and \
         self.is_preventive(state, action):
          reward += -1e-2

      if state.relative_altitude > 500 and self.is_crossing(state, action):
          reward += -1e-2

      if self.is_reversal(state, action):
          reward += -8e-3

      if self.is_strengthening(state, action):
          reward += -5e-3

      if self.is_weakening(state, action):
          reward += -1e-3

      if ("DND" not in action.name or "DNC" not in action.name) and \
          closure > 3000:
          reward += -1.5e-3

      if action.name != "COC" and closure < 3000:
          reward += -2.3e-3

      if ("DND" not in action.name or "DNC" not in action.name) and \
          closure > 3000:
          reward += -5e-4

      #delta_h
      if self.is_crossing(state, action) and abs(state.vertical_speed) > 500 and \
       ((self.is_up_action(action) and state.vertical_speed < 0) or (self.is_down_action(action) and state.vertical_speed > 0)):
        reward += (-4e-4) * self.delta_h(state, action)

      if action == 'MDES' or action == 'MCL':
        reward += -4e-4


      if ("DND" not in action.name or "DNC" not in action.name):
        reward += -1e-4


      if action.name != "COC":
        reward += (-3e-5) * self.delta_h(state, action)

      if not self.is_preventive(state, action):
        reward += -1e-5

      if action.name == "COC":
          reward += 1e-9

      return reward

    def delta_h(self, state, action):
      hmin, hmax = float('-inf'), float('inf')

      if action.name in ranges:
        hmin, hmax = ranges[action.name][0], ranges[action.name][1]
      elif action.name == 'MDES':
        hmax = state.vertical_speed
      elif action.name == 'MCL':
        hmin = state.vertical_speed

      result = min(abs(hmin - state.vertical_speed), abs(hmax - state.vertical_speed))

      return result

    def is_in_range(self, state, action):
      if action in ranges:
        return state.vertical_speed >= ranges[action][0] and state.vertical_speed <= ranges[action][1]
      return True

    def is_up_action(self, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return action_alert.endswith('C') or action_alert.endswith('CL')

    def is_down_action(self, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return action_alert.endswith('D') or action_alert.endswith('DES')

    def is_crossing(self, state, action):
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return (state.relative_altitude < 0 and self.is_down_action(action)) or \
             (state.relative_altitude > 0 and self.is_up_action(action))

    def is_preventive(self, state, action):
      if any(char.isdigit() for char in action.name):
        action_speed_str = ''.join([s for s in action.name if s.isdigit()])
        action_speed = int(action_speed_str)
        if action.name.startswith('D'):
          if action_speed and self.is_up_action(action):
              return state.vertical_speed < action_speed
          else:
              return state.vertical_speed > -action_speed #ajuste
        else:
          if action_speed and self.is_up_action(action):
              return state.vertical_speed > action_speed
          else:
              return state.vertical_speed < -action_speed
      elif action.name == 'DNC':
          return state.vertical_speed < 0
      elif action.name == 'DND':
          return state.vertical_speed > 0
      else:
          return True


    def is_reversal(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      return (state_alert.endswith('D') or state_alert.endswith('DES') and \
            self.is_up_action(action)) or \
            (state_alert.endswith('C') or state_alert.endswith('CL') and \
            self.is_down_action(action))

    def is_strengthening(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      state_speed = ''.join([s for s in state.current_alert if s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      action_speed = ''.join([s for s in action.name if s.isdigit()])
      return ((state_alert.endswith('D') or state_alert.endswith('DES')) and \
            self.is_down_action(action)) and state_speed < action_speed or \
            ((state_alert.endswith('C') or state_alert.endswith('CL')) and \
            self.is_up_action(action)) and state_speed < action_speed

    def is_weakening(self, state, action):
      state_alert = ''.join([s for s in state.current_alert if not s.isdigit()])
      state_speed = ''.join([s for s in state.current_alert if s.isdigit()])
      action_alert = ''.join([s for s in action.name if not s.isdigit()])
      action_speed = ''.join([s for s in action.name if s.isdigit()])
      return ((state_alert.endswith('D') or state_alert.endswith('DES')) and \
            self.is_down_action(action)) and state_speed > action_speed or \
            ((state_alert.endswith('C') or state_alert.endswith('CL')) and \
            self.is_up_action(action)) and state_speed > action_speed


    def sample(self, state, action, next_state):
        # deterministic
        return self._reward_func(state, action)

In [None]:
class CollisionAvoidanceProblem(pomdp_py.POMDP):

    def __init__(self, init_true_state, init_belief):
        """init_belief is a Distribution."""
        agent = pomdp_py.Agent(init_belief,
                               PolicyModel(),
                               TransitionModel(),
                               ObservationModel(),
                               #RewardModel()
                               RewardModel2()
                               )
        env = pomdp_py.Environment(init_true_state,
                                   TransitionModel(),
                                   RewardModel())
        super().__init__(agent, env, name="AvoidanceCollisionProblem")

In [None]:
init_relative_altitude = np.random.choice([i for i in range(-40000, 40001, 33)])
init_vertical_speed = np.random.choice([i for i in range(-10000, 10001, 25)])
init_intruder_vertical_speed = np.random.choice([i for i in range(-10000, 10001, 25)])
init_collision_time = np.random.choice([i for i in range(0, 41)])
init_action = np.random.choice([action.name for action in PolicyModel.ACTIONS])
init_is_pilot_responding = np.random.choice([True, False])
init_true_state = State(init_relative_altitude, init_vertical_speed,
                        init_intruder_vertical_speed, init_collision_time,
                        init_action, is_pilot_responding=init_is_pilot_responding)
init_belief = pomdp_py.Histogram(
    {
        State(
            init_relative_altitude,
            init_vertical_speed,
            init_intruder_vertical_speed,
            init_collision_time,
            init_action,
            init_is_pilot_responding
            ) : 1.0
    }
)

In [None]:
collision_avoidance_problem = CollisionAvoidanceProblem(init_true_state, init_belief)

In [None]:
# Step 1; in main()
# creating planners
vi = pomdp_py.ValueIteration(horizon=3, discount_factor=0.95)
pouct = pomdp_py.POUCT(max_depth=3, discount_factor=0.95,
                       planning_time=.5, exploration_const=110,
                       rollout_policy=collision_avoidance_problem.agent.policy_model)
pomcp = pomdp_py.POMCP(max_depth=3, discount_factor=0.95,
                       planning_time=.5, exploration_const=110,
                       rollout_policy=collision_avoidance_problem.agent.policy_model)
...  # call test_planner() for steps 2-6.

# Steps 2-6; called in main()
def test_planner(collision_avoidance_problem, planner, nsteps=3):
   for i in range(nsteps):  # Step 6
        # Step 2
        action = planner.plan(collision_avoidance_problem.agent)

        print("==== Step %d ====" % (i+1))
        print("True state:", collision_avoidance_problem.env.state)
        print("Belief:", collision_avoidance_problem.agent.cur_belief)
        print("Action:", action)
        # Step 3;
        reward = collision_avoidance_problem.env.state_transition(action, execute=True)
        print("Reward:", reward)

        # Step 4
        real_observation = Observation(collision_avoidance_problem.env.state.relative_altitude,
                                       collision_avoidance_problem.env.state.vertical_speed,
                                       collision_avoidance_problem.env.state.intruder_vertical_speed,
                                       collision_avoidance_problem.env.state.collision_time,
                                       collision_avoidance_problem.env.state.current_alert,
                                       collision_avoidance_problem.env.state.is_pilot_responding)
        print(">> Observation: \n%s" % real_observation)

        # Step 5
        # Update the belief. If the planner is POMCP, planner.update
        # also automatically updates agent belief.
        collision_avoidance_problem.agent.update_history(action, real_observation)
        planner.update(collision_avoidance_problem.agent, action, real_observation)
        if isinstance(planner, pomdp_py.POUCT):
            print("Num sims: %d" % planner.last_num_sims)
        if isinstance(collision_avoidance_problem.agent.cur_belief, pomdp_py.Histogram):
            new_belief = pomdp_py.update_histogram_belief(collision_avoidance_problem.agent.cur_belief,
                                                          action, real_observation,
                                                          collision_avoidance_problem.agent.observation_model,
                                                          collision_avoidance_problem.agent.transition_model)
            collision_avoidance_problem.agent.set_belief(new_belief)

In [None]:
test_planner(collision_avoidance_problem=collision_avoidance_problem, planner=pouct, nsteps=init_collision_time)

==== Step 1 ====
True state: Relative altitude: 32621.64995894356,
        Vertical speed: 6123.722177184522,
        Intruder vertical speed: 6021.19146336659,
        Collision time: 0,
        Current alert: MDES,
        Pilot responding: False

Belief: {Relative altitude: 33029,
        Vertical speed: 6125,
        Intruder vertical speed: 6025,
        Collision time: 4,
        Current alert: SDES1500,
        Pilot responding: False
: 1.0}
Action: MDES
Reward: -0.012299999594688416
>> Observation: 
Relative altitude: 35801,
        Vertical speed: 5975,
        Intruder vertical speed: -6700,
        Collision time: 8,
        Current alert: COC,
        Pilot responding: True

Num sims: 856
==== Step 2 ====
True state: Relative altitude: 35801,
        Vertical speed: 5975,
        Intruder vertical speed: -6700,
        Collision time: 8,
        Current alert: COC,
        Pilot responding: True

Belief: {Relative altitude: 33029,
        Vertical speed: 6125,
        Intru