In [1]:
import jax
import jax.numpy as jnp
import jax.random as random

## Markov environment

In [2]:
class Markov:
    def __init__(self, transitions: dict, seed: int = 11):
        self.transitions = transitions
        self.state_space = len(self.transitions)
        self.transition_probs = {}
        self.rewards = {}
        
        key = random.PRNGKey(seed)
        for state in self.transitions.keys():
            self.rewards[state] = random.normal((key:=random.split(key)[0]), (1,))
            
            self.transition_probs[state] = {}
            for action in self.transitions[state]:
                next_states = self.transitions[state][action]
                self.transition_probs[state][action] = {}
                state_probs = random.randint((key:=random.split(key)[0]), (len(next_states),), 1, 2 * len(next_states))
                state_probs = state_probs / state_probs.sum()
                for (next_state, state_prob) in zip(next_states, state_probs):
                    self.transition_probs[state][action][next_state] = state_prob   
    
    def states(self):
        return list(self.transitions.keys())
    
    def state_space(self):
        return self.state_space
    
    def actions(self, state):
        return list(self.transition_probs[state].keys())
    
    def next_states(self, state):
        states = []
        for action in self.transition_probs[state].keys():
            states = states + list(self.transition_probs[state][action].keys())
        return list(set(states))
    
    def p(self, state, action, _, next_state):
        return self.transition_probs.get(state, {action: {next_state: 0}}).get(action, {next_state:0}).get(next_state, 0)
    
    def rewards(self, state):
        return self.rewards[state]

## Policy functions

In [3]:
class Policy:
    def p(self, a, s):
        raise NotImplementedError()

class StochasticPolicy(Policy):
    def __init__(self, state_actions: dict, seed: int = 11):
        key = random.PRNGKey(seed)
        self.probs = {}
    
        for state in state_actions:
            self.probs[state] = {}
            actions = state_actions[state]
        
            action_probs = random.randint((key:=random.split(key)[0]), (len(actions),), 1, 2 * len(actions) + 1)
            action_probs = action_probs / action_probs.sum()
        
            for (action, prob) in zip(actions, action_probs):
                self.probs[state][action] = prob
    
    def p(self, a, s):
        return self.probs[s][a]
    
    def __call__(self, s):
        probs = self.probs[state]
        max_a = probs.keys()[0]
        for a in probs.keys():
            if probs[a] > probs[max_a]:
                max_a = a
        return max_a
    
class DeterministicPolicy(Policy):
    def __init__(self, transitions: dict):
        self.transitions = transitions
        
    def p(self, a, s):
        return 1. if self.transitions[a] == s else 0
    
    def __call__(self, s):
        return self.transitions[s]

## Policy evaluation

In [4]:
def policy_evaluation(policy, markov, theta: float, phi: float):
    v = jnp.zeros(markov.state_space)
    delta = float('inf')
    while delta > theta:
        delta = 0
        for s in markov.states():
            old_v = v[s]
            v_s = 0.
            a = policy(s)
            for s_dot in markov.next_states(s):
                for r in markov.rewards[s_dot]:
                    v_s += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
            v = v.at[s].set(v_s)
            delta = max(delta, abs(v_s - old_v))
    return v

## Policy improvement

In [5]:
def policy_improvement(markov, v, phi: float, policy=None):
    if policy is None:
        policy = DeterministicPolicy({state: markov.actions(state)[0] for state in markov.states()})
    
    policy_stable = True
    for s in markov.states():
        old_action = policy(s)
        max_a, max_value = markov.actions(s)[0], float('-inf')
        for a in markov.actions(s):
            cur_value = 0
            for s_dot in markov.next_states(s):
                    for r in markov.rewards[s_dot]:
                        cur_value += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
            if cur_value > max_value:
                max_a = a
                max_value = cur_value
        if old_action == max_a:
            continue
        policy.transitions[s] = max_a
        policy_stable = False
    return policy, policy_stable

## Policy iteration

In [6]:
def policy_iteration(markov, theta: float, phi: float):
    policy = DeterministicPolicy({state: markov.actions(state)[0] for state in markov.states()})
    
    while True:
        v = policy_evaluation(policy, markov, theta, phi)
        policy, policy_stable = policy_improvement(markov, v, phi, policy)
        if policy_stable:
            return policy, v

In [7]:
# s -> a -> s_dot
# -- 0 -> 1, 3 --
# 0 - > 1 -> 1
# 0 -> 3 -> 3
# -- 1 -> 0, 4 --
# 1 -> 1 -> 0
# 1 -> 2 -> 4
# -- 2 -> 0, 4 --
# 2 -> 3 -> 0
# 2 -> 2 -> 4
# -- 3 -> 2
# 3 -> 1 -> 2
# -- 4 -> 1, 3, 5(terminal)
# 4 -> 1 -> 1
# 4 -> 2 -> 5(terminal)

transitions = {
    0 : {1 : [1], 3: [3]},
    1 : {1: [0], 2: [1]}, 
    2 : {3: [0], 2: [4]}, 
    3 : {1: [2]}, 
    4 : {1: [1], 2: [5]}
}

markov = Markov(transitions)
markov.rewards = {0: [-1], 1: [-1], 2: [-2], 3: [-1], 4: [-1], 5: [0]}

markov.transition_probs



{0: {1: {1: DeviceArray(1., dtype=float32)},
  3: {3: DeviceArray(1., dtype=float32)}},
 1: {1: {0: DeviceArray(1., dtype=float32)},
  2: {1: DeviceArray(1., dtype=float32)}},
 2: {3: {0: DeviceArray(1., dtype=float32)},
  2: {4: DeviceArray(1., dtype=float32)}},
 3: {1: {2: DeviceArray(1., dtype=float32)}},
 4: {1: {1: DeviceArray(1., dtype=float32)},
  2: {5: DeviceArray(1., dtype=float32)}}}

In [8]:
theta = 0.9
phi = 0.99

policy, value = policy_iteration(markov, theta, phi)

In [9]:
policy.transitions

{0: 3, 1: 1, 2: 2, 3: 1, 4: 2}

In [10]:
value

DeviceArray([-3.9601  , -4.920499, -1.      , -2.99    ,  0.      ], dtype=float32)

In [16]:
terminal = 5
trajectories = {}

for start in markov.states():
    s = start
    trajectory = []
    while s != terminal:
        s = markov.transitions[s][policy(s)][0]
        trajectory.append(s)
    trajectories[start] = trajectory

trajectories

{0: [3, 2, 4, 5], 1: [0, 3, 2, 4, 5], 2: [4, 5], 3: [2, 4, 5], 4: [5]}

## Value iteration

In [12]:
def value_iteration(markov, theta: float):
    v = jnp.zeros(markov.state_space)
    delta = float('inf')
    
    transition = {s: None for s in markov.states()}
    while delta > theta:
        delta = 0
        for s in markov.states():
            old_v = v[s]
            max_a, max_val = markov.actions(s)[0], float('-inf')
            for a in markov.actions(s):
                cur_value = 0
                for s_dot in markov_state.next_states(s):
                        for r in markov.rewards(next_s):
                            cur_value += markov.p(s, a, r, s_dot) * (r + phi * v[s_dot])
                if cur_value > max_value:
                    max_a = a
                    max_value = cur_value
            transitions[s] = max_a
            v = v.at[s].set(max_value)
            delta = max(delta, abs(old_v - v[s]))
    return v