In [28]:
import numpy as np

class Environment:
    def __init__(self, num_states = 1000, branching_factor = 10):
        self.num_states = num_states
        self.branching_factor = branching_factor
        self.transitions1 = [np.random.choice(num_states, branching_factor, replace=False) for _ in range(num_states)]
        self.transitions2 = [np.random.choice(num_states, branching_factor, replace=False) for _ in range(num_states)]
        self.num_actions = 2
        self.reset_state()
        self.eps = 0.1
        
    def reset_state(self):
        self.state = 0
        
    def step(self, action):
        done = False
        idx = np.random.choice(self.branching_factor)
        if action == 0:
            self.state = self.transitions1[self.state][idx]
        else:
            self.state = self.transitions2[self.state][idx]
            
        if self.eps <= np.random.uniform(0, 1):
            done = True
            
        reward = np.random.normal(0, 1)
        return self.state, idx, reward, done


In [29]:
env = Environment()

In [30]:
class PlanningAgent:
    def __init__(self, num_states = 1000, branching_factor = 10):
        self.num_actions = 2
        self.model = np.zeros((num_states, self.num_actions, branching_factor, 2))
        self.Q = np.zeros((num_states, self.num_actions))
        self.branching_factor = branching_factor
        
    def get_action(self, state, eps=0.1):
        if eps <= np.random.uniform(0, 1):
            action = np.random.choice(self.num_actions)
        else:
            action = np.argmax(self.Q[state, ])
        return action
    
    def update_model(self, state, action, new_state_idx, reward):
        """
        Model keeps track of mean.
        """
        self.model[state, action, new_state_idx, 0] += 1
        k = self.model[state, action, new_state_idx, 0]
        self.model[state, action, new_state_idx, 1] += 1./k * (reward - self.model[state, action, new_state_idx, 1])
        
    def sample_model(self, state, action):
        new_state_idx = np.random.choice(self.branching_factor)
        mu = self.model[state, action, new_state_idx, 1]
        reward = np.random.normal(mu)
        return new_state_idx, reward
    
    def update_Q(self, state, action, new_state_idx, reward):
        self.Q[state, action] = 0