In [9]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy

In [29]:
# --- MCTS Node Class ---
class MCTSNode:
    def __init__(self, state, parent=None, action=None, make_env=None, verbose=False):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.make_env = make_env
        self.verbose = verbose

    def is_leaf(self):
        return len(self.children) == 0

    def uct(self, exploration=1.41):
        if self.visits == 0 or self.parent is None:
            return float("inf")
        exploitation = self.value / self.visits
        exploration_bonus = exploration * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration_bonus

    def best_uct_child(self):
        return max(self.children, key=lambda child: child.uct())

    def best_child(self):
        return max(self.children, key=lambda child: child.value)

    def selection(self, env):
        """Traverse the tree to select a promising node to expand."""
        node = self
        while not node.is_leaf():
            node = node.best_uct_child()
        if node.visits == 0:
            return node, None
        goal_node = node.expand(env)
        return (goal_node if goal_node else random.choice(node.children), goal_node is not None)

    def expand(self, env):
        """Expand the node by trying all possible actions."""
        for action in range(env.action_space.n):
            env_copy = self.make_env()
            env_copy.reset()
            env_copy.unwrapped.s = self.state
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            if reward == 0 and terminated or self.state == obs:
                continue

            child = MCTSNode(obs, parent=self, action=action, make_env=self.make_env, verbose=self.verbose)
            self.children.append(child)

            if reward == 1:
                if self.verbose:
                    print(f"Goal found from state {self.state} with action {action} → {obs}")
                return child

        if self.verbose:
            print(f"Expanded node {self.state} with children: {[c.state for c in self.children]}")
        return None

    def simulation(self, max_steps=10):
        """Perform a rollout from the current node using random actions."""
        env_copy = self.make_env()
        env_copy.reset()
        env_copy.unwrapped.s = self.state
        obs = self.state

        for _ in range(max_steps):
            action = env_copy.action_space.sample()
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            if self.verbose:
                print(f"Simulating from {self.state} → {obs} with action {action} reward {reward}")
                print(f"Simulation state {obs}, reward {reward}, terminated {terminated}, truncated {truncated}")

            if reward > 0 or terminated:
                return reward
        return 

    def backpropagation(self, reward):
        """Propagate the simulation result back up the tree."""
        node = self
        while node:
            node.visits += 1
            node.value += reward
            if self.verbose:
                print(f"Backprop node {node.state}, visits={node.visits}, value={node.value}")
            node = node.parent


# --- MCTS Search Class ---
class MCTS:
    def __init__(self, make_env, num_iterations=100, num_simulations=10, exploration=1.41, verbose=False):
        self.make_env = make_env
        self.env = make_env()
        self.root = MCTSNode(self.env.reset()[0], make_env=self.make_env, verbose=verbose)
        self.num_iterations = num_iterations
        self.num_simulations = num_simulations
        self.exploration = exploration
        self.verbose = verbose
        self.root.expand(self.env)

    def run(self):
        for _ in range(self.num_iterations):
            node, goal = self.root.selection(self.env)
            if goal:
                node.backpropagation(1)
                if self.verbose:
                    print(f"Goal reached at state {node.state}")
                break
            reward = node.simulation(max_steps=self.num_simulations)
            node.backpropagation(reward)
        if self.verbose:
            print(f"Finished {self.num_iterations} iterations.")

In [30]:
class ManipulationFrozenLakeWithMacros(gym.Env):
    def __init__(self, grid_size=4, agent_start=0, box_start=1, goal_agent=13, goal_box=15, holes=None):
        super().__init__()
        self.grid_size = grid_size
        self.n_states = grid_size * grid_size
        self.agent_start = agent_start
        self.box_start = box_start
        self.goal_agent = goal_agent
        self.goal_box = goal_box
        self.agent_pos = agent_start
        self.box_pos = box_start
        self.holding = False
        self.holes = holes if holes is not None else {5, 7, 11, 12}

        # Actions: 0=up, 1=down, 2=left, 3=right, 4=grab, 5=release, 6+=macros
        self.n_primitive_actions = 6
        self.n_macro_actions = 8  # You can add more macros here
        self.action_space = spaces.Discrete(self.n_primitive_actions + self.n_macro_actions)
        self.observation_space = spaces.MultiDiscrete([self.n_states, self.n_states, 2])

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.agent_pos = self.agent_start
        self.box_pos = self.box_start
        self.holding = False
        return self._get_obs(), {}

    def _get_obs(self):
        return (self.agent_pos, self.box_pos, int(self.holding))

    def _to_coord(self, pos):
        return divmod(pos, self.grid_size)

    def _to_index(self, row, col):
        return row * self.grid_size + col

    def _move(self, pos, action):
        row, col = self._to_coord(pos)
        if action == 0 and row > 0: row -= 1        # up
        elif action == 1 and row < self.grid_size - 1: row += 1  # down
        elif action == 2 and col > 0: col -= 1      # left
        elif action == 3 and col < self.grid_size - 1: col += 1  # right
        return self._to_index(row, col)

    def _apply_action(self, action):
        """Apply a primitive action and update the environment."""
        if action in [0, 1, 2, 3]:
            new_pos = self._move(self.agent_pos, action)
            if new_pos in self.holes:
                return self._get_obs(), 0.0, True, False, {}
            if self.holding:
                self.agent_pos = new_pos
                self.box_pos = new_pos
            else:
                self.agent_pos = new_pos
        elif action == 4:  # grab
            if self.agent_pos == self.box_pos:
                self.holding = True
        elif action == 5:  # release
            self.holding = False

        # Check for holes
        if self.agent_pos in self.holes or self.box_pos in self.holes:
            return self._get_obs(), 0.0, True, False, {}

        # Check for goal
        success = (
            self.agent_pos == self.goal_agent and
            self.box_pos == self.goal_box and
            not self.holding
        )
        
        if self.holding and self.box_pos != self.goal_box:
            reward = 0.1
        else:
            reward = 1.0 if success else 0.0
        return self._get_obs(), reward, success, False, {}

    def step(self, action):
        """Apply either a primitive or macro action."""
        if action < self.n_primitive_actions:
            return self._apply_action(action)

        # Macro actions
        macro_index = action - self.n_primitive_actions
        macro_actions = {
            0: [4, 3],      # grab, right
            1: [3, 5],      # right, release
            2: [0, 5],      # up, release
            3: [4, 1],      # grab, down
            4: [0, 0],      # up, up
            5: [3, 3],      # right, right
            6: [5, 2, 2],    # release, left, left
            7: [3, 4],  # right, grab
        }
        sequence = macro_actions.get(macro_index, [])
        for a in sequence:
            obs, reward, done, truncated, info = self._apply_action(a)
            if done:
                break
        return obs, reward, done, truncated, {}


In [32]:

# --- Main Execution ---
def make_env():
    return ManipulationFrozenLakeWithMacros()

# Run MCTS
mcts = MCTS(make_env=make_env, num_iterations=5000, num_simulations=1000, exploration=1.41, verbose=True)
mcts.run()

# NOTE: A solution is not found because the simulation does not yell rewards
# NOTE: We added small rewards to grabbing the box, but it results in a loop from the initial to the box grabbing state

# Visualise best path
env = make_env()
env.reset()
print("Initial state:")

trajectory = []
node = mcts.root
trajectory.append((node.state, node.action, node.value))
while not node.is_leaf():
    prev_node = node
    node = prev_node.best_child()
    trajectory.append((node.state, node.action, node.value))
    print(f"Best action from state {prev_node.state} to state {node.state} with value {node.value}")
    env.step(node.action)

Expanded node (0, 1, 0) with children: [(4, 1, 0), (1, 1, 0), (1, 1, 0), (1, 1, 0), (4, 1, 0), (2, 1, 0), (1, 1, 1)]
Simulating from (4, 1, 0) → (0, 1, 0) with action 8 reward 0.0
Simulation state (0, 1, 0), reward 0.0, terminated False, truncated False
Simulating from (4, 1, 0) → (0, 1, 0) with action 8 reward 0.0
Simulation state (0, 1, 0), reward 0.0, terminated False, truncated False
Simulating from (4, 1, 0) → (0, 1, 0) with action 12 reward 0.0
Simulation state (0, 1, 0), reward 0.0, terminated False, truncated False
Simulating from (4, 1, 0) → (0, 1, 0) with action 0 reward 0.0
Simulation state (0, 1, 0), reward 0.0, terminated False, truncated False
Simulating from (4, 1, 0) → (4, 1, 0) with action 1 reward 0.0
Simulation state (4, 1, 0), reward 0.0, terminated False, truncated False
Simulating from (4, 1, 0) → (4, 1, 0) with action 13 reward 0.0
Simulation state (4, 1, 0), reward 0.0, terminated True, truncated False
Backprop node (4, 1, 0), visits=1, value=0.0
Backprop node (

In [5]:
env = ManipulationEnv(goal_agent=2, goal_box=2)
obs, _ = env.reset()
done = False

while not done:
    action = env.action_space.sample()
    prev_obs = obs
    obs, reward, done, truncated, _ = env.step(action)
    
    print(f"Moved from {prev_obs} to {obs} with action {action} and reward {reward}")


Moved from (0, 1, 0) to (0, 1, 0) with action 5 and reward 0.0
Moved from (0, 1, 0) to (0, 1, 0) with action 5 and reward 0.0
Moved from (0, 1, 0) to (0, 1, 0) with action 0 and reward 0.0
Moved from (0, 1, 0) to (0, 1, 0) with action 4 and reward 0.0
Moved from (0, 1, 0) to (0, 1, 0) with action 0 and reward 0.0
Moved from (0, 1, 0) to (4, 1, 0) with action 1 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 2 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 2 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 5 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 4 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 2 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 5 and reward 0.0
Moved from (4, 1, 0) to (4, 1, 0) with action 3 and reward 0.0
