In [12]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy

# --- Node class definition ---
class Node:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value = 0.0

    def is_leaf(self):
        return len(self.children) == 0

    def expand(self, env):
        for action in range(env.action_space.n):
            env_copy = copy_env(env)
            env_copy.reset(seed=None)
            obs, reward, terminated, truncated, _ = env_copy.step(action)
            child = Node(state=obs, parent=self, action=action)
            self.children.append(child)

    def uct_score(self, exploration=1.41):
        if self.visits == 0:
            return float("inf")
        avg_value = self.value / self.visits
        return avg_value + exploration * math.sqrt(math.log(self.parent.visits) / self.visits)


# --- Helper functions ---
def copy_env(env):
    # NOTE: Gymnasium does not support true deep copies.
    # For FrozenLake, we simulate a "copy" by re-initializing.
    return gym.make("FrozenLake-v1", is_slippery=False)

def unravel_state(state_index):
    return (state_index % 4, state_index // 4)

def build_visit_map(root):
    visit_map = np.zeros((4, 4))
    def fill(node):
        x, y = unravel_state(node.state)
        visit_map[y][x] += node.visits
        for child in node.children:
            fill(child)
    fill(root)
    return visit_map

def plot_heatmap(visit_map):
    plt.imshow(visit_map, cmap='viridis', origin='lower')
    plt.colorbar(label='Visits')
    plt.title("MCTS State Visit Heatmap")
    plt.xticks(range(4))
    plt.yticks(range(4))
    plt.show()

def print_tree(node, depth=0):
    print("  " * depth + f"State: {node.state}, Visits: {node.visits}, Value: {node.value:.2f}")
    for child in node.children:
        print_tree(child, depth + 1)


# --- MCTS Core Logic ---
def mcts(env, root_state, num_simulations):
    root = Node(state=root_state)
    for _ in range(num_simulations):
        node = tree_policy(root, env)
        reward = default_policy(node.state, env)
        backup(node, reward)
    return root, best_action(root)

def tree_policy(node, env):
    while True:
        if node.is_leaf():
            node.expand(env)
            return random.choice(node.children)
        else:
            node = best_uct_child(node)

def best_uct_child(node):
    return max(node.children, key=lambda child: child.uct_score())

def default_policy(state, env):
    env_copy = copy_env(env)
    env_copy.reset(seed=None)
    env_copy.unwrapped.s = state  # manually set state
    total_reward = 0
    for _ in range(10):
        action = env_copy.action_space.sample()
        obs, reward, terminated, truncated, _ = env_copy.step(action)
        total_reward += reward
        if terminated or truncated:
            break
    return total_reward

def backup(node, reward):
    while node is not None:
        node.visits += 1
        node.value += reward
        node = node.parent

def best_action(root):
    return max(root.children, key=lambda c: c.visits).action


# --- Run and Visualize ---
env = gym.make("FrozenLake-v1", render_mode="ansi", is_slippery=False)
initial_state = env.reset(seed=42)[0]

root, action = mcts(env, root_state=initial_state, num_simulations=100)

print(f"\nBest Action Chosen: {action}\n")
print("Tree:")
print_tree(root)
# visit_map = build_visit_map(root)
# plot_heatmap(visit_map)
frame = env.render()
print(frame)


Best Action Chosen: 1

Tree:
State: 0, Visits: 100, Value: 2.00
  State: 0, Visits: 24, Value: 0.00
    State: 0, Visits: 6, Value: 0.00
      State: 0, Visits: 2, Value: 0.00
        State: 0, Visits: 1, Value: 0.00
          State: 0, Visits: 0, Value: 0.00
          State: 4, Visits: 0, Value: 0.00
          State: 1, Visits: 0, Value: 0.00
          State: 0, Visits: 1, Value: 0.00
        State: 4, Visits: 0, Value: 0.00
        State: 1, Visits: 1, Value: 0.00
        State: 0, Visits: 0, Value: 0.00
      State: 4, Visits: 1, Value: 0.00
        State: 0, Visits: 0, Value: 0.00
        State: 4, Visits: 1, Value: 0.00
        State: 1, Visits: 0, Value: 0.00
        State: 0, Visits: 0, Value: 0.00
      State: 1, Visits: 1, Value: 0.00
      State: 0, Visits: 1, Value: 0.00
        State: 0, Visits: 1, Value: 0.00
        State: 4, Visits: 0, Value: 0.00
        State: 1, Visits: 0, Value: 0.00
        State: 0, Visits: 0, Value: 0.00
    State: 4, Visits: 6, Value: 0.00
     