In [13]:
import random

from functools import partial
from gym.envs.toy_text import FrozenLakeEnv

from mctspy.tree import (
    DecisionNode, 
    MCTSSimulatorInterface, 
    MCTS, 
    ucb_action,
)

In [21]:
class FrozenLakeMCTS(MCTSSimulatorInterface):

    def __init__(self, env):
        self.env = env

    def step(self, state, action):
        self.env.s = state
        next_state, reward, *_ = self.env.step(action)
        
        return next_state, reward

    def state_is_terminal(self, state):
        return self.env.desc.flat[state] in (b"G", b"H")

    def enumerate_actions(self, state):
        return set(range(self.env.action_space.n))

    def get_initial_state(self):
        return self.env.reset()

In [22]:
def random_rollout_value(state, env: FrozenLakeMCTS):
    """ Rollout the environment till terminal state with random actions.
    """
    cummulative_reward = 0
    while not env.state_is_terminal(state):
        state, reward = env.step(
            state, 
            random.choice(tuple(env.enumerate_actions(state)))
        )
        cummulative_reward += reward

    return cummulative_reward

In [26]:
n_tests = 100
n_positive = 0

for _ in range(n_tests):
    env = FrozenLakeEnv(is_slippery=False, map_name="4x4")
    env = FrozenLakeMCTS(env)

    mcts = MCTS(env, ucb_action, partial(random_rollout_value, env=env), 1000)
    mcts_root = DecisionNode(env.get_initial_state(), 0, 0, {})

    # Build tree
    mcts.build_tree(mcts_root)

    # Get the best score in root Node
    best_score = max(chance_node.value / chance_node.visits for chance_node in mcts_root.children.values())

    # Compute scores for two optimal actions
    a1_score = mcts_root.children[1].value / mcts_root.children[1].visits
    a2_score = mcts_root.children[2].value / mcts_root.children[2].visits

    if a1_score == best_score or a2_score == best_score:
        n_positive += 1

n_positive / n_tests

0.71

In [27]:
n_tests = 100
n_positive = 0

for _ in range(n_tests):

    env = FrozenLakeEnv(is_slippery=False, map_name="4x4")
    env = FrozenLakeMCTS(env)

    state = env.get_initial_state()
    trajectory = [state]

    mcts = MCTS(env, ucb_action, partial(random_rollout_value, env=env), 1000)
    mcts_root = DecisionNode(state, 0, 0, {})
    
    current = mcts_root

    while not env.state_is_terminal(state):
        mcts.build_tree(current)

        action = max(
            (chance_node for chance_node in current.children.values()), 
            key=lambda chance_node: chance_node.value / chance_node.visits
        ).action

        state, reward = env.step(state, action)
        current = current.children[action].children[state]
        
        trajectory.append(state)

    if len(trajectory) == 7:
        n_positive += 1

n_positive / n_tests

0.66