In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import copy
# custom imports
import utils
import train
import mcts
from stochastic_mcts import StochasticPVMCTS
from rtfm import featurizer as X
import os
from torch import multiprocessing as mp
import random

Using device cpu
Using device cpu


In [2]:
# Check only if main logic of the training loop works
ucb_C = 1.0
discount = 0.99 
episode_length = 32
num_simulations = 100
n_episodes = 100 #4000
memory_size = 1024
batch_size = 4 #32
tau = 0.1 # new_trg_params = (1-tau)*old_trg_params + tau*value_net_params
dir_noise = True
dirichlet_alpha = 0.5 # no real reason to choose this value, except it's < 1
exploration_fraction = 0.25
temperature = 0.
full_cross_entropy = False
entropy_bonus = False
entropy_weight = 1e-2

In [3]:
flags = utils.Flags(env="rtfm:groups_simple-v0")
gym_env = utils.create_env(flags)
featurizer = X.Render()
game_simulator = mcts.FullTrueSimulator(gym_env, featurizer)

In [4]:
pv_net = mcts.DiscreteSupportPVNet_v3(gym_env)

In [8]:
def play_rollout_pv_net_stochastic(
    pv_net,
    env,
    episode_length,
    ucb_C,
    discount,
    num_simulations,
    dirichlet_alpha, 
    exploration_fraction,
    temperature,
    render=False,
    debug_render=False,
):
    """
    Plays a rolllout with a policy and value MCTS. 
    Starts building the tree from the sub-tree of the root's child node that has been selected at the previous step.
    
    If mode='simulate', it's identical to a policy MCTS with MC rollout evaluations, if mode='predict', the value network 
    is used to estimate the value of the leaf nodes (instead of a MC rollout).
    
    Samples the next action based on the Q-values of the root node's children and returns both the MCTS policy and the list of 
    sampled actions as possible targets with which to train the policy network.
    
    Formula used for MCTS policy (softmax of Q-values with temperature):
    
    p(a) = exp{Q(a)/T} / \sum_b exp{Q(b)/T}

    Note: the softmax function with T=0 is the argmax function.
    
    This function is also mixing a prior sampled from a Dirichlet distribution (with parameters dirichlet_alpha for each 
    possible action) to the prior of the root node's children, in order to increase exploration at the base of the tree 
    even in cases where the policy is almost deterministic. The mixture coefficient between the prior and the categorical 
    distribution sampled by the Dirichelt distribution is the exploration_fraction, such that:
    
    p(a) = (1-exploration_fraction) Prior(a) + exploration_fraction Dir(a)
    
    Version v2: same as v1, but it's not re-using the old sub-tree in the new mcts step. 
    This has be done if we want to use the deterministic PV-MCTS as a baseline for the stochastic environment.
    As it is, this function it's not convinient to use in the deterministic setup (altough it can be useful to 
    study the search tree properties from scratch at every step).
    """
    
    A = len(env.env.action_space)
    action_dict = {
        0:"Stay",
        1:"Up",
        2:"Down",
        3:"Left",
        4:"Right"
    }
    frame, valid_actions = env.reset()
    if render:
        env.render()
    total_reward = 0
    done = False
    new_root = None
    # variables used for training of value net
    frame_lst = [frame]
    reward_lst = []
    done_lst = []
    action_lst = []
    probs_lst = []
    
    for i in range(episode_length):
        tree = StochasticPVMCTS(
            frame, 
            env, 
            valid_actions, 
            ucb_C, 
            discount, 
            pv_net,
            root=new_root,
            render=debug_render, 
            )
        
        root, info = tree.run(num_simulations,
                              dir_noise=True, 
                              dirichlet_alpha=dirichlet_alpha, 
                              exploration_fraction=exploration_fraction
                             )
        
        action, probs = root.softmax_Q(temperature)
        action_lst.append(action)
        probs_lst.append(probs)
        
        if render:
            print("Action selected from MCTS: ", action, "({})".format(action_dict[action]))

        frame, valid_actions, reward, done = env.step(action)
        
        frame_lst.append(frame)
        reward_lst.append(reward)
        done_lst.append(done)
        
        if render:
            env.render()
        total_reward += reward
        
        new_root = tree.get_subtree(action, frame)
        if new_root is not None and render:
            # amount of simulations that we are going to re-use in the next step:
            print("new_root.visit_count: ", new_root.visit_count) 
        if done:
            frame, valid_actions = env.reset()
            if render:
                print("Final reward: ", reward)
                print("\nNew episode begins.")
                env.render()
            done = False
            new_root = None


    return total_reward, frame_lst, reward_lst, done_lst, action_lst, probs_lst


In [9]:
results = play_rollout_pv_net_stochastic(
    pv_net,
    game_simulator,
    episode_length,
    ucb_C,
    discount,
    num_simulations,
    dirichlet_alpha, 
    exploration_fraction,
    temperature,
    render=False,
    debug_render=False,
)

In [10]:
total_reward, frame_lst, reward_lst, done_lst, action_lst, probs_lst = results

In [11]:
total_reward

4