In [10]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy

# MCTS implementation

In [26]:
# --- MCTS Node Class ---
class MCTSNode:
    def __init__(self, state, parent=None, action=None, make_env=None, verbose=False):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.make_env = make_env
        self.verbose = verbose

    def is_leaf(self):
        return len(self.children) == 0

    def uct(self, exploration=1.41):
        if self.visits == 0 or self.parent is None:
            return float("inf")
        exploitation = self.value / self.visits
        exploration_bonus = exploration * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration_bonus

    def best_uct_child(self):
        return max(self.children, key=lambda child: child.uct())

    def best_child(self):
        return max(self.children, key=lambda child: child.value)

    def selection(self, env):
        """Traverse the tree to select a promising node to expand.
        
        Returns:
            - node  (MCTSNode): The selected node to expand.
            - is_goal (bool): True if the node is a goal state.
        """
        node = self
        while not node.is_leaf():
            node = node.best_uct_child()
            if self.verbose:
                print(f"Selected node {node.state} with visits {node.visits} and value {node.value}")
            
        if node.visits == 0:
            return node, None
        goal_node = node.expand(env)
        return (goal_node if goal_node else random.choice(node.children), goal_node is not None)

    def expand(self, env):
        """Expand the node by trying all possible actions."""
        for action in range(env.action_space.n):
            env_copy = self.make_env()
            env_copy.reset()
            env_copy.unwrapped.s = self.state
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            # Only add child if it is not a (non successful) terminal state or if it is not the same state
            # if reward == 0 and terminated or self.state == obs:
            #     continue

            child = MCTSNode(obs, parent=self, action=action, make_env=self.make_env, verbose=self.verbose)
            self.children.append(child)

            if reward == 1:
                if self.verbose:
                    print(f"Goal found from state {self.state} with action {action} → {obs}")
                return child

        if self.verbose:
            print(f"Expanded node {self.state} with children: {[c.state for c in self.children]}")
        return None

    def simulation(self, max_steps=10):
        """Perform a rollout from the current node using random actions."""
        env_copy = self.make_env()
        env_copy.reset()
        env_copy.unwrapped.s = self.state
        obs = self.state

        for _ in range(max_steps):
            action = env_copy.action_space.sample()
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            if self.verbose:
                print(f"Simulating from {self.state} → {obs} with action {action} reward {reward}")

            if reward == 1 or terminated or truncated:
                return reward
        return 0

    def backpropagation(self, reward):
        """Propagate the simulation result back up the tree."""
        node = self
        while node:
            node.visits += 1
            node.value += reward
            if self.verbose:
                print(f"Backprop node {node.state}, visits={node.visits}, value={node.value}")
            node = node.parent


# --- MCTS Search Class ---
class MCTS:
    def __init__(self, make_env, num_iterations=100, num_simulations=10, exploration=1.41, verbose=False):
        self.make_env = make_env
        self.env = make_env()
        self.root = MCTSNode(self.env.reset()[0], make_env=self.make_env, verbose=verbose)
        self.num_iterations = num_iterations
        self.num_simulations = num_simulations
        self.exploration = exploration
        self.verbose = verbose
        self.root.expand(self.env)

    def run(self):
        for _ in range(self.num_iterations):
            node, is_goal = self.root.selection(self.env)
            if is_goal:
                node.backpropagation(1)
                if self.verbose:
                    print(f"Goal reached at state {node.state}")
                break
            reward = node.simulation(max_steps=self.num_simulations)
            node.backpropagation(reward)
        if self.verbose:
            print(f"Finished {self.num_iterations} iterations.")


# --- Main Execution ---
def make_env():
    return gym.make("FrozenLake-v1", is_slippery=False, render_mode="ansi")

# Run MCTS
mcts = MCTS(make_env=make_env, num_iterations=1000, num_simulations=100, exploration=1.41, verbose=True)
mcts.run()

# Visualise best path
env = make_env()
env.reset()
print("Initial state:")
print(env.render())

trajectory = []
node = mcts.root
trajectory.append((node.state, node.action, node.value))
while not node.is_leaf():
    prev_node = node
    node = prev_node.best_child()
    trajectory.append((node.state, node.action, node.value))
    print(f"Best action from state {prev_node.state} to state {node.state} with value {node.value}")
    env.step(node.action)
    print(env.render())

Expanded node 0 with children: [0, 4, 1, 0]
Selected node 0 with visits 0 and value 0.0
Simulating from 0 → 1 with action 2 reward 0.0
Simulating from 0 → 1 with action 3 reward 0.0
Simulating from 0 → 5 with action 1 reward 0.0
Backprop node 0, visits=1, value=0.0
Backprop node 0, visits=1, value=0.0
Selected node 4 with visits 0 and value 0.0
Simulating from 4 → 5 with action 2 reward 0.0
Backprop node 4, visits=1, value=0.0
Backprop node 0, visits=2, value=0.0
Selected node 1 with visits 0 and value 0.0
Simulating from 1 → 0 with action 0 reward 0.0
Simulating from 1 → 1 with action 2 reward 0.0
Simulating from 1 → 2 with action 2 reward 0.0
Simulating from 1 → 6 with action 1 reward 0.0
Simulating from 1 → 2 with action 3 reward 0.0
Simulating from 1 → 2 with action 3 reward 0.0
Simulating from 1 → 1 with action 0 reward 0.0
Simulating from 1 → 2 with action 2 reward 0.0
Simulating from 1 → 6 with action 1 reward 0.0
Simulating from 1 → 5 with action 0 reward 0.0
Backprop node 1, v

Expanded node 8 with children: [8, 12, 9, 4]
Simulating from 8 → 8 with action 0 reward 0.0
Simulating from 8 → 8 with action 0 reward 0.0
Simulating from 8 → 8 with action 0 reward 0.0
Simulating from 8 → 12 with action 1 reward 0.0
Backprop node 8, visits=1, value=0.0
Backprop node 8, visits=2, value=0.0
Backprop node 8, visits=6, value=0.0
Backprop node 4, visits=23, value=0.0
Backprop node 0, visits=94, value=1.0
Selected node 0 with visits 22 and value 0.0
Selected node 4 with visits 5 and value 0.0
Selected node 4 with visits 1 and value 0.0
Expanded node 4 with children: [4, 8, 5, 0]
Simulating from 8 → 4 with action 3 reward 0.0
Simulating from 8 → 5 with action 2 reward 0.0
Backprop node 8, visits=1, value=0.0
Backprop node 4, visits=2, value=0.0
Backprop node 4, visits=6, value=0.0
Backprop node 0, visits=23, value=0.0
Backprop node 0, visits=95, value=1.0
Selected node 1 with visits 26 and value 1.0
Selected node 2 with visits 7 and value 1.0
Selected node 3 with visits 1 an

In [27]:
print(trajectory)

[(0, None, 12.0), (4, 1, 6.0), (8, 1, 4.0), (9, 2, 3.0), (13, 1, 3.0), (14, 2, 2.0), (15, 2, 1.0)]
