# Versione per ogni grid size con UCT


In [1]:
import uuid
import random
from math import sqrt, log
import gym
import numpy as np

class Node:
    def __init__(self, state, action, action_space, reward, terminal):
        self.identifier = str(uuid.uuid1())
        self.parent_identifier = None
        self.children_identifiers = []
        self.untried_actions = list(range(action_space))
        self.state = state
        self.total_simulation_reward = 0
        self.num_visits = 0
        self.performance = 0
        self.action = action
        self.reward = reward
        self.terminal = terminal

    def __str__(self):
        return "{}: (action={}, visits={}, reward={:d}, ratio={:0.4f})".format(
            self.state,
            self.action,
            self.num_visits,
            int(self.total_simulation_reward),
            self.performance
        )

    def untried_action(self):
        action = random.choice(self.untried_actions)
        self.untried_actions.remove(action)
        return action

def vertical_lines(last_node_flags, node_spacing=3):
    vertical_lines = []
    vertical_line = '\u2502'
    spacing = ' ' * node_spacing

    for last_node_flag in last_node_flags[0:-1]:
        if last_node_flag:
            vertical_lines.append(' ' * (node_spacing + 1))
        else:
            vertical_lines.append(vertical_line + spacing)

    vertical_lines.append(vertical_line if not last_node_flags[-1] else ' ' * (node_spacing + 1))

    return ''.join(vertical_lines)

def horizontal_line(last_node_flags, node_spacing=3):
    horizontal_line = '\u251c' + '\u2500' * node_spacing
    horizontal_line_end = '\u2514' + '\u2500' * node_spacing

    if last_node_flags[-1]:
        return horizontal_line_end
    else:
        return horizontal_line

class Tree:
    def __init__(self):
        self.nodes = {}
        self.root = None
        self.env = None

    def is_expandable(self, node):
        if node.terminal:
            return False
        return len(node.untried_actions) > 0

    def iter(self, identifier, depth, last_node_flags):
        if identifier is None:
            node = self.root
        else:
            node = self.nodes[identifier]

        if depth == 0:
            yield "", node
        else:
            yield vertical_lines(last_node_flags) + horizontal_line(last_node_flags), node

        children = [self.nodes[identifier] for identifier in node.children_identifiers]
        last_index = len(children) - 1

        depth += 1
        for index, child in enumerate(children):
            last_node_flags.append(index == last_index)
            for edge, node in self.iter(child.identifier, depth, last_node_flags):
                yield edge, node
            last_node_flags.pop()

    def add_node(self, node, parent=None):
        self.nodes[node.identifier] = node

        if parent is None:
            self.root = node
            node.parent_identifier = None
        else:
            self.nodes[parent.identifier].children_identifiers.append(node.identifier)
            node.parent_identifier = parent.identifier

    def children(self, node):
        return [self.nodes[identifier] for identifier in node.children_identifiers]

    def parent(self, node):
        parent_identifier = node.parent_identifier
        return self.nodes[parent_identifier] if parent_identifier else None

    def show(self):
        lines = ""
        for edge, node in self.iter(identifier=None, depth=0, last_node_flags=[]):
            lines += "{}{}\n".format(edge, node)
        print(lines)

    def render_policy(self, map_size):
        node = self.root
        path = []
        directions = {0: 'Left', 1: 'Down', 2: 'Right', 3: 'Up'}

        print("Rendering final policy...\n")

        if self.env:
            self.env.close()
        self.env = gym.make('FrozenLake-v1', is_slippery=False, render_mode='human', map_name=map_size)
        self.env.reset()

        while node:
            self.env.render()
            path.append(node.state)

            if node.terminal:
                print("Reached terminal state.")
                break

            children = self.children(node)

            if not children:
                print("No children found. Ending policy rendering.")
                break
            
            best_child = max(children, key=lambda n: n.num_visits, default=None)
            
            if best_child is None:
                print("Best child is None. Ending policy rendering.")
                break
            
            if best_child.action is not None:
                print(f"Action: {directions[best_child.action]} -> State: {best_child.state}")
                state, reward, done, _ = self.env.step(best_child.action)
                if done:
                    print("Environment signaled done.")
                    break
            else:
                print("Best child has no action. Ending policy rendering.")
                break
            
            node = best_child

        self.env.render()
        self.env.close()
        print(f"\nFinal policy path (states): {path}")

        path2 = [item[0] if isinstance(item, tuple) else item for item in path]
        print(path2)
        return path2

class MonteCarloTreeSearch:
    def __init__(self, env, tree):
        self.env = env
        self.tree = tree
        self.action_space = self.env.action_space.n
        state = self.env.reset()
        self.tree.add_node(Node(state=state, action=None, action_space=self.action_space, reward=0, terminal=False))

    def expand(self, node):
        action = node.untried_action()
        state, reward, done, _, _ = self.env.step(action)
        new_node = Node(state=state, action=action, action_space=self.action_space, reward=reward, terminal=done)
        self.tree.add_node(new_node, node)
        return new_node

    def default_policy(self, node):
        if node.terminal:
            return node.reward

        state = node.state
        total_reward = 0
        while True:
            action = random.randint(0, self.action_space - 1)
            state, reward, done, _, _ = self.env.step(action)
            total_reward += reward
            if done:
                return total_reward

    def compute_value(self, parent, child, exploration_constant):
        exploitation_term = child.total_simulation_reward / child.num_visits if child.num_visits > 0 else 0
        exploration_term = exploration_constant * sqrt(2 * log(parent.num_visits) / child.num_visits) if child.num_visits > 0 else 0
        return exploitation_term + exploration_term

    def best_child(self, node, exploration_constant):
        children = self.tree.children(node)
        if not children:
            return None

        best_child = children[0]
        best_value = self.compute_value(node, best_child, exploration_constant)
        for child in children[1:]:
            value = self.compute_value(node, child, exploration_constant)
            if value > best_value:
                best_child = child
                best_value = value
        return best_child

    def tree_policy(self):
        node = self.tree.root
        while not node.terminal:
            if self.tree.is_expandable(node):
                return self.expand(node)
            else:
                node = self.best_child(node, exploration_constant=1.0 / sqrt(2.0))
                state, reward, done, _, _ = self.env.step(node.action)
                if done:
                    node.terminal = True  # Mark as terminal if done
                    return node
                assert node.state == state
        return node

    def backward(self, node, value):
        while node:
            node.num_visits += 1
            node.total_simulation_reward += value
            node.performance = node.total_simulation_reward / node.num_visits if node.num_visits > 0 else 0
            node = self.tree.parent(node)

    def forward(self):
        self._forward(self.tree.root)

    def _forward(self, node):
        best_child = self.best_child(node, exploration_constant=1.4)
        if best_child is None:
            return
        
        print("****** {} ******".format(best_child.state))
        for child in self.tree.children(best_child):
            print("{}: {:0.4f}".format(child.state, child.performance))
        if len(self.tree.children(best_child)) > 0:
            self._forward(best_child)

    def render_policy(self, map_size):
        node = self.tree.root
        path = []
        directions = {0: 'Left', 1: 'Down', 2: 'Right', 3: 'Up'}

        print("Rendering final policy...\n")

        if self.env:
            self.env.close()
        self.env = gym.make('FrozenLake-v1', is_slippery=False, render_mode='human', map_name=map_size)
        self.env.reset()

        while node:
            self.env.render()
            path.append(node.state)

            if node.terminal:
                print("Reached terminal state.")
                break

            children = self.tree.children(node)

            if not children:
                print("No children found. Ending policy rendering.")
                break
            
            best_child = max(children, key=lambda n: n.performance, default=None)
            
            if best_child is None:
                print("Best child is None. Ending policy rendering.")
                break
            
            if best_child.action is not None:
                print(f"Action: {directions[best_child.action]} -> State: {best_child.state}")
                state, reward, done, _, _ = self.env.step(best_child.action)
                if done:
                    print("Environment signaled done.")
                    break
            else:
                print("Best child has no action. Ending policy rendering.")
                break
            
            node = best_child

        self.env.render()
        self.env.close()
        print(f"\nFinal policy path (states): {path}")

        path2 = [item[0] if isinstance(item, tuple) else item for item in path]
        print(path2)
        return path2

def main():
    map_name = '8x8'
    env = gym.make('FrozenLake-v1', is_slippery=False, map_name=map_name)
    tree = Tree()
    monteCarloTreeSearch = MonteCarloTreeSearch(env=env, tree=tree)
    steps = 100000  # Increase the number of iterations

    for _ in range(steps):
        env.reset()
        node = monteCarloTreeSearch.tree_policy()
        reward = monteCarloTreeSearch.default_policy(node)
        monteCarloTreeSearch.backward(node, reward)

    tree.show()
    monteCarloTreeSearch.forward()
    monteCarloTreeSearch.render_policy(map_name)


    
if __name__ == "__main__":
    main()


(0, {'prob': 1}): (action=None, visits=100000, reward=183, ratio=0.0018)
│├───8: (action=1, visits=24930, reward=45, ratio=0.0018)
│   │├───8: (action=0, visits=6312, reward=13, ratio=0.0021)
│   │   │├───8: (action=0, visits=1672, reward=7, ratio=0.0042)
│   │   │   │├───0: (action=3, visits=422, reward=2, ratio=0.0047)
│   │   │   │   │├───1: (action=2, visits=101, reward=0, ratio=0.0000)
│   │   │   │   │   │├───1: (action=3, visits=25, reward=0, ratio=0.0000)
│   │   │   │   │   │   │├───2: (action=2, visits=6, reward=0, ratio=0.0000)
│   │   │   │   │   │   │   │├───3: (action=2, visits=2, reward=0, ratio=0.0000)
│   │   │   │   │   │   │   │       └───2: (action=0, visits=1, reward=0, ratio=0.0000)
│   │   │   │   │   │   │   │├───10: (action=1, visits=1, reward=0, ratio=0.0000)
│   │   │   │   │   │   │   │├───2: (action=3, visits=1, reward=0, ratio=0.0000)
│   │   │   │   │   │   │       └───1: (action=0, visits=1, reward=0, ratio=0.0000)
│   │   │   │   │   │   │├───0: (action