<a href="https://colab.research.google.com/github/ibenatar-96/tiger-pomdp-mplr/blob/main/tiger_pomdp_mplr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install numpyro

Collecting numpyro
  Downloading numpyro-0.14.0-py3-none-any.whl (330 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/330.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m235.5/330.2 kB[0m [31m6.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m327.7/330.2 kB[0m [31m8.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m330.2/330.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: numpyro
Successfully installed numpyro-0.14.0


In [None]:
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import sys
import os
import time
import copy
import random
from itertools import product

numpyro.set_host_device_count(4)

# Tiger POMDP (Partially Observable Markov Decision Process)

The Tiger POMDP is a classical problem in the field of artificial intelligence and decision-making under uncertainty. It's used to illustrate the challenges of decision-making when there's uncertainty about the state of the environment.

In the Tiger POMDP scenario, an agent is placed in a room with two doors. Behind one door is his freedom, and behind the other is a tiger. The agent doesn't know which door leads to which outcome. It can take actions like "listen" to hear a sound indicating the location of the tiger or "open" a door to reveal its contents. However, actions are imperfect, leading to uncertainty.

Solving the Tiger POMDP involves finding a policy that maximizes the expected cumulative reward over time, taking into account the uncertainty and partial observability. Various algorithms, such as belief state planning or particle filtering, can be used to approximate or solve POMDPs.

The Tiger POMDP can be represented as a Partially Observable Markov Decision Process (POMDP) defined by the tuple $(S, A, T, R, Ω, O, γ)$, where:

\\
$S$: The finite set of states consists of two elements representing the locations of the tiger and his freedom.

$A$: The finite set of actions available to the agent includes "listen" and "open" representing actions to gather information or make decisions.

$T$: The state transition function $T: S \times A \mapsto S$ describes the transition probabilities between states based on actions. For example, if the agent decides to "open" a door, the state transition function will determine the probabilities of transitioning to either the tiger or freedom state.

$R$: The reward function $R: S \times A \mapsto \mathbb{R}$ provides immediate rewards for actions in specific states. For instance, opening the door containing his freedom could yield a positive reward, while opening the door with the tiger will result in a negative reward.

$Ω$: The finite set of observations consists of two elements representing the possible observations of "tiger-left" and "tiger-right" when the agent chooses to "listen".

$O$: The set of conditional observation probabilities $Ω$ specifies the likelihood of observing each possible observation given the true state of the environment.

$γ$: The discount factor $γ \in [0,1]$ accounts for the importance of future rewards relative to immediate rewards in the agent's decision-making process.

## Generating "Synthetic" Observations.

Now let's create synthetic observations / data for our model to learn from.

This synthetic data is following the rules of the original Tiger POMDP problem,
where the probability of getting a correct observation when doing "listen" action is 0.85.

Our observations will be in the form of a list that consists of episodes:

$(b_{0},a_{0},o_{0},b_{1}),(b_{1},a_{1},o_{1},b_{2}),...,(b_{n-1},a_{n-1},o_{n-1},b_{n})$

Where $(b_{t},a_{t},o_{t},b_{t+1})$ -

$b_{t}$ - the belief state of time step $t$

$a_{t}$ - the action taken in time step $t$

$o_{t}$ - the observation recieved at time step $t$

$b_{1}$ - the updated belief state at time step $t+1$

We will use $b^{a}_{o}(s')=P(s'|o,a,b)=\frac{O(s'a,o)*\sum_{s\in S}Tr(s,a,s')*b(s)}{P(o|a,b)}$ to update our belief state.

$\Omega = \{o_{tl}, o_{tr}\} $

Where $O(s',a,o) = P(o_{tl}|s',a_{listen}) = P(o_{tr}|s',a_{listen}) =
\begin{align}
     \left\{
        \begin{array}{cl}
        0.85 & \text{if s' == s}  \\
        0.15 & \text{if s' != s}
        \end{array}
    \right.
\end{align}$

\\

$O(s'_{tl},a_{listen},o_{tl}) = 0.85$

$O(s'_{tl},a_{listen},o_{tr}) = 0.15$

$O(s'_{tr},a_{listen},o_{tl}) = 0.15$

$O(s'_{tr},a_{listen},o_{tr}) = 0.85$


The context here of s' is that I have "reached" the state that the action took me to.
Meaning - for example, we know that $a_{listen}$ keeps us in the same state, so if the real tiger location is behind the left door - acting $a_{listen}$ will result in s' = "tiger-left", and the probability of observing $o_{tl}$ is 0.85, and the probability of observing $o_{tr}$ is 0.15.

* When action $a$ is "open-left" or "open-right", then recieving $o_{tl}$ and $o_{tr}$ are evenly distributed (uniformly).

Explanation - Observing the correct state occurs with probability 0.85, for example when the state is $s_{tl}$ (meaning the tiger is behind the left door), we will recieve the correct observation - $o_{tl}$ (that the tiger is behind the left door) with probability 0.85, and we will recieve the incorrect observation - $o_{tr}$ with probability 0.15.

$\text{Transition Model T(s,a,s')}$

$T(s,a,s') =
\begin{align}
     \left\{
        \begin{array}{cl}
        1 & \text{if s' == s & a = listen}  \\
        0 & \text{if s' != s & a = listen}  \\
        0.5 & \text{if a = open-left | open-right}
        \end{array}
    \right.
\end{align}$

\\

$T(s_{tl},a_{listen},s'_{tl}) = T(s_{tr},a_{listen},s'_{tr}) = 1$

$T(s_{tl},a_{listen},s'_{tr}) = T(s_{tr},a_{listen},s'_{tl}) = 0$

$T(s_{tl},a_{open-left},s'_{tl}) = T(s_{tl},a_{open-left},s'_{tr}) = 0.5$

$T(s_{tr},a_{open-left},s'_{tl}) = T(s_{tr},a_{open-left},s'_{tr}) = 0.5$

$T(s_{tl},a_{open-right},s'_{tl}) = T(s_{tl},a_{open-right},s'_{tr}) = 0.5$

$T(s_{tr},a_{open-right},s'_{tl}) = T(s_{tr},a_{open-left},s'_{tr}) = 0.5$


Explanation - When doing a "listen" action, the state stays the same; and when doing a "open" action (open-left or open-right) the world is reset and the probability to transition to $s_{tl}$ and $s_{tr}$ is uniformly distributed.

In [None]:
Actions = ["listen", "open-left", "open-right"]
States = ["tiger-left", "tiger-right"]
Observations = ["tiger-left", "tiger-right"]
Inital_Belief_State = {"tiger-left": 0.5, "tiger-right": 0.5}
Terminate_Actions = ["open-left", "open-right"]

# Observation Model: {(state-prime, action): {observation1: probability1, observation2: probability2}}
Observation_Model = {("tiger-left","listen"): {"tiger-left": 0.85,
                                               "tiger-right": 0.15},
                     ("tiger-right","listen"): {"tiger-left": 0.15,
                                                "tiger-right":0.85},
                     ("tiger-left","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-left","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5}}

# Transition Model: {(state, action): {state-prime1: probability1, state-prime2: probability2}}
Transition_Model = {("tiger-left","listen"): {"tiger-left": 1.0,
                                               "tiger-right": 0.0},
                     ("tiger-right","listen"): {"tiger-left": 0.0,
                                                "tiger-right":1.0},
                     ("tiger-left","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-left"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-left","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5},
                     ("tiger-right","open-right"): {"tiger-right":0.5,
                                                   "tiger-left":0.5}}

In [None]:
def calc_next_bs(observation, action, state, belief_state):
    b_s = {}
    for state_prime in States:
        obs_prob = Observation_Model[(state_prime,action)][observation]
        prob_unfactored = 0.0
        for state, prob in belief_state.items():
            prob_unfactored += (Transition_Model[(state, action)][state_prime] * prob)
        b_s[state_prime] = obs_prob * prob_unfactored
    norm_factor = sum(b_s.values())
    norm_b_s = {key: value / norm_factor for key, value in b_s.items()} # Normalizing the Belief State
    return norm_b_s

def gen_obs():
    episodes_obs = []
    for _ in range(15): # create 15 episodes
        episode_log = []
        action = None
        tiger_state = np.random.choice(States, size=None)
        belief_state = Inital_Belief_State
        if tiger_state == "tiger-left":
            obs_prob = [0.85, 0.15]
        else:
            obs_prob = [0.15, 0.85]
        state = tiger_state
        while action not in Terminate_Actions:
            action = np.random.choice(Actions, size=None, p=[0.5,0.25,0.25])
            if action == "listen":
                obs = np.random.choice(Observations, size=None, p=obs_prob)
            else:
                obs = np.random.choice(Observations, size=None)
            # next_state = np.random.choice(list(Transition_Model[(state,action)].keys()), size=None, p=list(Transition_Model[(state,action)].values()))
            next_belief_state = calc_next_bs(obs, action, state, belief_state)
            episode_log.append((belief_state, action, obs, next_belief_state))
            belief_state = next_belief_state
        episodes_obs.append(episode_log)
    return episodes_obs

In [None]:
observations = gen_obs()
for i,obs in enumerate(observations):
    print(f"obs[{i}] = {obs}")

obs[0] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-right', {'tiger-left': 0.15, 'tiger-right': 0.85}), ({'tiger-left': 0.15, 'tiger-right': 0.85}, 'open-left', 'tiger-right', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[1] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'open-right', 'tiger-right', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[2] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-left', {'tiger-left': 0.85, 'tiger-right': 0.15}), ({'tiger-left': 0.85, 'tiger-right': 0.15}, 'listen', 'tiger-left', {'tiger-left': 0.9697986577181208, 'tiger-right': 0.0302013422818792}), ({'tiger-left': 0.9697986577181208, 'tiger-right': 0.0302013422818792}, 'open-right', 'tiger-left', {'tiger-left': 0.5, 'tiger-right': 0.5})]
obs[3] = [({'tiger-left': 0.5, 'tiger-right': 0.5}, 'listen', 'tiger-right', {'tiger-left': 0.15, 'tiger-right': 0.85}), ({'tiger-left': 0.15, 'tiger-right': 0.85}, 'listen', 'tiger-left', {'tiger-left': 0.5, 'tiger-right': 0.5}), ({'tiger-

In [None]:
def calc_belief_state_general(noise, obs_prob, transition_prob, belief_state):
    b_s = {}
    for state_prime in States:
        prob_unfactored = 0.0
        for state, prob in belief_state.items():
            prob_unfactored += (transition_prob * prob)
        b_s[state_prime] = obs_prob * prob_unfactored
    norm_factor = sum(b_s.values())
    norm_b_s = {key: value / norm_factor for key, value in b_s.items()} # Normalizing the Belief State
    return norm_b_s

In [None]:
def calc_belief_state_transition(action, p_observation, transition_prob, belief_state, belief_state_prime):
    bs_prime = p_observation * jnp.dot(transition_prob, belief_state)
    normalized_bs_prime = bs_prime / jnp.sum(bs_prime)
    accuracy = jnp.all(jnp.abs(normalized_bs_prime - belief_state_prime) <= 0.1) & (jnp.shape(normalized_bs_prime)[0] == jnp.shape(belief_state_prime)[0])
    return normalized_bs_prime, accuracy

In [None]:
def tiger_model_transition(obs=None):
    """
    Args:
    obs: observations - logs of episodes.
    obs is a lists of lists (list of episode logs), [[...], [...], [...]],
    each list in the obs list contains: [(belief-state_0, action_0, observation_0, belief-state_1),
                                         (belief-state_1, action_1, observation_1, belief-state_2),
                                         ...,
                                         (belief-state_n,-1 action_n-1, observation_n-1, belief-state_n)]
    """
    p_transitions = {}
    p_observations = {}
    noise = numpyro.sample(f"noise", dist.Beta(1, 1))
    for state,action in product(States, Actions):
        p_transitions[(state,action)] = {}
        for state_prime in States:
             p_transitions[(state,action)][state_prime] = numpyro.sample(f"T({str(state)},{str(action)},{str(state_prime)})", dist.Beta(1, 1)) # p_transitions = {('tiger-left','open-left'): {'tiger-left': sample1,
                                                                                                                                               #                   'tiger-right': sample2},
                                                                                                                                               #                  ('tiger-left','open-right'): {'tiger-left': sample3,
                                                                                                                                               #                   'tiger-right': sample4},...}
    if obs is not None:
        n_obs = sum(len(o) for o in obs)
        transitions_arr = jnp.array([[list(p_transitions[state, action].values()) for state in bs] for o in obs for bs, action, _, _ in o])
        observations_arr = jnp.array([[Observation_Model[(sp, action)][observation] for sp in belief_state_prime] for o in obs for _, _, observation, belief_state_prime in o])
        actions = [action for o in obs for _, action, _, _ in o]
        observations = [observation for o in obs for _, _, observation, _ in o]
        belief_states = jnp.array([list(bs.values()) for o in obs for bs, _, _, _ in o])
        belief_states_prime = jnp.array([list(belief_state_prime.values()) for o in obs for _, _, _, belief_state_prime in o])

        calculated_bs_primes, obs_success = zip(*[
            calc_belief_state_transition(action, p_observation, p_transition, belief_state, belief_state_prime)
            for action, p_observation, p_transition, belief_state, belief_state_prime in zip(
                actions, observations_arr, transitions_arr, belief_states, belief_states_prime)])

        calculated_bs_primes = jnp.array(calculated_bs_primes)
        obs_success = jnp.array(obs_success)

        # with numpyro.plate("obs", size=n_obs):
        #     belief_state_prime = numpyro.sample("belief_state_prime", dist.Dirichlet(concentration=jnp.ones((len(States),len(States))) * transitions_arr))
        #     obs_bs = jnp.where(jnp.all(jnp.abs(calculated_bs_primes - belief_state_prime) <= 0.1), 1, 0)
        #     d = observations_arr * numpyro.sample("d", dist.Dirichlet(concentration=jnp.ones((2,2)) * transitions_arr)) * belief_states
        #     weighted = observations_arr * transitions_arr * belief_states
        #     numpyro.sample("o", dist.Dirichlet(weighted), obs=belief_states_prime)
        for i in range(n_obs):
            n_states = len(States)
            p_transition = transitions_arr[i]
            p_observation = observations_arr[i]
            belief_state = belief_states[i]
            belief_state_prime = belief_states_prime[i]
            weighted = p_observation * jnp.ones(n_states) * p_transition * belief_state
            for j in range(n_states)
                action_outcome = numpyro.sample(f"action_outcome_{i}_{j}", dist.Bernoulli(p_transition[i][j]))
                success = ........
                numpyro.sample(f"o_{i}_{j}", dist.Bernoulli(), obs=)


SyntaxError: expected ':' (<ipython-input-8-05059175c570>, line 51)

In [None]:
numpyro.render_model(tiger_model_transition, model_args=(observations,), render_distributions=True, render_params=True,)

In [None]:
def inference(ai_model, obs):
    nuts_kernel = numpyro.infer.NUTS(ai_model)
    mcmc = numpyro.infer.MCMC(
        nuts_kernel,
        num_warmup=500,
        num_chains=4,
        num_samples=5000)
    mcmc.run(jax.random.PRNGKey(int(time.time() * 1E6)), obs=obs)
    mcmc.print_summary()
    return mcmc

In [None]:
inference(tiger_model, observations)