In [15]:
import random
from math import sqrt, log
import uuid
import gym
from gym.envs.registration import register
import json

class Node:

    def __init__(self, state, action, action_space, reward, terminal):
        self.identifier = str(uuid.uuid1())
        self.parent_identifier = None
        self.children_identifiers = []
        self.untried_actions = list(range(action_space))
        self.state = state
        self.total_simulation_reward = 0
        self.num_visits = 0
        self.performance = 0
        self.action = action
        self.reward = reward
        self.terminal = terminal

    def __str__(self):
        return "{}: (action={}, visits={}, reward={:d}, ratio={:0.4f})".format(
                                                  self.state,
                                                  self.action,
                                                  self.num_visits,
                                                  int(self.total_simulation_reward),
                                                  self.performance)

    def untried_action(self):
        action = random.choice(self.untried_actions)
        self.untried_actions.remove(action)
        return action


def vertical_lines(last_node_flags):
    vertical_lines = []
    vertical_line = '\u2502'
    for last_node_flag in last_node_flags[0:-1]:
        if last_node_flag == False:
            vertical_lines.append(vertical_line + ' ' * 3)
        else:
            # space between vertical lines
            vertical_lines.append(' ' * 4)
    return ''.join(vertical_lines)

def horizontal_line(last_node_flags):
    horizontal_line = '\u251c\u2500\u2500 '
    horizontal_line_end = '\u2514\u2500\u2500 '
    if last_node_flags[-1]:
        return horizontal_line_end
    else:
        return horizontal_line

class Tree:

    def __init__(self):
        self.nodes = {}
        self.root = None

    def size(self):
        """Returns the number of nodes in the tree."""
        return len(self.nodes)

    def depth_info(self):
        """Calculates the maximum depth and average depth of the tree."""
        depths = []

        def calculate_depth(node, current_depth):
            depths.append(current_depth)
            for child_id in node.children_identifiers:
                calculate_depth(self.nodes[child_id], current_depth + 1)

        if self.root is not None:
            calculate_depth(self.root, 0)

        if depths:
            max_depth = max(depths)
            avg_depth = sum(depths) / len(depths)
            return max_depth, avg_depth
        else:
            return 0, 0  # In case the tree is empty

    def is_expandable(self, node):
        if node.terminal:
            return False
        if len(node.untried_actions) > 0:
            return True
        return False

    def iter(self, identifier, depth, last_node_flags):
        if identifier is None:
            node = self.root
        else:
            node = self.nodes[identifier]

        if depth == 0:
            yield "", node
        else:
            yield vertical_lines(last_node_flags) + horizontal_line(last_node_flags), node

        children = [self.nodes[identifier] for identifier in node.children_identifiers]
        last_index = len(children) - 1

        depth += 1
        for index, child in enumerate(children):
            last_node_flags.append(index == last_index)
            for edge, node in self.iter(child.identifier, depth, last_node_flags):
                yield edge, node
            last_node_flags.pop()

    def add_node(self, node, parent=None):
        self.nodes.update({node.identifier: node})

        if parent is None:
            self.root = node
            self.nodes[node.identifier].parent = None
        else:
            self.nodes[parent.identifier].children_identifiers.append(node.identifier)
            self.nodes[node.identifier].parent_identifier=parent.identifier

    def children(self, node):
        children = []
        for identifier in self.nodes[node.identifier].children_identifiers:
            children.append(self.nodes[identifier])
        return children

    def parent(self, node):
        parent_identifier = self.nodes[node.identifier].parent_identifier
        if parent_identifier is None:
            return None
        else:
            return self.nodes[parent_identifier]

    def show(self):
        lines = ""
        for edge, node in self.iter(identifier=None, depth=0, last_node_flags=[]):
            lines += "{}{}\n".format(edge, node)
        print(lines)


In [16]:
class MonteCarloTreeSearch:
    def __init__(self, env, V, initial_state=None, max_steps=2):
        self.env = env
        self.V = V  # Value matrix
        self.action_space = self.env.action_space.n
        self.max_steps = max_steps  # Maximum steps to simulate forward
        self.initial_state = initial_state
        self.reset_tree()

    def reset_tree(self):
        """Reset the tree with the given state as the root."""
        if self.initial_state is None:
            state = self.env.reset()[0]
        else:
            self.env.reset()  # Reset the environment to the start.
            self.env.unwrapped.s = self.initial_state  # Manually set the state to initial_state.
            state = self.initial_state
        
        root_node = Node(state=state, action=None, action_space=self.action_space, reward=0, terminal=False)
        self.tree = Tree()
        self.tree.add_node(root_node)
        print("Root node:", root_node)

    def expand(self, node, action):
        """Expand a given node by taking an action."""
        previous_state = self.env.unwrapped.s
        
        state, reward, done, _, _ = self.env.step(action)
        new_node = Node(state=state, action=action, action_space=self.action_space, reward=self.V[state], terminal=done)
        self.tree.add_node(new_node, node)

        self.env.unwrapped.s = previous_state

        return new_node

    def simulate(self, node):
        return self.V[node.state]


    def backpropagate(self, node, reward):
        """Backpropagate the simulation results up the tree."""
        while node:
            node.num_visits += 1
            node.total_simulation_reward += reward
            node.performance = node.total_simulation_reward / node.num_visits
            node.reward = (node.reward + 0.99 * reward)/ node.num_visits # Discount reward for parent
            node = self.tree.parent(node)

    def build_depth_n_tree(self):
        """Expand the tree to depth max_steps and perform simulations."""
        root_node = self.tree.root
        
        for action in range(self.action_space):
            child_node = self.expand(root_node, action)
            print(f"Expanded root node with action {action}, resulting in state {child_node.state}")
            
            if not child_node.terminal:
                for second_action in range(self.action_space):
                    grandchild_node = self.expand(child_node, second_action)
                    print(f"  Expanded child node with action {second_action}, resulting in state {grandchild_node.state}")
                    
                    simulation_reward = self.simulate(child_node)
                    print(f"  Simulated from state {child_node.state}, got reward {simulation_reward}")
                    
                    self.backpropagate(child_node, simulation_reward)
        
            self.env.unwrapped.s = root_node.state  # Reset the environment to the root state
        self.tree.show()  # Visualize the tree after building

    def forward(self):
        """Perform a single iteration of MCTS with depth-n simulation as default policy."""
        self.build_depth_n_tree()
        root_node = self.tree.root
        best_child = max(self.tree.children(root_node), key=lambda n: n.reward, default=None)
        if best_child:
            return best_child.action, best_child.state
        else:
            return None, None

    def run(self):
        """Run a single iteration, now reflecting the deeper default policy."""
        self.forward()

    def choose_best_action(self):
        """After running the MCTS, choose the best action based on the highest value of the final state."""
        children = self.tree.children(self.tree.root)
        
        # Check if any child has the state equal to 15
        for child in children:
            if child.state == 15:
                return child.action, child.state
        
        # If no child state is 15, choose the one with the highest performance
        best_child = max(children, key=lambda n: n.performance, default=None)
        
        if best_child:
            return best_child.action, best_child.state
        else:
            return None, None



In [18]:
class MonteCarloTreeSearchUCT:
    def __init__(self, env, V, initial_state=None, max_steps=2, c=1.41):
        self.env = env
        self.V = V  # Value matrix
        self.action_space = self.env.action_space.n
        self.max_steps = max_steps  # Maximum steps to simulate forward
        self.initial_state = initial_state
        self.c = c  # UCB exploration parameter
        self.reset_tree()

    def reset_tree(self):
        """Reset the tree with the given state as the root."""
        if self.initial_state is None:
            state = self.env.reset()[0]
        else:
            self.env.reset()  # Reset the environment to the start.
            self.env.unwrapped.s = self.initial_state  # Manually set the state to initial_state.
            state = self.initial_state
        
        root_node = Node(state=state, action=None, action_space=self.action_space, reward=0, terminal=False)
        self.tree = Tree()
        self.tree.add_node(root_node)
        print("Root node:", root_node)

    def ucb1(self, node):
        """Calculate the UCB1 value for a given node."""
        if node.num_visits == 0:
            return float('inf')  # Encourage exploration of unvisited nodes
        exploitation = node.performance
        parent_node = self.tree.parent(node)  # Retrieve the parent node
        if parent_node:
            exploration = self.c * math.sqrt(math.log(parent_node.num_visits) / node.num_visits)
        else:
            exploration = 0  # If there is no parent, exploration is not needed
        return exploitation + exploration


    def select(self, node):
        """Select a child node using the UCB1 formula."""
        children = self.tree.children(node)
        return max(children, key=self.ucb1, default=None)

    def expand(self, node, action):
        """Expand a given node by taking an action."""
        previous_state = self.env.unwrapped.s
        
        state, reward, done, _, _ = self.env.step(action)
        new_node = Node(state=state, action=action, action_space=self.action_space, reward=self.V[state], terminal=done)
        self.tree.add_node(new_node, node)

        self.env.unwrapped.s = previous_state

        return new_node

    def simulate(self, node):
        """Simulate a random play to the end of the game."""
        return self.V[node.state]

    def backpropagate(self, node, reward):
        """Backpropagate the simulation results up the tree."""
        while node:
            node.num_visits += 1
            node.total_simulation_reward += reward
            node.performance = node.total_simulation_reward / node.num_visits
            node.reward = (node.reward + 0.99 * reward)/ node.num_visits  # Discount reward for parent
            node = self.tree.parent(node)

    def build_depth_n_tree(self):
        """Expand the tree to depth max_steps and perform simulations."""
        root_node = self.tree.root
        
        for _ in range(self.max_steps):
            selected_node = root_node
            # Select the best child node using UCB1 until we reach a leaf
            while selected_node and self.tree.children(selected_node):
                selected_node = self.select(selected_node)

            if selected_node and not selected_node.terminal:
                for action in range(self.action_space):
                    child_node = self.expand(selected_node, action)
                    print(f"Expanded node with action {action}, resulting in state {child_node.state}")
                    
                    if not child_node.terminal:
                        simulation_reward = self.simulate(child_node)
                        print(f"Simulated from state {child_node.state}, got reward {simulation_reward}")
                        self.backpropagate(child_node, simulation_reward)
                        
                    self.env.unwrapped.s = root_node.state  # Reset the environment to the root state

    def forward(self):
        """Perform a single iteration of MCTS with depth-n simulation as default policy."""
        self.build_depth_n_tree()
        root_node = self.tree.root
        best_child = max(self.tree.children(root_node), key=lambda n: n.reward, default=None)
        if best_child:
            return best_child.action, best_child.state
        else:
            return None, None

    def run(self):
        """Run a single iteration, now reflecting the deeper default policy."""
        self.forward()

    def choose_best_action(self):
        """After running the MCTS, choose the best action based on the highest value of the final state."""
        children = self.tree.children(self.tree.root)
        
        # Check if any child has the state equal to 15
        for child in children:
            if child.state == 15:
                return child.action, child.state
        
        # If no child state is 15, choose the one with the highest performance
        best_child = max(children, key=lambda n: n.performance, default=None)
        
        if best_child:
            return best_child.action, best_child.state
        else:
            return None, None