<a href="https://colab.research.google.com/github/ibenatar-96/tiger-pomdp-mplr/blob/main/tiger_pomdp_mplr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install numpyro

Collecting numpyro
  Downloading numpyro-0.14.0-py3-none-any.whl (330 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/330.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m235.5/330.2 kB[0m [31m6.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m327.7/330.2 kB[0m [31m8.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m330.2/330.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: numpyro
Successfully installed numpyro-0.14.0


In [None]:
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import sys
import os
import time
import copy
import random
from itertools import product

numpyro.set_host_device_count(4)

# Tiger POMDP (Partially Observable Markov Decision Process)

The Tiger POMDP is a classical problem in the field of artificial intelligence and decision-making under uncertainty. It's used to illustrate the challenges of decision-making when there's uncertainty about the state of the environment.

In the Tiger POMDP scenario, an agent is placed in a room with two doors. Behind one door is his freedom, and behind the other is a tiger. The agent doesn't know which door leads to which outcome. It can take actions like "listen" to hear a sound indicating the location of the tiger or "open" a door to reveal its contents. However, actions are imperfect, leading to uncertainty.

Solving the Tiger POMDP involves finding a policy that maximizes the expected cumulative reward over time, taking into account the uncertainty and partial observability. Various algorithms, such as belief state planning or particle filtering, can be used to approximate or solve POMDPs.

The Tiger POMDP can be represented as a Partially Observable Markov Decision Process (POMDP) defined by the tuple $(S, A, T, R, Ω, O, γ)$, where:

\\
$S$: The finite set of states consists of two elements representing the locations of the tiger and his freedom.

$A$: The finite set of actions available to the agent includes "listen" and "open" representing actions to gather information or make decisions.

$T$: The state transition function $T: S \times A \mapsto S$ describes the transition probabilities between states based on actions. For example, if the agent decides to "open" a door, the state transition function will determine the probabilities of transitioning to either the tiger or freedom state.

$R$: The reward function $R: S \times A \mapsto \mathbb{R}$ provides immediate rewards for actions in specific states. For instance, opening the door containing his freedom could yield a positive reward, while opening the door with the tiger will result in a negative reward.

$Ω$: The finite set of observations consists of two elements representing the possible observations of "tiger-left" and "tiger-right" when the agent chooses to "listen".

$O$: The set of conditional observation probabilities $Ω$ specifies the likelihood of observing each possible observation given the true state of the environment.

$γ$: The discount factor $γ \in [0,1]$ accounts for the importance of future rewards relative to immediate rewards in the agent's decision-making process.

## Generating "Synthetic" Observations.

Now let's create synthetic observations / data for our model to learn from.

This synthetic data is following the rules of the original Tiger POMDP problem,
where the probability of getting a correct observation when doing "listen" action is 0.85.

Our observations will be in the form of a list that consists of episodes:

$b_{0},(a_{0},o_{0}),(a_{1},o_{1}),...,(,a_{n},o_{n})$

Where $(a_{t},o_{t})$ -

$a_{t}$ - the action taken in time step $t$

$o_{t}$ - the observation recieved at time step $t$

$\Omega = \{o_{tl}, o_{tr}\} $

Where $O(s',a,o) = P(o_{tl}|s',a_{listen}) = P(o_{tr}|s',a_{listen}) =
\begin{align}
     \left\{
        \begin{array}{cl}
        0.85 & \text{if s' == s}  \\
        0.15 & \text{if s' != s}
        \end{array}
    \right.
\end{align}$

\\

$O(s'_{tl},a_{listen},o_{tl}) = 0.85$

$O(s'_{tl},a_{listen},o_{tr}) = 0.15$

$O(s'_{tr},a_{listen},o_{tl}) = 0.15$

$O(s'_{tr},a_{listen},o_{tr}) = 0.85$


The context here of s' is that I have "reached" the state that the action took me to.
Meaning - for example, we know that $a_{listen}$ keeps us in the same state, so if the real tiger location is behind the left door - acting $a_{listen}$ will result in s' = "tiger-left", and the probability of observing $o_{tl}$ is 0.85, and the probability of observing $o_{tr}$ is 0.15.

* When action $a$ is "open-left" or "open-right", then recieving $o_{tl}$ and $o_{tr}$ are evenly distributed (uniformly).

Explanation - Observing the correct state occurs with probability 0.85, for example when the state is $s_{tl}$ (meaning the tiger is behind the left door), we will recieve the correct observation - $o_{tl}$ (that the tiger is behind the left door) with probability 0.85, and we will recieve the incorrect observation - $o_{tr}$ with probability 0.15.

In [None]:
Actions = ["listen", "open-left", "open-right"]
States = ["tiger-left", "tiger-right"]
Observations = ["tiger-left", "tiger-right"]
Inital_Belief_State = {"tiger-left": 0.5, "tiger-right": 0.5}
Terminate_Actions = ["open-left", "open-right"]

# Observation Model: {(state-prime, action): {observation1: probability1, observation2: probability2}}
Observation_Model = {("tiger-left","listen"): {"tiger-left": 0.85,
                                               "tiger-right": 0.15},
                     ("tiger-right","listen"): {"tiger-left": 0.15,
                                                "tiger-right":0.85},
                     ("tiger-left","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-left","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5}}

# Transition Model: {(state, action): {state-prime1: probability1, state-prime2: probability2}}
Transition_Model = {("tiger-left","listen"): {"tiger-left": 1.0,
                                               "tiger-right": 0.0},
                     ("tiger-right","listen"): {"tiger-left": 0.0,
                                                "tiger-right":1.0},
                     ("tiger-left","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-left","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5}}

In [None]:
def gen_obs(mode='random'):
    episodes_obs = []
    for _ in range(15): # create 15 episodes
        episode_log = [np.array([0.5,0.5])]
        action = None
        tiger_state = np.random.choice(States, size=None)
        if tiger_state == "tiger-left":
            obs_prob = [0.85, 0.15]
        else:
            obs_prob = [0.15, 0.85]
        if mode == 'simple':
            action_cnt = 0
        while action not in Terminate_Actions:
            if mode == 'random':
                action = np.random.choice(Actions, size=None, p=[0.5,0.25,0.25])
            elif mode == 'simple':
                if action_cnt == 0:
                    action = 'listen'
                else:
                    action = np.random.choice([a for a in Actions if a.startswith("open")], size=None)
                action_cnt += 1
            assert action is not None
            if action == "listen":
                obs = np.random.choice(Observations, size=None, p=obs_prob)
            else:
                obs = np.random.choice(Observations, size=None, p=[1 if state == tiger_state else 0 for state in States])
            episode_log.append((action, obs))
        episodes_obs.append(episode_log)
    return episodes_obs

In [None]:
observations = gen_obs()
for i,obs in enumerate(observations):
    print(f"obs[{i}] = {obs}")

obs[0] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-right', {'tiger-left': 0.15, 'tiger-right': 0.85}), ({'tiger-left': 0.15, 'tiger-right': 0.85}, 'open-left', 'tiger-right', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[1] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'open-right', 'tiger-right', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[2] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-left', {'tiger-left': 0.85, 'tiger-right': 0.15}), ({'tiger-left': 0.85, 'tiger-right': 0.15}, 'listen', 'tiger-left', {'tiger-left': 0.9697986577181208, 'tiger-right': 0.0302013422818792}), ({'tiger-left': 0.9697986577181208, 'tiger-right': 0.0302013422818792}, 'open-right', 'tiger-left', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[3] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-right', {'tiger-left': 0.15, 'tiger-right': 0.85}), ({'tiger-left': 0.15, 'tiger-right': 0.85}, 'listen', 'tiger-left', {'tiger-left': 0.5, 'tiger-right': 0.5}), ({'tiger-

In [None]:
def tiger_model(obs=None):
    """
    Args:
    obs: observations - logs of episodes.
    obs is a lists of lists (list of episode logs), [[...], [...], [...]],
    each list in the obs list contains: [belief-state_0, (action_0, observation_0),
                                                         (action_1, observation_1),
                                         ...,
                                                         (action_n, observation_n)]
    """
    p_observations = {}
    for state_prime,action in product(States, Actions):
        p_observations[(state_prime,action)] = {}
        for o in Observation_Model:
            p_observations[(state_prime,action)][o] = numpyro.sample(f"O({o},{action},{state_prime})") # Probability of seeing observation 'o', after doing action 'a' and reaching s'.

    if obs is not None:
        observations_arr = jnp.array([[Observation_Model[(sp, action)][observation] for sp in States] for o[1:] in obs for _, _, observation, belief_state_prime in o])


In [None]:
numpyro.render_model(tiger_model, model_args=(observations,), render_distributions=True, render_params=True,)

In [None]:
def inference(ai_model, obs):
    nuts_kernel = numpyro.infer.NUTS(ai_model)
    mcmc = numpyro.infer.MCMC(
        nuts_kernel,
        num_warmup=500,
        num_chains=4,
        num_samples=5000)
    mcmc.run(jax.random.PRNGKey(int(time.time() * 1E6)), obs=obs)
    mcmc.print_summary()
    return mcmc

In [None]:
inference(tiger_model, observations)