psuedo code

In [None]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import mctx
import jax as jx
from optimizers import adamw
batchsize = 32
config = "some file"
class V_function():
    def __init__(self,config):
        self.num_hidden_units = 200
        self.num_hidden_layers = 3
        self.relu = nn.relu()
    def forward(self, obs):
      #obs::(board_h,board_w,action_num)
      x = np.flatten(obs)   #h*w*A
      for i in range(self.num_hidden_layers):
         x = self.relu(F.linear(self.num_hidden_units)(x))
      V = F.linear(1)(x)   #ex. [0.23]
      return V[0]   #ex. 0.23
    
class pi_function():
    def __init__(self,config,num_actions):
        self.num_hidden_units = 200
        self.num_hidden_layers = 3
        self.relu = nn.relu()
        self.num_actions = num_actions
    def forward(self, obs):
      #obs::(board_h,board_w,action_num)
      x = np.flatten(obs)   #h*w*A
      for i in range(self.num_hidden_layers):
         x = self.relu(F.linear(self.num_hidden_units)(x))
      pi_logit = F.linear(self.num_actions)(x)   #ex. [0.23]
      return pi_logit   #ex. 0.23

In [None]:
def get_recurrent_fn(env, V_func, pi_func):
    def recurrent_fn(actions, env_states):
        env_states, obs, rewards, terminals, _ = env.step(actions, env_states)   #batch_axes=(0,0)
        V = V_func(obs)     #batch_axes=(0)
        pi_logit = pi_func(obs)  #batch_axes=(0)
        recurrent_fn_output = mctx.RecurrentFnOutput(
            reward=rewards,
            discount=(1.0-terminals)*0.99,    #0.99
            prior_logits=pi_logit,
            value=V
        )
        return recurrent_fn_output, env_states
    return recurrent_fn

In [None]:
def init_fn(env,key):
    #* environment
    key, subkeys_batch = jx.random.split(key,batchsize);     
    env_states = env.reset(subkeys_batch)   #batched version state  env_states: [bs, ....]
    num_actions = env.num_actions()  #5
    #*  v_net 
    V_func =  V_function(config)
    #* p_net
    pi_func = pi_function(config,num_actions)

    #* v_optim
    optimV = V_opt_state,V_opt_update, get_V_params = adamw(V_func.parameters())
    #* p_optim
    optimP = pi_opt_state, pi_opt_update, get_pi_params = adamw(pi_func.parameters())
                        
    return env_states, V_func, pi_func,    optimV, optimP


In [None]:
import math.log as log
import tc.nn.softmax as softmax
  
class AC_loss():
    def __init__(self,pi_func, V_func,):
        self.pi_func = pi_func
        self.V_func = V_func
    def forward(self, pi_MCTS, V_MCTS, obs):
        pi_logits = self.pi_func(obs)
        V = self.V_func(obs)

        pi_loss = sum(pi_MCTS*(log(pi_MCTS/softmax(pi_logits))))   #entropyloss =  y*log(y/y_hat)
        V_loss = (V_MCTS-V)**2                         #  MSE

        return sum(pi_loss+V_loss)

In [None]:
time_step = 0
import jax_environments
key = jx.random.PRNGKey(0)

Environment = getattr(jax_environments, "ProcMaze")   #*ProcMaze   num_action == 5  
env_config = config.env_config 
env = Environment( grid_size=5, timeout=64)
num_actions = env.num_actions()   #5
key, subkey = jx.random.split(key)

env_states, V_func, pi_func,  optim_V, optim_P = init_fn(env,subkey)


get_V_params = optim_V.get_V_params;   V_opt_state   = optim_V.V_opt_state
V_target_params = V_func.parameters()

get_pi_params = optim_P.get_pi_params;  pi_opt_state = optim_P.pi_opt_state


recurrent_fn = get_recurrent_fn(env, V_func, pi_func)
iterations = 100

import jax.grad as grad
import functools

Ac = AC_loss(pi_func, V_func)
AC_loss_model = grad(Ac,arg=(Ac.pi_func.params,Ac.V_func.params))    #*   AC_Loss(.. .,pi_MCTS, V_target, obs    #batch_axis  (...,0,0,0)  

def agent_environment_interaction_loop_function(S):
    S["key"], subkey = jx.random.split(S["key"])
    for _ in range(iterations):
        obs = env.get_observation(S["env_states"])   #*  get observation in the real world
        #**  turn env_states = (goar,wall_grid,pos,t)  to  np.array
        #**  next env_state = env.step(env_states)
        pi_logits = pi_func(obs)  #pi(obs)   batch_axes=(0,)
        V = V_func(obs)  #V(obs)  batch_axes=(0,)

        root = mctx.RootFnOutput(
            prior_logits=pi_logits,
            value=V,
            embedding=S["env_states"]
        )

        S["key"], subkey = jx.random.split(S["key"])
        #** run simutions  32 times  and get a policy_output
        policy_output = mctx.gumbel_muzero_policy(
            params={"V":S["V_target_params"], "pi":get_pi_params(S["pi_opt_state"])},
            rng_key=subkey,
            root=root,
            recurrent_fn=recurrent_fn,
            num_simulations=config.num_simulations,   #10
            max_num_considered_actions=num_actions,   #5
            qtransform=functools.partial(
                mctx.qtransform_completed_by_mix_value,
                use_mixed_value=config.use_mixed_value,  #true
                value_scale=config.value_scale       #0.1
            ),
        )

        # tree search derived targets for policy and value function
        search_policy = policy_output.action_weights    #policy: policy_output.action_weights 
                                                        #action: policy_output.action
        search_value = policy_output.search_tree.qvalues(ROOT_INDEX=0)[policy_output.action]
        

        # compute loss gradient compared to tree search targets and update parameters
        #                               AC_Loss()        pi_params V_params, pi_target,V_target, obs
        pi_grads, V_grads = AC_loss_model(search_policy, search_value, obs)  #forward(self, pi_MCTS, V_MCTS, obs):
        S["pi_opt_state"] = optim_P(pi_grads,S["pi_opt_state"]); 
        S["V_opt_state"]  = optim_V(V_grads,S["V_opt_state"]); 

        # Update V target params after a particular number of parameter updates   
        S["opt_t"]+=1
        if S["opt_t"]%config.target_update_frequency == 0:
            S["V_target_params"] = get_V_params(S["V_opt_state"])

        # always take action recommended by tree search
        actions = policy_output.action

        # step the environment  #*  take real action in real environment
        S["env_states"], obs, reward, terminal, _ = env.step(actions, S["env_states"])   #batch_axes=(0,0)  env.step(action,env_state)

        # reset environment if terminated
        S["key"], subkeys = jx.random.split(S["key"])
        if terminal: 
            S["env_states"] = env.reset(subkeys)                  

        #*  reward  and return  in the real environment
        # update statistics for computing average return
        S["episode_return"] += reward
        if terminal:
            S["avg_return"] = S["avg_return"]*0.9+S["episode_return"]*0.1
            S["episode_return"] =  0
            S["num_episodes"] =  S["num_episodes"]+1
    return S


run_state = {"env_states":env_states, 
            "V_opt_state":V_opt_state,  "V_target_params":V_target_params, 
            "pi_opt_state":pi_opt_state, 
            "opt_t":0,  "avg_return":0, "episode_return":0, 
            "num_episodes":0, "key":key}
avg_returns = []
times = []

for i in range(config.num_steps//config.eval_frequency):
    # perform a number of iterations of agent environment interaction including learning updates
    run_state = agent_environment_interaction_loop_function(run_state)

    # avg_return is debiased, and only includes batch elements wit at least one completed episode so that it is more meaningful in early episodes
    valid_avg_returns = run_state["avg_return"][run_state["num_episodes"]>0]
    valid_num_episodes = run_state["num_episodes"][run_state["num_episodes"]>0]
    avg_return = np.mean(valid_avg_returns/(1-config.avg_return_smoothing**valid_num_episodes))
    print("Running Average Return: "+str(avg_return))
    avg_returns+=[avg_return]

    time_step+=config.eval_frequency
    times+=[time_step]