In [49]:
import jax
import jax.numpy as jnp
import random

In [50]:
class Markov:
    def __init__(self, transitions: dict, seed: int = 11):
        self.transitions = transitions
        self.state_space = len(self.transitions)
        self.transition_probs = {}
        self.rewards = {}
        
        key = random.PRNGKey(seed)
        for state in self.transitions.keys():
            self.rewards[state] = random.normal((key:=random.split(key)[0]), (1,))
            
            self.transition_probs[state] = {}
            for action in self.transitions[state]:
                next_states = self.transitions[state][action]
                self.transition_probs[state][action] = {}
                state_probs = random.randint((key:=random.split(key)[0]), (len(next_states),), 1, 2 * len(next_states))
                state_probs = state_probs / state_probs.sum()
                for (next_state, state_prob) in zip(next_states, state_probs):
                    self.transition_probs[state][action][next_state] = state_prob   
    
    def states(self):
        return list(self.transitions.keys())
    
    def state_space(self):
        return self.state_space
    
    def actions(self, state):
        return list(self.transition_probs[state].keys())
    
    def next_states(self, state):
        states = []
        for action in self.transition_probs[state].keys():
            states = states + list(self.transition_probs[state][action].keys())
        return list(set(states))
    
    def p(self, state, action, _, next_state):
        return self.transition_probs.get(state, {action: {next_state: 0}}).get(action, {next_state:0}).get(next_state, 0)
    
    def rewards(self, state):
        return self.rewards[state]
    
class Policy:
    def p(self, a, s):
        raise NotImplementedError()

class StochasticPolicy(Policy):
    def __init__(self, state_actions: dict, seed: int = 11):
        key = random.PRNGKey(seed)
        self.probs = {}
    
        for state in state_actions:
            self.probs[state] = {}
            actions = state_actions[state]
        
            action_probs = random.randint((key:=random.split(key)[0]), (len(actions),), 1, 2 * len(actions) + 1)
            action_probs = action_probs / action_probs.sum()
        
            for (action, prob) in zip(actions, action_probs):
                self.probs[state][action] = prob
    
    def p(self, a, s):
        return self.probs[s][a]
    
    def __call__(self, s):
        probs = self.probs[state]
        max_a = probs.keys()[0]
        for a in probs.keys():
            if probs[a] > probs[max_a]:
                max_a = a
        return max_a
    
class DeterministicPolicy(Policy):
    def __init__(self, transitions: dict):
        self.transitions = transitions
        
    def p(self, a, s):
        return 1. if self.transitions[a] == s else 0
    
    def __call__(self, s):
        return self.transitions[s]

In [72]:
import random as std_rand

def random_episode_generator(episode_len=10, n_states=10, n_actions=3):
    def f(env, policy):
        return [{'state': std_rand.randrange(n_states), 
                 'action': std_rand.randrange(n_actions), 
                 'reward': std_rand.gauss(0, 1)} for _ in range(episode_len)]
    return f

class FakeEnv:
    def __init__(self, state_space):
        self.state_space = state_space
        
    def states(self):
        return self.state_space

# Monte Carlo prediction

In [75]:
def monte_carlo_prediction(phi, policy, env, iterations, episode_generator, first_visit=True):
    v = jnp.ones((len(env.states()), ))
    returns = [[] for _ in env.states()]
    
    for i in range(iterations):
        episode = episode_generator(env, policy)
        g = 0
        used_s = set()
        for T in range(len(episode)):
            for t in range(T - 1, -1, -1):
                g = phi * g + episode[t]['reward']
                s = episode[t]['state']
                if not first_visit or s not in used_s:
                    returns[s].append(g)
                    v = v.at[s].set(sum(returns[s]) / len(returns[s]))
    return v

In [76]:
monte_carlo_prediction(0.99, None, FakeEnv(list(range(10))), 10, random_episode_generator(10))

DeviceArray([-1.7783886 , -3.656394  , -0.37476203, -2.3032498 ,
             -1.9882085 , -5.425548  ,  0.12746167, -4.5953507 ,
             -2.2459898 , -3.2673235 ], dtype=float32)