In [1]:
!pip3 install rl_util



In [15]:
import jax
import jax.numpy as jnp
import jax.random as random
from rl_util.environment import MarkovEnv
from rl_util.generator import simple_circle
from rl_util.test import test_policy
from rl_util.policy import StochasticPolicy, DeterministicPolicy

## Policy evaluation

In [3]:
def policy_evaluation(policy, markov, theta: float, phi: float):
    v = jnp.zeros(markov.state_space())
    delta = float('inf')
    while delta > theta:
        delta = 0
        for s in markov.states():
            old_v = v[s]
            v_s = 0.
            a = policy(s)
            for s_dot in markov.next_states(s):
                for r in markov.rewards(s_dot):
                    v_s += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
            v = v.at[s].set(v_s)
            delta = max(delta, abs(v_s - old_v))
    return v

## Policy improvement

In [8]:
def policy_improvement(markov, v, phi: float, policy=None):
    if policy is None:
        policy = DeterministicPolicy(state_space=markov.state_space(), action_space=markov.action_space())
    
    policy_stable = True
    for s in markov.states():
        old_action = policy(s)
        max_a, max_value = markov.actions(s)[0], float('-inf')
        for a in markov.actions(s):
            cur_value = 0
            for s_dot in markov.next_states(s):
                    for r in markov.rewards(s_dot):
                        cur_value += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
            if cur_value > max_value:
                max_a = a
                max_value = cur_value
        if old_action == max_a:
            continue
        policy.update(s, max_a)
        policy_stable = False
    return policy, policy_stable

## Policy iteration

In [9]:
def policy_iteration(markov, theta: float, phi: float):
    policy = DeterministicPolicy(state_space=markov.state_space(), action_space=markov.action_space())
    
    while True:
        v = policy_evaluation(policy, markov, theta, phi)
        policy, policy_stable = policy_improvement(markov, v, phi, policy)
        if policy_stable:
            return policy, v

In [10]:
theta = 0.9
phi = 0.99
state_space = 4
action_space = 2

markov = simple_circle(state_space=state_space, action_space=action_space)

In [11]:
policy, value = policy_iteration(markov, theta, phi)

In [13]:
value

DeviceArray([-1.,  0.,  0.,  0.], dtype=float32)

In [17]:
test_policy(markov, policy)

Finished in 2 steps, reward: -4.0


## Value iteration

In [30]:
def value_iteration(markov, theta: float):
    v = jnp.zeros(markov.state_space())
    delta = float('inf')
    
    transitions = {s: None for s in markov.states()}
    while delta > theta:
        delta = 0
        for s in markov.states():
            old_v = v[s]
            max_a, max_value = markov.actions(s)[0], float('-inf')
            for a in markov.actions(s):
                cur_value = 0
                for s_dot in markov.next_states(s):
                        for r in markov.rewards(s_dot):
                            cur_value += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
                if cur_value > max_value:
                    max_a = a
                    max_value = cur_value
            transitions[s] = max_a
            v = v.at[s].set(max_value)
            delta = max(delta, abs(old_v - v[s]))
    return v

In [32]:
value_iteration(markov, theta=1e-3)

DeviceArray([-1.,  0.,  0.,  0.], dtype=float32)