In [56]:
import numpy as np
import time

In [45]:
def step(env,action):
    """real interactions with the environment"""
    real_state = np.random.randn(10,5,5)
    reward = np.random.uniform(0,1)
    return real_state, reward

In [46]:
def representation(Observations):
    """Turns a set of previous observations into a state embedding representation"""
    return np.mean(Observations,axis=1)

In [52]:
def predict_P_and_V(state0): 
    """Predicts P and V for this step - used purely for training"""
    P = np.random.randn(5) * state0
    V = np.max(state0) * np.random.uniform(0,1)
    return P, V

In [48]:
def dynamics(state0, action):
    """Takes a state0 and action pair, and predicts the next hidden state representation and immedate reward"""
    state1 = np.random.uniform(0,1) * state0 * np.log(action+2)
    reward1 = np.random.randn(1)
    return state1, reward1

In [51]:
def training(t, observations, rewards, future_rewards, true_policies, actions, k):
    ### Loss functions useds
    def mse(a,b):
        return (a-b)**2
    def centropy(a,b):
        return np.sum(a * np.log(b))
    
    """takes the obs up to time t"""
    """At each subsequent time state upTo k..."""
    """predict pi and v"""
    """Take real action, and predict r"""
    """compare these to true pis, true future_rewards and true r"""
    
    o = observations[:t] #all observations up to time t
    s = representation(o) #first representation of state at time t
    loss_p = 0
    loss_v = 0
    loss_r = 0
    for i in range(k):
        ##
        true_p = true_policies[t+k]
        true_fr = future_rewards[t+k]
        p, v = predict_P_and_V(s)
        loss_p += centropy(p,true_p)
        loss_v = mse(true_fr, v)
        s, r = dynamics(s, actions[t+k])
        loss_r += mse(r, rewards[t+k])

In [61]:
class Node:
    def __init__(self,parent,state):
        self.Qsa = 0
        self.Nsa = 0
        self.Psa = 0
        self.Rsa = 0
        self.Ssa = 0
        self.parent = parent
        self.children = []
        self.state = state
        
class MCTS:
    def __init__(self,params: dict):
        self.c1, self.c2, self.gamma = params
    
    def one_turn(self,root_node,time_limit=100):
        tn = time.time()
        self.nodes = []
        self.root_node = root_node 
        while time.time() < tn + time_limit:
            mcts_go(root_node)
        policy, chosen_action = randomly_sample_action(self.root_node)
        return policy, chosen_action
        
    def mcts_go(self,node):
        if len(node.children) == 0:
            self.expand(node)
        else:
            best_ucb_child = self.pick_node_to_expand(node)
            mcts_go(best_ucb_child)
                
    def expand(self,node):
        """You've reached a terminal node. Backpropogate the rewards and expand the node."""
        prob_action,V = predict_P_and_V(node.state)
        self.back_prop_rewards(node,V)
        ## Add a child node for each action of this node.
        for edge in actions:
            state, _ = dynamics(node.state,actions)
            new_node = Node(state, node)
            new_node.Psa = prob_action[edge] #set its probability according to the action index from the π calculated for the whole parent state
            self.nodes.append(new_node.copy())
        
    def pick_node_to_expand(self,node):
        return np.argmax([UCB_calc(x) for x in node.children])
    
    def UCB_calc(self,node):        
        Q = node.Qsa
        policy_and_novelty_coef = node.Psa * np.sqrt(node.parent.Nsa) / (1+node.Nsa)
        muZeroModerator = self.c1 + np.log((node.parent.Nsa + self.c2 + c1+1)/self.c2)
        return Q + policy_and_novelty_coef * muZeroModerator
    
    def back_prop_rewards(self, V):
        """just send those rewards up the chain"""
    
    def randomly_sample_action(self,root_node):
        policy = np.array([x.Nsa for x in root_node.children])
        return policy, np.random.choice(policy)

class Episode:
    def __init__(self,params):
        self.params = params #c1, c2, gamma, max turns etc.
    
    def play_episode(self, env):
        metrics = {}
        for met in ['policy','action','obs','reward']:
            metrics[met] = []
        obs, _ = env.reset()
        metrics['obs'].append(obs)
        while True:
            state = representation(np.array(metrics['obs'])) #variable length array here, of t x OBS dimension
            root_node = Node(parent='null',state=state)
            mcts = MCTS(self.params)
            policy, action = mcts.one_turn(root_node)
            obs, reward, done = step(env, action)
            self.store_metrics(policy, action, reward, obs)
            if done == True or turn > self.params.turn_limit: 
                break #params for ending episode
        self.calculate_V_from_rewards() #using N step returns or whatever to calculate the returns.
        
    def store_metrics(self,policy, action, reward,obs):
        metrics['obs'].append(obs)
        metrics['policy'].append(policy)
        metrics['action'].append(action)
        metrics['reward'].append(reward)
    
    

In [54]:
env= 'k'
Observation0,real_reward = receive_state_and_reward(env,action=0)
state0 = representation(Observation0)
predict_P_and_V(state0)

(array([[-0.02382524,  0.12484224, -0.51097075, -0.0836743 , -0.22833026],
        [-0.09649986,  0.02579445, -0.20978821, -0.14017272,  0.32218067],
        [-0.35151722, -0.12334596, -0.4805248 ,  0.03702563,  0.84166641],
        [-1.45884413,  0.18921395, -0.05712818, -0.47003721, -0.20508899],
        [-0.29056937,  0.20483202, -0.36071345, -0.10936777,  0.91639969],
        [ 0.05280136,  0.18237211, -0.40537285, -0.11181813, -0.05391194],
        [ 0.66377449, -0.0242699 ,  0.6102202 , -0.0192993 ,  0.70139135],
        [-0.31561705,  0.19632571, -0.60327542, -0.04454043, -0.76430236],
        [ 0.21008685,  0.12271352,  0.49260207, -0.27185586,  0.5766538 ],
        [ 0.11965702,  0.12359691, -0.03796755,  0.0525345 ,  0.5754395 ]]),
 0.7276996333402753)