In [1]:
import uuid
import random
import gym
import numpy as np
from math import sqrt, log

class Node:
    def __init__(self, state, action, action_space, reward, terminal):
        self.identifier = str(uuid.uuid1())
        self.parent_identifier = None
        self.children_identifiers = []
        self.untried_actions = list(range(action_space))
        self.state = state
        self.total_simulation_reward = 0
        self.num_visits = 0
        self.performance = 0
        self.action = action
        self.reward = reward
        self.terminal = terminal

    def __str__(self):
        return "{}: (action={}, visits={}, reward={:d}, ratio={:0.4f})".format(
                                                  self.state,
                                                  self.action,
                                                  self.num_visits,
                                                  int(self.total_simulation_reward),
                                                  self.performance)

    def untried_action(self):
        action = random.choice(self.untried_actions)
        self.untried_actions.remove(action)
        return action

def vertical_lines(last_node_flags):
    vertical_lines = []
    vertical_line = '\u2502'
    for last_node_flag in last_node_flags[0:-1]:
        if last_node_flag == False:
            vertical_lines.append(vertical_line + ' ' * 3)
        else:
            vertical_lines.append(' ' * 4)
    return ''.join(vertical_lines)

def horizontal_line(last_node_flags):
    horizontal_line = '\u251c\u2500\u2500 '
    horizontal_line_end = '\u2514\u2500\u2500 '
    if last_node_flags[-1]:
        return horizontal_line_end
    else:
        return horizontal_line

class Tree:
    def __init__(self):
        self.nodes = {}
        self.root = None

    def is_expandable(self, node):
        if node.terminal:
            return False
        if len(node.untried_actions) > 0:
            return True
        return False

    def iter(self, identifier, depth, last_node_flags):
        if identifier is None:
            node = self.root
        else:
            node = self.nodes[identifier]

        if depth == 0:
            yield "", node
        else:
            yield vertical_lines(last_node_flags) + horizontal_line(last_node_flags), node

        children = [self.nodes[identifier] for identifier in node.children_identifiers]
        last_index = len(children) - 1

        depth += 1
        for index, child in enumerate(children):
            last_node_flags.append(index == last_index)
            for edge, node in self.iter(child.identifier, depth, last_node_flags):
                yield edge, node
            last_node_flags.pop()

    def add_node(self, node, parent=None):
        self.nodes.update({node.identifier: node})

        if parent is None:
            self.root = node
            self.nodes[node.identifier].parent = None
        else:
            self.nodes[parent.identifier].children_identifiers.append(node.identifier)
            self.nodes[node.identifier].parent_identifier = parent.identifier

    def children(self, node):
        children = []
        for identifier in self.nodes[node.identifier].children_identifiers:
            children.append(self.nodes[identifier])
        return children

    def parent(self, node):
        parent_identifier = self.nodes[node.identifier].parent_identifier
        if parent_identifier is None:
            return None
        else:
            return self.nodes[parent_identifier]

    def show(self):
        lines = ""
        for edge, node in self.iter(identifier=None, depth=0, last_node_flags=[]):
            lines += "{}{}\n".format(edge, node)
        print(lines)

    def render_policy(self):
        node = self.tree.root
        path = []
        print("Rendering final policy...\n")

        while node and not node.terminal:
            print(node)
            path.append(node.state)
            node = max(self.tree.children(node), key=lambda n: n.num_visits)
        
        if node:
            print(node)
            path.append(node.state)
        
        print("\nFinal policy path (states):", path)

class MonteCarloTreeSearch:
    def __init__(self, env, V, initial_state=None, max_steps=2):
        self.env = env
        self.V = V  # Value matrix
        self.action_space = self.env.action_space.n
        self.max_steps = max_steps  # Limit search depth to 2 steps
        self.initial_state = initial_state
        self.reset_tree()

    def reset_tree(self):
        """ Reset the tree with the given state as the root. """
        if self.initial_state is None:
            state = self.env.reset()[0]
        else:
            state = self.initial_state
        
        root_node = Node(state=state, action=None, action_space=self.action_space, reward=0, terminal=False)
        self.tree = Tree()
        self.tree.add_node(root_node)

    def expand(self, node):
        action = node.untried_action()
        state, _, done, _, _ = self.env.step(action)
        new_node = Node(state=state, action=action, action_space=self.action_space, reward=self.V[state], terminal=done)
        self.tree.add_node(new_node, node)
        return new_node

    def default_policy(self, node):
        return self.V[node.state]

    def tree_policy(self):
        node = self.tree.root
        depth = 0

        while not node.terminal and depth < self.max_steps:
            if self.tree.is_expandable(node):
                return self.expand(node)
            else:
                action = random.choice(self.tree.children(node)).action
                state, reward, done, _, _ = self.env.step(action)
                node.reward = self.V[state]
                node.state = state
                node.terminal = done
            depth += 1

        return node

    def backward(self, node, value):
        while node:
            node.num_visits += 1
            node.total_simulation_reward += value
            node.performance = node.total_simulation_reward / node.num_visits
            node = self.tree.parent(node)

    def forward(self):
        """ Perform a single iteration of MCTS. """
        leaf_node = self.tree_policy()
        simulation_result = self.default_policy(leaf_node)
        self.backward(leaf_node, simulation_result)
        

    def run(self):
        for i in range(2):
            self.forward()

    def choose_best_action(self):
        """ After running the MCTS, choose the best action based on performance. """
        best_child = max(self.tree.children(self.tree.root), key=lambda n: self.V[n.state], default=None)
        if best_child:
            return best_child.action, best_child.state
        else:
            return None, None

    def execute(self, max_iters_per_tree=2):
        path = []
        print("Starting MCTS with 2-step lookahead...\n")

        while True:
            self.run(max_iters_per_tree=1)
            action, next_state = self.choose_best_action()

            if action is None:
                print("No valid actions available, terminating.")
                break

            print(f"Chosen action: {action}, leads to state: {next_state}")
            path.append(next_state)

            # Apply the action in the environment and get the new state
            _, _, done, _, _ = self.env.step(action)
            # Rebuild the tree from the new state
            self.reset_tree()

            if done:
                print("Reached a terminal state, terminating.")
                break

        print("\nFinal path (states):", path)
        return path


In [2]:
import numpy as np

def compute_v_table_from_q_table(q_table_file):

    # Carica la Q-table dal file
    q_table = np.load(q_table_file)
    q_table = q_table[0]
    v_table = np.max(q_table, axis=1)
    np.save("V.npy", v_table)
    return v_table



In [3]:
import gym
import numpy as np

def main():
    env = gym.make('FrozenLake-v1', is_slippery=False)
    env2 = gym.make('FrozenLake-v1', is_slippery=False)

    # Load and compute the V-table from the Q-table
    q_table_file = 'Q.npy'  # Replace with the correct path to the Q-table file
    v_table = compute_v_table_from_q_table(q_table_file)
    
    print("V-table:", v_table)
    
    # Define the target state
    target_state = 15
    
    # Initialize starting state
    stateMCTS = env.reset()[0]
    print(f"Initial state: {stateMCTS}")
    state2 = env2.reset()[0]
    
    path = [stateMCTS]
    terminal_state_reached = False
    max_iterations_without_convergence = 10000
    iteration_count = 0

    while not terminal_state_reached and iteration_count < max_iterations_without_convergence:
      
        # Initialize Monte Carlo Tree Search with the current state
        monteCarloTreeSearch = MonteCarloTreeSearch(env=env, V=v_table, initial_state=stateMCTS)
        monteCarloTreeSearch.run()
        monteCarloTreeSearch.tree.show()
        print(f"\nBuilding tree\n")

        # Get the best action from the current state
        action, next_state = monteCarloTreeSearch.choose_best_action()

        if action is None:
            print("No valid actions found, stopping.")
            break

        print(f"Chosen action: {action}, leads to state: {next_state}")
        path.append(next_state)

        # Apply the action in the environment and get the new state
        state, reward, done, _, _ = env2.step(action)
        print(f"New state after action {action}: {state}, reward: {reward}, done: {done}")

        # Check if the new state is the terminal state (goal) or if it is a terminal state (falling into the lake)
        if state == target_state:
            print("Target state reached.")
            terminal_state_reached = True
        elif done:
            print("Fell into the lake, retrying.")
            stateMCTS = env.reset()[0] 
            print(f"New state after falling into the lake: {stateMCTS}")
            env2.reset( ) # Reset the environment and get the new initial state
            path = [stateMCTS]  # Reset the path for new attempt
            iteration_count = 0  # Reset the iteration count for new attempt
        else:
            iteration_count += 1

    # Clear path if target state is not reached
    if not terminal_state_reached:
        path = []

    print("\nFinal path (states):", path)

if __name__ == "__main__":
    main()


V-table: [0.95099005 0.96059601 0.97029794 0.96059601 0.96059601 0.
 0.9801     0.         0.970299   0.9801     0.99       0.
 0.         0.99       1.         0.        ]
Initial state: 0
0: (action=None, visits=2, reward=1, ratio=0.9510)
├── 0: (action=0, visits=1, reward=0, ratio=0.9510)
└── 0: (action=3, visits=1, reward=0, ratio=0.9510)


Building tree

Chosen action: 0, leads to state: 0
New state after action 0: 0, reward: 0.0, done: False
0: (action=None, visits=2, reward=1, ratio=0.9558)
├── 0: (action=3, visits=1, reward=0, ratio=0.9510)
└── 1: (action=2, visits=1, reward=0, ratio=0.9606)


Building tree

Chosen action: 2, leads to state: 1
New state after action 2: 1, reward: 0.0, done: False
0: (action=None, visits=2, reward=0, ratio=0.0000)
├── 5: (action=1, visits=1, reward=0, ratio=0.0000)
└── 5: (action=0, visits=1, reward=0, ratio=0.0000)


Building tree

Chosen action: 1, leads to state: 5
New state after action 1: 5, reward: 0.0, done: True
Fell into the lake, retry