In [None]:
!pip install pomdp-py

Collecting pomdp-py
  Downloading pomdp_py-1.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pomdp-py
Successfully installed pomdp-py-1.3.3


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## General setup

In [None]:
import pomdp_py
import numpy as np
import random

In [None]:
from collections import defaultdict

def tuple_default():
    return defaultdict((float, int))

In [None]:
import random

class Q_Learning_Table:
  def __init__(self, gamma, n, n_actions) -> None:
    self.q_tab = defaultdict(tuple_default)
    self.gamma = gamma
    self.n = n
    self.n_actions = n_actions

  def update(self, goal, o_t, a_t, r_t, new_o):
    #TODO: Ler sobre flattening policy
    a_v = self.q_tab[(goal, new_o)]
    best_a = None
    max_v = float('-inf')

    for a, (v, _) in a_v.items():
      if v > max_v:
        max_v = v
        best_a = a

    if not best_a:
      best_a = random.randint(1, self.n_actions)

    new_value = r_t + self.gamma * max_v

    if (goal, o_t) not in self.q_tab or (a_t not in self.q_tab[(goal, o_t)]):
      self.q_tab[(goal, o_t)][a_t] = new_value
    else:
      v, N = self.q_tab[(goal, o_t)][a_t]
      alpha = 1/N
      self.q_tab[(goal, o_t)][a_t] = (1 - alpha) * self.q_tab[(goal, o_t)][a_t] \
        + alpha * new_value


class ExperienceMemory:
  def __init__(self, state_value_tab : Q_Learning_Table) -> None:
    self.state_value_tab = state_value_tab
    self.experiences = []

  def add_experience(self, goal, o_t, a_t, r_t, new_o):
    self.experiences.append((goal, o_t, a_t, r_t, new_o))
    self.state_value_tab.update(goal, o_t, a_t, r_t, new_o)

  def __hash__(self):
      return hash(experiences)

  def __eq__(self, other):
    if isinstance(other, State):
      return self.experiences == other.experiences
    else:
      return False

  def __str__(self):
    return f"""Experience memory: {self.experiences}\n"""

  def __repr__(self):
    return f"""Experience memory: {self.experiences}\n"""


In [None]:
class State(pomdp_py.State):
    def __init__(self, exp_mem : ExperienceMemory):
        self.exp_mem = exp_mem

    def __hash__(self):
      return hash(self.exp_mem)

    def __eq__(self, other):
      if isinstance(other, State):
        return self.exp_mem == other.exp_mem
      else:
        return False

    def __str__(self):
      return f"""Experience memory: {self.exp_mem}\n"""

    def __repr__(self):
      return f"""Experience memory: {self.exp_mem}\n"""

In [None]:
class Action(pomdp_py.Action):
    """Simple named action."""
    def __init__(self, name):
        self.name = name
    def __hash__(self):
        return hash(self.name)
    def __eq__(self, other):
        if isinstance(other, Action):
            return self.name == other.name
        elif type(other) == str:
            return self.name == other
    def __str__(self):
        return self.name
    def __repr__(self):
        return "Action(%s)" % self.name

In [None]:
class Observation(pomdp_py.Observation):
    def __init__(self, screen: str):
        self.screen = screen

    def __hash__(self):
      return hash(self.screen)

    def __eq__(self, other):
        return self.screen == other.screen

    def __str__(self):
      return f"""Screen: {self.screen}\n"""

    def __repr__(self):
      return f"""Screen: {self.screen}\n"""

In [None]:
class ObservationModel(pomdp_py.ObservationModel):
  def probability(self, observation, next_state, action):
    #Vamos assumir que a observação, no caso a descrição da tela, é fidedigna.  Portanro, esta não é uma fonte de incerteza.
    return 1.0

  def sample(self, next_state: State, action):
    exp = next_state.exp_mem.experiences
    return exp[len(exp) - 1][1]

In [None]:
class TransitionModel(pomdp_py.TransitionModel):
  def probability(self, next_state, state, action):
    return 1.0 - 1e-9

  def sample(self, state, action):
    #TODO: Aqui, vamos precisar do env para executar a ação no novo estado
    pass
