# Multi-Objective Grid Environment Experiments

This notebook demonstrates how to create a custom environment using the `gym-simplegrid` package with a 1000x1000 grid and 10 objectives. The agent's task is to visit all objectives, and the environment records the sequence of actions, the seed, obstacle locations, and the agent's position. For each method, (RE-BT:Espresso, BT Factor and ours) we run the results.

### Installation

In [None]:
%pip install gym-simplegrid

### Environment Definition

In [None]:

import gymnasium as gym
import numpy as np
import json
import time
import os
import csv
import matplotlib.pyplot as plt

from gym_simplegrid.envs import SimpleGridEnv

class MultiObjectiveGridEnv(SimpleGridEnv):
    def __init__(self, size=1000, num_objectives=10, render_mode=None):
        self.size = size
        self.num_objectives = num_objectives
        self.initial_objectives = []  # Store initial objectives
        self.objectives = []
        self.trajectory = []
        self.seed_value = None
        self.initial_agent_pos = None
        
        # Create an empty grid (all cells are walkable)
        obstacle_map = ["0" * size for _ in range(size)]
        
        # Initialize the parent class
        super().__init__(obstacle_map=obstacle_map, render_mode=render_mode)
        
        self.action_space = gym.spaces.Discrete(4)  # UP, DOWN, LEFT, RIGHT
        
        self.observation_space = gym.spaces.Dict({
            "agent": gym.spaces.Box(0, size - 1, shape=(2,), dtype=int),
            "objectives": gym.spaces.Box(0, size - 1, shape=(num_objectives, 2), dtype=int)
        })
    
    def reset(self, seed=None, options=None):
        self.seed_value = seed
        np.random.seed(seed)
        
        self.trajectory = []
        
        # Place agent at a random position
        self.agent_xy = (np.random.randint(0, self.size), np.random.randint(0, self.size))
        self.initial_agent_pos = self.agent_xy
        
        # Set the required attributes for rendering
        self.start_xy = self.agent_xy  # Set start_xy to agent's initial position
        
        # Place objectives at random positions
        self.objectives = []
        while len(self.objectives) < self.num_objectives:
            obj_pos = (np.random.randint(0, self.size), np.random.randint(0, self.size))
            if obj_pos != self.agent_xy and obj_pos not in self.objectives:
                self.objectives.append(obj_pos)
        
        # Set goal_xy to the first objective
        self.goal_xy = self.objectives[0] if self.objectives else (0, 0)
        
        # Initialize iteration counter
        self.n_iter = 0
        
        # Store initial objectives
        self.initial_objectives = self.objectives.copy()
        
        observation = self._get_obs()
        info = self._get_info()
    
        return observation, info

    
    def _get_obs(self):
        return {
            "agent": np.array(self.agent_xy),
            "objectives": np.array(self.objectives)
        }
    
    def _get_info(self):
        return {
            "remaining": len(self.objectives)
        }
    
    def step(self, action):
        self.trajectory.append(action)
        
        direction_map = {
            0: (0, 1),   # UP
            1: (0, -1),  # DOWN
            2: (-1, 0),  # LEFT
            3: (1, 0)    # RIGHT
        }
        direction = direction_map[action]
        
        new_x = self.agent_xy[0] + direction[0]
        new_y = self.agent_xy[1] + direction[1]
        new_pos = (new_x, new_y)
        
        if 0 <= new_x < self.size and 0 <= new_y < self.size:
            self.agent_xy = new_pos
        
        reward = 0
        objective_reached = False
        # Check if agent reached an objective and remove it
        for i, obj in enumerate(self.objectives):
            if self.agent_xy == obj:
                reward = 1
                objective_reached = True
                self.objectives.pop(i)
                break
        
        terminated = len(self.objectives) == 0
        truncated = False
        
        # Update rendering attributes
        self.n_iter += 1
        self.goal_xy = self.objectives[0] if self.objectives else None
        
        observation = self._get_obs()
        info = self._get_info()
        info['objective_reached'] = objective_reached
        
        return observation, reward, terminated, truncated, info

    
    def get_trajectory_data(self):
        labeled_objectives = {f"goal_{i+1:02d}": pos for i, pos in enumerate(self.initial_objectives)}
        presence_start = {label: 1 for label in labeled_objectives.keys()}
        presence_end = {label: 0 for label in labeled_objectives.keys()}
        for label, pos in labeled_objectives.items():
            if pos in self.objectives:
                presence_end[label] = 1
        return {
            "seed": self.seed_value,
            "agent_start": self.initial_agent_pos,
            "agent_end": self.agent_xy,
            "objectives_start": presence_start,
            "objectives_end": presence_end,
            "labeled_objectives": labeled_objectives,
            "trajectory": self.trajectory,

        }

    def render(self):
        if self.render_mode == "human":
            if not hasattr(self, 'fig') or self.fig is None:
                self.fig, self.ax = plt.subplots(figsize=(8, 8))
                plt.ion()
            
            self.ax.clear()
            
            # Plot agent
            self.ax.plot(self.agent_xy[0], self.agent_xy[1], 'ro', markersize=10, label='Agent')
            
            # Plot objectives
            for obj in self.objectives:
                self.ax.plot(obj[0], obj[1], 'go', markersize=8)
            
            # Set plot limits
            display_size = min(100, self.size)  # Show a smaller window for large grids
            agent_x, agent_y = self.agent_xy
            x_min = max(0, agent_x - display_size//2)
            x_max = min(self.size, agent_x + display_size//2)
            y_min = max(0, agent_y - display_size//2)
            y_max = min(self.size, agent_y + display_size//2)
            
            self.ax.set_xlim(x_min, x_max)
            self.ax.set_ylim(y_min, y_max)
            
            # Add labels and title
            self.ax.set_xlabel('X')
            self.ax.set_ylabel('Y')
            self.ax.set_title(f'Step: {self.n_iter}, Objectives remaining: {len(self.objectives)}')
            
            # Add grid
            self.ax.grid(True)
            
            # Add legend
            self.ax.legend()
            
            plt.draw()
            plt.pause(0.01)
            
            return self.fig

In [None]:
class TrajectoryRecorder:
    def __init__(self, save_dir="trajectories", num_objectives=10):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.successful_trajectories = []
        self.csv_file = os.path.join(save_dir, "trajectories.csv")
        self.last_trajectory_length = 0
        self.num_objectives = num_objectives
        self.last_objectives = None  # Store the last objectives state
        self.last_agent_pos = None
        self.last_seed = None
        self.labeled_objectives = {}  # Store the mapping of labels to positions

        # Create CSV file with header
        with open(self.csv_file, mode='w', newline='') as csvfile:
            fieldnames = ['seed', 'action_label', 'agent_pos_start', 'trajectory', 'agent_pos_end']
            for i in range(self.num_objectives):
                fieldnames.append(f'goal_{i+1:02d}_start')
            for i in range(self.num_objectives):
                fieldnames.append(f'goal_{i+1:02d}_end')
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

    def record_goal_reached(self, env):
        trajectory_data = env.get_trajectory_data()
        seed = trajectory_data["seed"]
        agent_end = trajectory_data["agent_end"]
        trajectory = trajectory_data["trajectory"]
        presence_end = trajectory_data["objectives_end"]

        # Store labeled objectives mapping if first time
        if not self.labeled_objectives and "labeled_objectives" in trajectory_data:
            self.labeled_objectives = trajectory_data["labeled_objectives"]

        # For the first goal, use the initial state
        if self.last_agent_pos is None:
            agent_start = trajectory_data["agent_start"]
            presence_start = trajectory_data["objectives_start"]
            new_trajectory = trajectory
        else:
            # For subsequent goals, use the last recorded state
            agent_start = self.last_agent_pos
            presence_start = self.last_objectives
            new_trajectory = trajectory[self.last_trajectory_length:]

        agent_start_str = f"{agent_start[0]}_{agent_start[1]}"
        agent_end_str = f"{agent_end[0]}_{agent_end[1]}"
        trajectory_str = json.dumps(new_trajectory)

        # Calculate number of objectives reached so far
        # It's the difference between the total and the remaining objectives
        num_objectives_reached = sum(1 for v in presence_end.values() if v == 0)

        action_label = f"actions{seed}a{num_objectives_reached}"

        row = {
            'seed': seed,
            'action_label': action_label,
            'agent_pos_start': agent_start_str,
            'trajectory': trajectory_str,
            'agent_pos_end': agent_end_str
        }

        # Add goal presence columns
        for label in sorted(presence_start.keys()):
            row[f'{label}_start'] = presence_start[label]
        for label in sorted(presence_end.keys()):
            row[f'{label}_end'] = presence_end[label]

        with open(self.csv_file, mode='a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=row.keys())
            writer.writerow(row)

        # print(f"Saved trajectory segment to {self.csv_file}")

        # Update state for next goal
        self.last_objectives = presence_end
        self.last_agent_pos = agent_end
        self.last_seed = seed
        self.last_trajectory_length = len(trajectory)

    def reset_for_new_episode(self):
        """Reset the recorder state for a new episode"""
        self.last_objectives = None
        self.last_agent_pos = None
        self.last_seed = None
        self.last_trajectory_length = 0

    def record_if_successful(self, env, success=True):
        if success:
            self.record_goal_reached(env)
            return self.csv_file
        return None


### Running a Directed Agent


In [None]:
def run_agent(episodes=10, max_steps=10000):
    env = MultiObjectiveGridEnv(size=1000, num_objectives=10)
    recorder = TrajectoryRecorder()

    for episode in range(episodes):
        obs, info = env.reset(seed=episode)
        recorder.last_objectives = None
        recorder.last_agent_pos = None
        recorder.last_seed = None
        recorder.last_trajectory_length = 0

        for step in range(max_steps):
            agent_pos = obs["agent"]
            objectives = obs["objectives"]

            # If no objectives left, we're done
            if len(objectives) == 0:
                break

            # Find closest objective
            distances = [abs(agent_pos[0] - obj[0]) + abs(agent_pos[1] - obj[1]) for obj in objectives]
            closest_idx = np.argmin(distances)
            closest_obj = objectives[closest_idx]

            # Decide action to move closer to closest objective
            if agent_pos[0] < closest_obj[0]:
                action = 3  # RIGHT
            elif agent_pos[0] > closest_obj[0]:
                action = 2  # LEFT
            elif agent_pos[1] < closest_obj[1]:
                action = 0  # UP
            else:
                action = 1  # DOWN

            obs, reward, terminated, truncated, info = env.step(action)

            # Check if an objective was reached
            if info.get('objective_reached', False):
                print(f"Episode {episode}, Step {step}: Objective reached!")
                recorder.record_goal_reached(env)

            if terminated:
                print(f"Episode {episode} completed in {step+1} steps! Visited all objectives.")
                break

            if step == max_steps - 1:
                print(f"Episode {episode} failed to complete in {max_steps} steps.")

    env.close()
    return recorder.successful_trajectories

In [None]:
successful_runs = run_agent(episodes=100)
print(f"Total successful runs: {len(successful_runs)}")

### Load and execute BTs

In [None]:
%pip install py-trees-meet-groot
%pip install py-trees-parser

In [None]:
import time
import py_trees

class TickStatsVisitor(py_trees.visitors.VisitorBase):
    def __init__(self):
        # Make sure to set full=True to visit all nodes
        super().__init__(full=True)
        self.nodes_visited = 0
        self.start_time = None
        self.duration = 0

    def initialise(self):
        self.nodes_visited = 0
        self.start_time = time.time()
        print("Visitor initialized")

    def finalise(self):
        self.duration = time.time() - self.start_time
        print("Visitor finalized")

    def visit(self, node):
        self.nodes_visited += 1
        print(f"Visiting node: {node.name}")


class NodeCounterVisitor(py_trees.visitors.VisitorBase):
    def __init__(self):
        super().__init__(full=True)
        self.count = 0
    
    def initialise(self):
        self.count = 0
    
    def finalise(self):
        pass
    
    def visit(self, node):
        self.count += 1


### Execute BT
Remeber to change the path to the behavior tree file and SPPA_BT variable if needed.

In [None]:
import py_trees
import pandas as pd
import json
import time
import numpy as np
import os
import xml.etree.ElementTree as ET
import traceback

# Configure if using SPPA BT or other method BT (BT-Factor or RE:BT Espresso)
SPPA_BT = False

if SPPA_BT:
    STEP_NEED = 1
    goal_sucess = 0

else:
    STEP_NEED = 10
    goal_sucess = 1

# Define your custom condition as a behavior
class CheckCondition(py_trees.behaviour.Behaviour):
    def __init__(self, name):
        super().__init__(name)
        self.condition_str = name  # Use the node name as the condition string
        self.blackboard = py_trees.blackboard.Blackboard()
    
    def update(self):
        try:
            # Parse the condition string
            if self.condition_str.startswith("goal_") and self.condition_str.endswith("_start"):
                # Check if a specific goal is in the start state (value 1)
                goal_id = self.condition_str
                if self.blackboard.exists("goal_states"):
                    goal_states = self.blackboard.get("goal_states")
                else:
                    goal_states = {}                     
                if goal_id in goal_states and goal_states[goal_id] == goal_sucess:
                    return py_trees.common.Status.SUCCESS
                else:
                    return py_trees.common.Status.FAILURE
            
            elif self.condition_str.startswith("seed_"):
                # Check if the current seed matches
                seed_value = int(self.condition_str[5:])  # Extract number after "seed_"
                if self.blackboard.exists("seed"):
                    current_seed = self.blackboard.get("seed")
                else:
                    current_seed = None

                if current_seed == seed_value:
                    return py_trees.common.Status.SUCCESS
                else:
                    return py_trees.common.Status.FAILURE
            
            else:
                self.logger.warning(f"Unknown condition: {self.condition_str}")
                return py_trees.common.Status.FAILURE
        except Exception as e:
            self.logger.error(f"Error in CheckCondition.update: {e}")
            traceback.print_exc()
            return py_trees.common.Status.FAILURE


# Define your custom action to execute trajectories from CSV
class ExecuteTrajectoryAction(py_trees.behaviour.Behaviour):
    def __init__(self, name, csv_file=None):
        super().__init__(name)
        self.csv_file = csv_file
        self.action_label = name  # Use the node name as the action label
        self.trajectory = None
        self.env = None
        self.blackboard = py_trees.blackboard.Blackboard()
    
    def setup(self):
        try:
            # self.logger.info(f"Loading trajectory for action: {self.action_label}")
            # Get CSV file from blackboard if not provided
            if self.csv_file is None:
                self.csv_file = self.blackboard.get("csv_file")
                if self.csv_file is None:
                    self.logger.error("CSV file not found in blackboard")
                    return False
            
            df = pd.read_csv(self.csv_file)
            matching_rows = df[df['action_label'] == self.action_label]
            
            if matching_rows.empty:
                self.logger.error(f"No matching action label found: {self.action_label}")
                return False
            
            trajectory_str = matching_rows.iloc[0]['trajectory']
            self.trajectory = json.loads(trajectory_str)
            # self.logger.info(f"Loaded trajectory: {self.trajectory}")
            return True
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.setup: {e}")
            traceback.print_exc()
            return False
    
    def initialise(self):
        try:
            self.env = self.blackboard.get("env")
            if self.env is None:
                self.logger.error("Environment not found in blackboard")
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.initialise: {e}")
            traceback.print_exc()
    
    def update(self):
        try:
            # Print when this action node is ticked
            print(f"Action node ticked: {self.name}")
            
            # Reload trajectory to ensure we have the latest
            df = pd.read_csv(self.csv_file)
            matching_rows = df[df['action_label'] == self.action_label]
            
            if matching_rows.empty:
                self.logger.error(f"No matching action label found: {self.action_label}")
                return py_trees.common.Status.FAILURE
            
            trajectory_str = matching_rows.iloc[0]['trajectory']
            self.trajectory = json.loads(trajectory_str)

            if self.env is None:
                return py_trees.common.Status.FAILURE
                
            if self.trajectory is None:
                print(f"Warning: trajectory is None in action {self.action_label}")
                return py_trees.common.Status.FAILURE
            
            # Execute the entire trajectory in one tick
            for step, action in enumerate(self.trajectory):
                obs, reward, terminated, truncated, info = self.env.step(int(action))
                
                # Update blackboard with new state
                self.blackboard.set("agent_pos", obs["agent"])
                self.blackboard.set("objectives", obs["objectives"].tolist() if hasattr(obs["objectives"], "tolist") else obs["objectives"])
                
                # Check if an objective was reached
                if info.get('objective_reached', False):
                    # Update goal states in blackboard
                    if self.blackboard.exists("goal_states"):
                        goal_states = self.blackboard.get("goal_states")
                        
                        # Find which goal was reached
                        if hasattr(self.env, 'initial_objectives') and hasattr(self.env, 'objectives'):
                            for i, obj in enumerate(self.env.initial_objectives):
                                if obj not in self.env.objectives:
                                    goal_id = f"goal_{i+1:02d}_start"
                                    if goal_id in goal_states:
                                        goal_states[goal_id] = 0  # Mark as visited
                        
                        # Update blackboard with modified goal states
                        self.blackboard.set("goal_states", goal_states)
                        # print(f"  Updated goal states in blackboard")
                
                if terminated:
                    self.logger.info(f"All objectives visited in {step+1} steps.")
                    return py_trees.common.Status.SUCCESS
            
            return py_trees.common.Status.SUCCESS  # Return SUCCESS after completing the trajectory
            
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.update: {e}")
            traceback.print_exc()
            return py_trees.common.Status.FAILURE


def create_behavior_instance(node_id, behaviors_dict):
    """Create a new instance of a behavior based on its ID"""
    try:
        if node_id.startswith("goal_") and node_id.endswith("_start"):
            return CheckCondition(name=node_id)
        elif node_id.startswith("seed_"):
            return CheckCondition(name=node_id)
        elif node_id in behaviors_dict:
            # For actions, create a new instance with the same parameters
            original = behaviors_dict[node_id]
            if isinstance(original, ExecuteTrajectoryAction):
                return ExecuteTrajectoryAction(name=node_id, csv_file=original.csv_file)
        
        print(f"Unknown behavior type for ID: {node_id}")
        return None
    except Exception as e:
        print(f"Error creating behavior instance: {e}")
        traceback.print_exc()
        return None

def build_subtree(xml_node, parent_node, behaviors_dict):
    """
    Recursively build a subtree from an XML node.
    
    Args:
        xml_node: The XML node to parse
        parent_node: The parent py_trees node
        behaviors_dict: Dictionary mapping node IDs to behavior objects
    """
    try:
        # Process each child of the XML node
        for child in xml_node:
            # Handle different node types
            if child.tag == "Sequence":
                node = py_trees.composites.Sequence(name=child.get("name", "Sequence"))
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Fallback":
                node = py_trees.composites.Selector(name=child.get("name", "Fallback"))
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Parallel":
                success_threshold = int(child.get("success_threshold", "1"))
                node = py_trees.composites.Parallel(
                    name=child.get("name", "Parallel"),
                    policy=py_trees.common.ParallelPolicy.SuccessOnOne()
                )
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Inverter":
                # Create a temporary composite to hold the child behavior
                temp_node = py_trees.composites.Sequence(name="TempNode")
                
                # Process all children of the inverter
                for subchild in child:
                    if subchild.tag == "Action" or subchild.tag == "Condition":
                        node_id = subchild.get("ID")
                        behavior = create_behavior_instance(node_id, behaviors_dict)
                        if behavior:
                            temp_node.add_child(behavior)
                    else:
                        # Handle nested composites
                        build_subtree([subchild], temp_node, behaviors_dict)
                
                # If the temporary node has a child, wrap it with an Inverter
                if temp_node.children:
                    inverter = py_trees.decorators.Inverter(
                        name=child.get("name", "Inverter"),
                        child=temp_node.children[0]
                    )
                    parent_node.add_child(inverter)
            
            elif child.tag == "Action" or child.tag == "Condition":
                # Get the ID of the node
                node_id = child.get("ID")
                behavior = create_behavior_instance(node_id, behaviors_dict)
                if behavior:
                    parent_node.add_child(behavior)
                else:
                    print(f"Failed to create behavior for ID: {node_id}")
    except Exception as e:
        print(f"Error in build_subtree: {e}")
        traceback.print_exc()

def tree_stats(root):
    from collections import deque, defaultdict
    if root is None:
        return 0, 0, 0

    max_depth = 0
    node_count = 0
    width_per_level = defaultdict(int)

    queue = deque([(root, 0)])

    while queue:
        node, depth = queue.popleft()
        node_count += 1
        width_per_level[depth] += 1
        max_depth = max(max_depth, depth)

        if hasattr(node, 'children'):
            for child in node.children:
                queue.append((child, depth + 1))

    max_width = max(width_per_level.values()) if width_per_level else 0
    return max_width, max_depth + 1, node_count

def parse_bt_xml(xml_file, behaviors_dict):
    """
    Custom parser for behavior tree XML files.
    
    Args:
        xml_file: Path to the XML file
        behaviors_dict: Dictionary mapping node IDs to behavior objects
        
    Returns:
        The root node of the parsed behavior tree
    """
    try:
        print(f"Parsing XML file: {xml_file}")
        # Parse the XML file
        tree = ET.parse(xml_file)
        xml_root = tree.getroot()
        
        print("XML parsed successfully")
        
        # Find the BehaviorTree element
        bt_elem = xml_root.find(".//BehaviorTree")
        if bt_elem is None:
            print("BehaviorTree element not found in XML")
            return None
        
        print("Found BehaviorTree element")
        
        # Find the first child of BehaviorTree (usually a Parallel or Sequence node)
        root_elem = None
        for child in bt_elem:
            print(f"Found child of BehaviorTree: {child.tag}")
            if child.tag in ["Parallel", "Sequence", "Fallback"]:
                root_elem = child
                break
        
        if root_elem is None:
            print("No valid root node (Parallel, Sequence, Fallback) found in BehaviorTree")
            return None
        
        print(f"Root element is: {root_elem.tag}")
        
        # Create the root node based on its type
        if root_elem.tag == "Parallel":
            success_threshold = int(root_elem.get("success_threshold", "1"))
            bt_root = py_trees.composites.Parallel(
                name="Root",
                policy=py_trees.common.ParallelPolicy.SuccessOnOne()
            )
            print(f"Created Parallel root with success_threshold={success_threshold}")
        elif root_elem.tag == "Sequence":
            bt_root = py_trees.composites.Sequence(name="Root")
            print("Created Sequence root")
        elif root_elem.tag == "Fallback":
            bt_root = py_trees.composites.Selector(name="Root")
            print("Created Fallback root")
        else:
            print(f"Unsupported root node type: {root_elem.tag}")
            return None
        
        # Recursively build the tree
        print("Building subtree...")
        build_subtree(root_elem, bt_root, behaviors_dict)
        print("Subtree built successfully")
        
        return bt_root
    
    except Exception as e:
        print(f"Error parsing XML: {e}")
        traceback.print_exc()
        return None

def load_and_execute_bt(xml_file, csv_file, env, seed=42):
    try:
        # Create behaviors dictionary
        behaviors_dict = {}
        
        # Create conditions for all possible goals
        for i in range(10):  # Assuming 10 objectives
            goal_id = f"goal_{i+1:02d}_start"
            behaviors_dict[goal_id] = CheckCondition(name=goal_id)
        
        # Create conditions for seeds
        for i in range(100):  # Assuming seeds 0-9
            seed_id = f"seed_{i}"
            behaviors_dict[seed_id] = CheckCondition(name=seed_id)
        
        # Create action behaviors from the CSV file
        try:
            df = pd.read_csv(csv_file)
            for action_label in df['action_label'].unique():
                behaviors_dict[action_label] = ExecuteTrajectoryAction(name=action_label, csv_file=csv_file)
        except Exception as e:
            print(f"Error loading actions from CSV: {e}")
            traceback.print_exc()
        
        print(f"Created {len(behaviors_dict)} behaviors")
        
        # Parse the XML file using our custom parser
        print("Parsing behavior tree XML...")
        root = parse_bt_xml(xml_file, behaviors_dict)
        
        if root is None:
            print("Failed to parse behavior tree")
            return
        
        print("Behavior tree parsed successfully")
        
        # Initialize environment
        print("Initializing environment...")
        obs, info = env.reset(seed=seed)
        
        # Initialize blackboard with environment state
        blackboard = py_trees.blackboard.Blackboard()
        blackboard.set("env", env)
        blackboard.set("csv_file", csv_file)
        blackboard.set("agent_pos", obs["agent"])
        blackboard.set("objectives", obs["objectives"].tolist() if hasattr(obs["objectives"], "tolist") else obs["objectives"])
        blackboard.set("seed", seed)
        
        # Initialize goal states from CSV
        goal_states = {}
        try:
            df = pd.read_csv(csv_file)
            # Find the row that matches our seed
            seed_rows = df[df['seed'] == seed]
            
            if not seed_rows.empty:
                # Use the first matching row
                row = seed_rows.iloc[0]
                
                # Extract goal states from the row
                for i in range(10):  # Assuming 10 objectives
                    goal_id = f"goal_{i+1:02d}_start"
                    if goal_id in row:
                        goal_states[goal_id] = row[goal_id]
            else:
                # If no matching seed, initialize all goals as present (1)
                for i in range(10):
                    goal_id = f"goal_{i+1:02d}_start"
                    goal_states[goal_id] = 1
        except Exception as e:
            print(f"Error loading goal states from CSV: {e}")
            traceback.print_exc()
            
            # Default initialization if CSV loading fails
            for i in range(10):
                goal_id = f"goal_{i+1:02d}_start"
                goal_states[goal_id] = 1
        
        print("Goal states initialized:")
        for goal_id, state in goal_states.items():
            print(f"  {goal_id}: {state}")
            
        blackboard.set("goal_states", goal_states)
        
        # Visualize the initial tree
        print("Initial Behavior Tree:")
        bt = py_trees.trees.BehaviourTree(root=root)
        
        # Calculate tree statistics
        max_width, max_depth, node_count = tree_stats(root)
        print(f"Tree statistics: width={max_width}, depth={max_depth}, nodes={node_count}")

        # Create and add the snapshot visitor for node tracking
        snapshot_visitor = py_trees.visitors.SnapshotVisitor()
        bt.visitors.append(snapshot_visitor)
        bt.setup()

        # Main execution loop
        step_count = 0
        max_steps = STEP_NEED  
        total_visits = 0
        
        while step_count < max_steps:
            print(f"\n--- Step {step_count} ---")
            
            # Reset snapshot visitor
            snapshot_visitor.initialise()
            
            # Start timing
            start_time = time.time()
            
            # Update blackboard with current environment state before ticking
            blackboard.set("agent_pos", obs["agent"])
            blackboard.set("objectives", obs["objectives"].tolist() if hasattr(obs["objectives"], "tolist") else obs["objectives"])
            
            # Tick the behavior tree
            status = bt.tick()
            
            # Calculate duration
            duration = time.time() - start_time
            
            step_count += 1
            print(f"Tree status after tick: {status}")
            print(f"Tick statistics: visited {len(snapshot_visitor.visited)} nodes in {duration:.6f} seconds")
            total_visits = total_visits+ len(snapshot_visitor.visited)
            
            # Print visited nodes if needed
            # if len(snapshot_visitor.visited) > 0:
            #     print("Nodes visited in this tick:")
            #     # for node_id, node in snapshot_visitor.visited.items():
            #     #     print(f"  - {node.name} ({type(node).__name__})")
            
            # Check if the tree has completed
            if status == py_trees.common.Status.SUCCESS or status == py_trees.common.Status.FAILURE:
                print(f"Behavior tree execution completed with status: {status}")
                break
            
            # Short delay to prevent CPU overload
            # time.sleep(0.01)
        
        print(f"Step count: {step_count}")
        print(f"Total visits: {total_visits}")
        
        # Close the environment
        try:
            env.close()
        except SystemExit:
            print("Ignored SystemExit from environment close()")
            
        return status, total_visits
    
    except SystemExit:
        print("SystemExit exception caught - this might be from py_trees or another library")
        traceback.print_exc()
    except Exception as e:
        print(f"Error in load_and_execute_bt: {e}")
        traceback.print_exc()


if __name__ == "__main__":
    try:
        # Create the environment
        env = MultiObjectiveGridEnv(size=1000, num_objectives=10)
        
        # Load and execute the behavior tree
        import os
        base_dir = os.getcwd()
        bt_file = os.path.join(base_dir, "bt_factor/100-enviroments/augmented/bt.xml")
        csv_file = os.path.join(base_dir, "trajectories/trajectories.csv")
        total_visits_count = 0
        for seed in range(100):
            print(f"\n--- Seed {seed} ---")
            status, total_visits = load_and_execute_bt(
                xml_file=bt_file,
                csv_file=csv_file,
                env=env,
                seed=seed
            )
            total_visits_count = total_visits_count + total_visits

        print(f"Total visits after {len(seeds)} seeds: {total_visits_count}")
    except Exception as e:
        print(f"Top-level exception: {e}")
        traceback.print_exc()


### Execute BT
This implementation checks ticks in all enviroments.

In [None]:
import py_trees
import pandas as pd
import json
import time
import numpy as np
import os
import xml.etree.ElementTree as ET
import traceback

# Configure if using SPPA BT or other method BT (BT-Factor or RE:BT Espresso)
SPPA_BT = True

if SPPA_BT:
    STEP_NEED = 1
    goal_sucess = 0

else:
    STEP_NEED = 10
    goal_sucess = 1

# Define your custom condition as a behavior
class CheckCondition(py_trees.behaviour.Behaviour):
    def __init__(self, name):
        super().__init__(name)
        self.condition_str = name  # Use the node name as the condition string
        self.blackboard = py_trees.blackboard.Blackboard()
    
    def update(self):
        try:
            # Parse the condition string
            if self.condition_str.startswith("goal_") and self.condition_str.endswith("_start"):
                # Check if a specific goal is in the start state (value 1)
                goal_id = self.condition_str
                if self.blackboard.exists("goal_states"):
                    goal_states = self.blackboard.get("goal_states")
                else:
                    goal_states = {}                     
                if goal_id in goal_states and goal_states[goal_id] == goal_sucess:
                    return py_trees.common.Status.SUCCESS
                else:
                    return py_trees.common.Status.FAILURE
            
            elif self.condition_str.startswith("seed_"):
                # Check if the current seed matches
                seed_value = int(self.condition_str[5:])  # Extract number after "seed_"
                if self.blackboard.exists("seed"):
                    current_seed = self.blackboard.get("seed")
                else:
                    current_seed = None

                if current_seed == seed_value:
                    return py_trees.common.Status.SUCCESS
                else:
                    return py_trees.common.Status.FAILURE
            
            else:
                self.logger.warning(f"Unknown condition: {self.condition_str}")
                return py_trees.common.Status.FAILURE
        except Exception as e:
            self.logger.error(f"Error in CheckCondition.update: {e}")
            traceback.print_exc()
            return py_trees.common.Status.FAILURE


# Define your custom action to execute trajectories from CSV
class ExecuteTrajectoryAction(py_trees.behaviour.Behaviour):
    def __init__(self, name, csv_file=None):
        super().__init__(name)
        self.csv_file = csv_file
        self.action_label = name  # Use the node name as the action label
        self.trajectory = None
        self.env = None
        self.blackboard = py_trees.blackboard.Blackboard()
    
    def setup(self):
        try:
            # self.logger.info(f"Loading trajectory for action: {self.action_label}")
            # Get CSV file from blackboard if not provided
            if self.csv_file is None:
                self.csv_file = self.blackboard.get("csv_file")
                if self.csv_file is None:
                    self.logger.error("CSV file not found in blackboard")
                    return False
            
            df = pd.read_csv(self.csv_file)
            matching_rows = df[df['action_label'] == self.action_label]
            
            if matching_rows.empty:
                self.logger.error(f"No matching action label found: {self.action_label}")
                return False
            
            trajectory_str = matching_rows.iloc[0]['trajectory']
            self.trajectory = json.loads(trajectory_str)
            # self.logger.info(f"Loaded trajectory: {self.trajectory}")
            return True
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.setup: {e}")
            traceback.print_exc()
            return False
    
    def initialise(self):
        try:
            self.env = self.blackboard.get("env")
            if self.env is None:
                self.logger.error("Environment not found in blackboard")
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.initialise: {e}")
            traceback.print_exc()
    
    def update(self):
        try:
            print(f"Action node ticked: {self.name}")
            
            # Reload trajectory to ensure we have the latest
            df = pd.read_csv(self.csv_file)
            matching_rows = df[df['action_label'] == self.action_label]
            
            if matching_rows.empty:
                self.logger.error(f"No matching action label found: {self.action_label}")
                return py_trees.common.Status.FAILURE
            
            trajectory_str = matching_rows.iloc[0]['trajectory']
            self.trajectory = json.loads(trajectory_str)

            if self.env is None:
                return py_trees.common.Status.FAILURE
                
            if self.trajectory is None:
                print(f"Warning: trajectory is None in action {self.action_label}")
                return py_trees.common.Status.FAILURE
            
            # Execute the entire trajectory in one tick
            for step, action in enumerate(self.trajectory):
                obs, reward, terminated, truncated, info = self.env.step(int(action))
                
                # Update blackboard with new state
                self.blackboard.set("agent_pos", obs["agent"])
                self.blackboard.set("objectives", obs["objectives"].tolist() if hasattr(obs["objectives"], "tolist") else obs["objectives"])
                
                # Check if an objective was reached
                if info.get('objective_reached', False):
                    # Update goal states in blackboard
                    if self.blackboard.exists("goal_states"):
                        goal_states = self.blackboard.get("goal_states")
                        
                        # Find which goal was reached
                        if hasattr(self.env, 'initial_objectives') and hasattr(self.env, 'objectives'):
                            for i, obj in enumerate(self.env.initial_objectives):
                                if obj not in self.env.objectives:
                                    goal_id = f"goal_{i+1:02d}_start"
                                    if goal_id in goal_states:
                                        goal_states[goal_id] = 0  # Mark as visited
                        
                        # Update blackboard with modified goal states
                        self.blackboard.set("goal_states", goal_states)
                        # print(f"  Updated goal states in blackboard")
                
                if terminated:
                    self.logger.info(f"All objectives visited in {step+1} steps.")
                    return py_trees.common.Status.SUCCESS
            
            return py_trees.common.Status.SUCCESS  # Return SUCCESS after completing the trajectory
            
        except Exception as e:
            self.logger.error(f"Error in ExecuteTrajectoryAction.update: {e}")
            traceback.print_exc()
            return py_trees.common.Status.FAILURE


def create_behavior_instance(node_id, behaviors_dict):
    """Create a new instance of a behavior based on its ID"""
    try:
        if node_id.startswith("goal_") and node_id.endswith("_start"):
            return CheckCondition(name=node_id)
        elif node_id.startswith("seed_"):
            return CheckCondition(name=node_id)
        elif node_id in behaviors_dict:
            # For actions, create a new instance with the same parameters
            original = behaviors_dict[node_id]
            if isinstance(original, ExecuteTrajectoryAction):
                return ExecuteTrajectoryAction(name=node_id, csv_file=original.csv_file)
        
        print(f"Unknown behavior type for ID: {node_id}")
        return None
    except Exception as e:
        print(f"Error creating behavior instance: {e}")
        traceback.print_exc()
        return None

def build_subtree(xml_node, parent_node, behaviors_dict):
    """
    Recursively build a subtree from an XML node.
    
    Args:
        xml_node: The XML node to parse
        parent_node: The parent py_trees node
        behaviors_dict: Dictionary mapping node IDs to behavior objects
    """
    try:
        # Process each child of the XML node
        for child in xml_node:
            # Handle different node types
            if child.tag == "Sequence":
                node = py_trees.composites.Sequence(name=child.get("name", "Sequence"))
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Fallback":
                node = py_trees.composites.Selector(name=child.get("name", "Fallback"))
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Parallel":
                success_threshold = int(child.get("success_threshold", "1"))
                node = py_trees.composites.Parallel(
                    name=child.get("name", "Parallel"),
                    policy=py_trees.common.ParallelPolicy.SuccessOnOne()
                )
                parent_node.add_child(node)
                build_subtree(child, node, behaviors_dict)
            
            elif child.tag == "Inverter":
                # Create a temporary composite to hold the child behavior
                temp_node = py_trees.composites.Sequence(name="TempNode")
                
                # Process all children of the inverter
                for subchild in child:
                    if subchild.tag == "Action" or subchild.tag == "Condition":
                        node_id = subchild.get("ID")
                        behavior = create_behavior_instance(node_id, behaviors_dict)
                        if behavior:
                            temp_node.add_child(behavior)
                    else:
                        # Handle nested composites
                        build_subtree([subchild], temp_node, behaviors_dict)
                
                # If the temporary node has a child, wrap it with an Inverter
                if temp_node.children:
                    inverter = py_trees.decorators.Inverter(
                        name=child.get("name", "Inverter"),
                        child=temp_node.children[0]
                    )
                    parent_node.add_child(inverter)
            
            elif child.tag == "Action" or child.tag == "Condition":
                # Get the ID of the node
                node_id = child.get("ID")
                behavior = create_behavior_instance(node_id, behaviors_dict)
                if behavior:
                    parent_node.add_child(behavior)
                else:
                    print(f"Failed to create behavior for ID: {node_id}")
    except Exception as e:
        print(f"Error in build_subtree: {e}")
        traceback.print_exc()

def tree_stats(root):
    from collections import deque, defaultdict
    if root is None:
        return 0, 0, 0

    max_depth = 0
    node_count = 0
    width_per_level = defaultdict(int)

    queue = deque([(root, 0)])

    while queue:
        node, depth = queue.popleft()
        node_count += 1
        width_per_level[depth] += 1
        max_depth = max(max_depth, depth)

        if hasattr(node, 'children'):
            for child in node.children:
                queue.append((child, depth + 1))

    max_width = max(width_per_level.values()) if width_per_level else 0
    return max_width, max_depth + 1, node_count

def parse_bt_xml(xml_file, behaviors_dict):
    """
    Custom parser for behavior tree XML files.
    
    Args:
        xml_file: Path to the XML file
        behaviors_dict: Dictionary mapping node IDs to behavior objects
        
    Returns:
        The root node of the parsed behavior tree
    """
    try:
        print(f"Parsing XML file: {xml_file}")
        # Parse the XML file
        tree = ET.parse(xml_file)
        xml_root = tree.getroot()
        
        print("XML parsed successfully")
        
        # Find the BehaviorTree element
        bt_elem = xml_root.find(".//BehaviorTree")
        if bt_elem is None:
            print("BehaviorTree element not found in XML")
            return None
        
        print("Found BehaviorTree element")
        
        # Find the first child of BehaviorTree (usually a Parallel or Sequence node)
        root_elem = None
        for child in bt_elem:
            print(f"Found child of BehaviorTree: {child.tag}")
            if child.tag in ["Parallel", "Sequence", "Fallback"]:
                root_elem = child
                break
        
        if root_elem is None:
            print("No valid root node (Parallel, Sequence, Fallback) found in BehaviorTree")
            return None
        
        print(f"Root element is: {root_elem.tag}")
        
        # Create the root node based on its type
        if root_elem.tag == "Parallel":
            success_threshold = int(root_elem.get("success_threshold", "1"))
            bt_root = py_trees.composites.Parallel(
                name="Root",
                policy=py_trees.common.ParallelPolicy.SuccessOnOne()
            )
            print(f"Created Parallel root with success_threshold={success_threshold}")
        elif root_elem.tag == "Sequence":
            bt_root = py_trees.composites.Sequence(name="Root")
            print("Created Sequence root")
        elif root_elem.tag == "Fallback":
            bt_root = py_trees.composites.Selector(name="Root")
            print("Created Fallback root")
        else:
            print(f"Unsupported root node type: {root_elem.tag}")
            return None
        
        # Recursively build the tree
        print("Building subtree...")
        build_subtree(root_elem, bt_root, behaviors_dict)
        print("Subtree built successfully")
        
        return bt_root
    
    except Exception as e:
        print(f"Error parsing XML: {e}")
        traceback.print_exc()
        return None

def load_and_execute_bt(xml_file, csv_file, env, seed=42):
    try:
        # Create behaviors dictionary
        behaviors_dict = {}
        
        # Create conditions for all possible goals
        for i in range(10):  # Assuming 10 objectives
            goal_id = f"goal_{i+1:02d}_start"
            behaviors_dict[goal_id] = CheckCondition(name=goal_id)
        
        # Create conditions for seeds
        for i in range(100):  # Assuming seeds 0-99
            seed_id = f"seed_{i}"
            behaviors_dict[seed_id] = CheckCondition(name=seed_id)
        
        # Create action behaviors from the CSV file
        try:
            df = pd.read_csv(csv_file)
            for action_label in df['action_label'].unique():
                behaviors_dict[action_label] = ExecuteTrajectoryAction(name=action_label, csv_file=csv_file)
        except Exception as e:
            print(f"Error loading actions from CSV: {e}")
            traceback.print_exc()
        
        # Parse the XML file using our custom parser
        root = parse_bt_xml(xml_file, behaviors_dict)
        
        if root is None:
            print("Failed to parse behavior tree")
            return None, 0, 0, 0
        
        # Initialize environment
        obs, info = env.reset(seed=seed)
        
        # Initialize blackboard with environment state
        blackboard = py_trees.blackboard.Blackboard()
        blackboard.set("env", env)
        blackboard.set("csv_file", csv_file)
        blackboard.set("agent_pos", obs["agent"])
        blackboard.set("objectives", obs["objectives"].tolist() if hasattr(obs["objectives"], "tolist") else obs["objectives"])
        blackboard.set("seed", seed)
        
        # Initialize goal states from CSV
        goal_states = {}
        try:
            df = pd.read_csv(csv_file)
            seed_rows = df[df['seed'] == seed]
            
            if not seed_rows.empty:
                row = seed_rows.iloc[0]
                for i in range(10):  # Assuming 10 objectives
                    goal_id = f"goal_{i+1:02d}_start"
                    if goal_id in row:
                        goal_states[goal_id] = row[goal_id]
            else:
                for i in range(10):
                    goal_id = f"goal_{i+1:02d}_start"
                    goal_states[goal_id] = 1
        except Exception as e:
            print(f"Error loading goal states from CSV: {e}")
            # Default initialization if CSV loading fails
            for i in range(10):
                goal_id = f"goal_{i+1:02d}_start"
                goal_states[goal_id] = 1
        
        blackboard.set("goal_states", goal_states)
        
        bt = py_trees.trees.BehaviourTree(root=root)
        
        # Create and add the snapshot visitor for node tracking
        snapshot_visitor = py_trees.visitors.SnapshotVisitor()
        bt.visitors.append(snapshot_visitor)
        bt.setup()

        # Main execution loop
        step_count = 0
        max_steps = STEP_NEED
        total_visits = 0
        
        correct_counter = 0
        wrong_counter = 0
        
        while step_count < max_steps:
            print(f"\n--- Step {step_count} ---")
            
            snapshot_visitor.initialise()
            
            # Tick the behavior tree
            status = bt.tick()
            
            # ---Diagnostic print of visited nodes ---
            # print("Visited nodes in this tick:")
            # for node in snapshot_visitor.visited.values():
            #     print(f"  - Node: {node.name}, Type: {type(node).__name__}")

           
            expected_action = f"actions{seed}a{step_count + 1}"  # Adjust for 1-based indexing seen in logs
            action_found_in_tick = False

            for node in snapshot_visitor.visited.values():
                # Flexible check: Look for nodes with 'trajectory' attribute or name starting with 'actions'
                if hasattr(node, 'trajectory') or (hasattr(node, 'name') and node.name.startswith('actions')):
                    action_found_in_tick = True
                    ticked_action_name = node.name
                    print(f"  Ticked Action: {ticked_action_name}, Expected: {expected_action}")
                    if ticked_action_name == expected_action:
                        correct_counter += 1
                        print("  -> Match: Correct action.")
                    else:
                        wrong_counter += 1
                        print("  -> Mismatch: Wrong action.")
                    break  # Assume one primary action per step

            if not action_found_in_tick:
                print("  No action was ticked in this step.")
            # --- End of new logic ---

            total_visits += len(snapshot_visitor.visited)
            step_count += 1
            
            if status in [py_trees.common.Status.SUCCESS, py_trees.common.Status.FAILURE]:
                print(f"Behavior tree execution completed with status: {status}")
                break
            
            # time.sleep(0.01)
        
        try:
            env.close()
        except SystemExit:
            print("Ignored SystemExit from environment close()")
            
        return status, total_visits, correct_counter, wrong_counter
    
    except Exception as e:
        print(f"Error in load_and_execute_bt: {e}")
        traceback.print_exc()
        return None, 0, 0, 0


# Example usage
if __name__ == "__main__":
    try:
        # Create the environment
        env = MultiObjectiveGridEnv(size=1000, num_objectives=10)
        
        # Load and execute the behavior tree
        import os
        base_dir = os.getcwd()
        bt_file = os.path.join(base_dir, "ours/100-enviroments/bt.xml")
        csv_file = os.path.join(base_dir, "trajectories/trajectories.csv")
        seeds = list(range(100)) # Simplified range for seeds 0-99

        total_visits_count = 0
        total_correct_actions = 0
        total_wrong_actions = 0

        for seed in seeds:
            print(f"\n==================== EXECUTING FOR SEED: {seed} ====================")
            status, total_visits, correct, wrong = load_and_execute_bt(
                xml_file=bt_file,
                csv_file=csv_file,
                env=env,
                seed=seed
            )
            total_visits_count += total_visits
            total_correct_actions += correct
            total_wrong_actions += wrong

        print("\n==================== FINAL SUMMARY ====================")
        print(f"Total visits after {len(seeds)} seeds: {total_visits_count}")
        print(f"Total CORRECT actions ticked: {total_correct_actions}")
        print(f"Total WRONG actions ticked: {total_wrong_actions}")

    except Exception as e:
        print(f"Top-level exception: {e}")
        traceback.print_exc()
