In [None]:
import os
import csv
import logging
import matplotlib.pyplot as plt
import numpy as np
import pybullet as p
import pybullet_data
from gymnasium import spaces
from ray import tune
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from typing import Optional, Dict, Any, Tuple
from math import sqrt, pi, cos, sin, atan2, acos, asin
from numpy import linalg
import cmath

# ----------------------------
# Configure Logging
# ----------------------------
logging.basicConfig(
    level=logging.INFO,  # Set to INFO to reduce logs
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ----------------------------
# Custom UR5e Robot Control
# ----------------------------
class CustomUR5eRobot:

    def __init__(self, urdf_path: str):
        self.urdf_path = urdf_path
        self.robot_id = None
        self.joint_indices = []
        self.n_joints = 6  # Assuming UR5e has 6 controllable joints
        self.filtered_action = np.zeros(self.n_joints)
        self.prev_position_error = np.zeros(self.n_joints)
        self.delta_deg_prev = np.zeros(self.n_joints)
    def load_robot(self):
        """
        Loads the UR5e robot into the PyBullet simulation.
        """
        if not os.path.exists(self.urdf_path):
            logger.error(f"URDF file {self.urdf_path} does not exist.")
            raise FileNotFoundError(f"URDF file {self.urdf_path} does not exist.")
        self.robot_id = p.loadURDF(self.urdf_path, [0, 0, 0], useFixedBase=True)
        # Assume the first 6 revolute joints are controllable
        self.joint_indices = [
            i for i in range(p.getNumJoints(self.robot_id))
            if p.getJointInfo(self.robot_id, i)[2] == p.JOINT_REVOLUTE
        ][:self.n_joints]
        if len(self.joint_indices) < self.n_joints:
            logger.error(f"Expected {self.n_joints} controllable joints, found {len(self.joint_indices)}.")
            raise ValueError(f"Expected {self.n_joints} controllable joints, found {len(self.joint_indices)}.")
        logger.info(f"Loaded robot with ID {self.robot_id}. Joint indices: {self.joint_indices}")

    def reset_joints(self, initial_positions: np.ndarray, initial_velocities: np.ndarray):
        """
        Resets the robot's joint positions and velocities.
        """
        for i, joint_idx in enumerate(self.joint_indices):
            p.resetJointState(self.robot_id, joint_idx, targetValue=initial_positions[i], targetVelocity=initial_velocities[i])
            # Disable default motor control
            p.setJointMotorControl2(
                bodyIndex=self.robot_id,
                jointIndex=joint_idx,
                controlMode=p.VELOCITY_CONTROL,
                targetVelocity=0,
                force=0
            )
        logger.info("Reset joints to initial positions and velocities.")


    def AH(self,n, th, c):

        mat = np.matrix
        d = mat([0.1625, 0, 0, 0.133, 0.0997, 0.101])
        alph = mat([pi/2, 0, 0, pi/2, -pi/2, 0])

        a = np.array([0, -0.425, -0.39225, 0, 0, 0])
        T_a = mat(np.identity(4), copy=False,dtype = np.float32)
        T_a[0, 3] = a[n-1]
        T_d = mat(np.identity(4), copy=False)
        T_d[2, 3] = d[0, n-1]

        Rzt = mat([[cos(th[n-1, c]), -sin(th[n-1, c]), 0, 0],
                [sin(th[n-1, c]), cos(th[n-1, c]), 0, 0],
                [0, 0, 1, 0],
                [0, 0, 0, 1]], copy=False)

        Rxa = mat([[1, 0, 0, 0],
                [0, cos(alph[0, n-1]), -sin(alph[0, n-1]), 0],
                [0, sin(alph[0, n-1]), cos(alph[0, n-1]), 0],
                [0, 0, 0, 1]], copy=False)

        A_i = T_d * Rzt * T_a * Rxa

        return A_i

    def HTrans(self,th, c):
        A_1 = self.AH(1, th, c)
        A_2 = self.AH(2, th, c)
        A_3 = self.AH(3, th, c)
        A_4 = self.AH(4, th, c)
        A_5 = self.AH(5, th, c)
        A_6 = self.AH(6, th, c)

        T_06 = A_1 * A_2 * A_3 * A_4 * A_5 * A_6

        return T_06
    
    def invKine(self,xyz):  # T60
        mat = np.matrix
        a2 = -0.425
        a3 = -0.39225
        d4 = 0.133
        d6 = 0.101


        desired_pos = np.array([[1, 0, 0, xyz[0]], [0, -1, 0, xyz[1]], [0, 0, -1, xyz[2]], [0, 0, 0, 1]],dtype = np.float32)
        th = mat(np.zeros((6, 8)))
        P_05 = (desired_pos * mat([0, 0, -d6, 1]).T - mat([0, 0, 0, 1]).T)

        # print("P_05:", P_05)

        # **** theta1 ****

        psi = atan2(P_05[1, 0], P_05[0, 0])
        phi = float(np.arccos(float(d4) / float(np.sqrt( P_05[1, 0]**2 + P_05[0, 0]**2,dtype = np.float32))))
        # The two solutions for theta1 correspond to the shoulder being either left or right
        th[0, 0:4] = pi/2 + psi + phi
        th[0, 4:8] = pi/2 + psi - phi
        th = th.real

        # print("theta1:", th[0])

        # **** theta5 ****

        cl = [0, 4]  # wrist up or down
        for i in range(0, len(cl)):
            c = cl[i]
            T_10 = np.linalg.inv(self.AH(1, th, c))
            T_16 = T_10 * desired_pos
            #print(T_16)
            th[4, c:c+2] = + acos((T_16[2, 3] - d4) / d6)
            th[4, c+2:c+4] = - acos((T_16[2, 3] - d4) / d6)

        th = th.real

        # print("theta5:", th[4])

        # **** theta6 ****
        # theta6 is not well-defined when sin(theta5) = 0 or when T16(1, 3), T16(2, 3) = 0.

        cl = [0, 2, 4, 6]
        for i in range(0, len(cl)):
            c = cl[i]
            T_10 = linalg.inv(self.AH(1, th, c))
            T_16 = linalg.inv(T_10 * desired_pos)
            th[5, c:c+2] = atan2((-T_16[1, 2] / sin(th[4, c])), (T_16[0, 2] / sin(th[4, c])))

        th = th.real

        # print("theta6:", th[5])

        # **** theta3 ****
        cl = [0, 2, 4, 6]
        for i in range(0, len(cl)):
            c = cl[i]
            T_10 = linalg.inv(self.AH(1, th, c))
            T_65 = self.AH(6, th, c)
            T_54 = self.AH(5, th, c)
            T_14 = (T_10 * desired_pos) * linalg.inv(T_54 * T_65)
            P_13 = T_14 * mat([0, -d4, 0, 1]).T - mat([0, 0, 0, 1]).T
            t3 = cmath.acos((linalg.norm(P_13)**2 - a2**2 - a3**2) / (2 * a2 * a3))  # norm ?
            th[2, c] = t3.real
            th[2, c+1] = -t3.real

        # print("theta3:", th[2])

        # **** theta2 and theta 4 ****

        cl = [0, 1, 2, 3, 4, 5, 6, 7]
        for i in range(0, len(cl)):
            c = cl[i]
            T_10 = np.linalg.inv(self.AH(1, th, c))
            T_65 = np.linalg.inv(self.AH(6, th, c))
            T_54 = np.linalg.inv(self.AH(5, th, c))
            T_14 = (T_10 * desired_pos) * T_65 * T_54
            P_13 = T_14 * mat([0, -d4, 0, 1]).T - mat([0, 0, 0, 1]).T

            # theta 2
            th[1, c] = -atan2(P_13[1], -P_13[0]) + asin(a3 * sin(th[2, c]) / linalg.norm(P_13))
            # theta 4
            T_32 = linalg.inv(self.AH(3, th, c))
            T_21 = linalg.inv(self.AH(2, th, c))
            T_34 = T_32 * T_21 * T_14
            th[3, c] = atan2(T_34[1, 0], T_34[0, 0])

        th = th.real
        ik_results = th.T
        pos = np.take(ik_results, indices=[2], axis=0)
        pos = pos.tolist()


        return np.array(pos[0])
    def FK(self, joint_angles):
        """
        Computes the forward kinematics for the UR5e robot.

        :param joint_angles: Array of 6 joint angles [theta1, theta2, ..., theta6]
        :return: Tuple of (x, y, z) positions
        """
        t1, t2, t3, t4, t5, t6 = joint_angles
        # Define your robot's specific parameters here
        # Example placeholder values (replace with actual UR5e parameters)
        d1 = 0.1625
        a2 = -0.425
        a3 = -0.39225
        d4 = 0.133
        d5 = 0.0997
        d6 = 0.101
        x = -(np.cos(t1)*np.cos(t4+t2+t3)*np.sin(t5)*d6) + (np.cos(t5)*np.sin(t1)*d6) + \
            (np.cos(t1)*np.sin(t2+t3+t4)*d5) + (np.sin(t1)*d4) + \
            (np.cos(t1)*np.cos(t2+t3)*a3) + (np.cos(t1)*np.cos(t2)*a2)
        y = -(np.cos(t2+t3+t4)*np.sin(t1)*np.sin(t5)*d6) - (np.cos(t1)*np.cos(t5)*d6) + \
            (np.sin(t1)*np.sin(t2+t3+t4)*d5) - (np.cos(t1)*d4) + \
            (np.cos(t2+t3)*np.sin(t1)*a3) + (np.cos(t2)*np.sin(t1)*a2)
        z = -(np.sin(t2+t3+t4)*np.sin(t5)*d6) - (np.cos(t2+t3+t4)*d5) + \
            (np.sin(t2+t3)*a3) + (np.sin(t2)*a2) + d1

        # Check for numerical issues
        if any(np.isnan([x, y, z])) or any(np.isinf([x, y, z])):
            logger.error(f"FK computation resulted in invalid values: x={x}, y={y}, z={z}")
            return 0.0, 0.0, 0.0  # Default fallback values

        return x, y, z
    
    def get_observation(self):
        """
        Retrieves the current joint positions and velocities.
        """
        joint_positions = []
        joint_velocities = []
        for joint_idx in self.joint_indices:
            joint_state = p.getJointState(self.robot_id, joint_idx)
            joint_positions.append(joint_state[0])  # Position
            joint_velocities.append(joint_state[1])  # Velocity
        # Removed detailed logging to reduce overhead
        return np.array(joint_positions, dtype=np.float32), np.array(joint_velocities, dtype=np.float32)

    def apply_action(self, delta_deg: np.ndarray, target_positions: np.ndarray):
        kp = [1.2, 0.8, 0.8, 0.8, 0.8, 0.8]  # PD controller proportional gains
        kd = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # PD controller derivative gains

        dt = 1/15  # Time step
        alpha = 0.07  # Low-pass filter coefficient (adjust as needed)

        # Apply the low-pass filter to the RL actions (delta_deg)
        filtered_delta_deg = alpha * delta_deg + (1 - alpha) * self.filtered_action
        self.filtered_action = filtered_delta_deg

        for i in range(self.n_joints):
            current_pos, _ = p.getJointState(self.robot_id, self.joint_indices[i])[:2]
            # Calculate errors
            position_error = target_positions[i] - current_pos + filtered_delta_deg[i]
            d_error = (position_error - self.prev_position_error[i]) / dt
            # PD control
            pd_output_velocity = (kp[i] * position_error) + (kd[i] * d_error)
            pd_output_velocity = np.clip(pd_output_velocity, -2, 2)  # Limit velocities

            # Optional: Zero out control for specific joints if needed
            if i in [0, 4, 5]:
                pd_output_velocity = 0

            p.setJointMotorControl2(
                bodyIndex=self.robot_id,
                jointIndex=self.joint_indices[i],
                controlMode=p.VELOCITY_CONTROL,
                targetVelocity=pd_output_velocity,
                force=500
            )
            self.prev_position_error[i] = position_error
            self.delta_deg_prev[i] = delta_deg[i]

    def disconnect(self):
        """
        Disconnects the robot from the simulation.
        """
        if self.robot_id is not None:
            p.removeBody(self.robot_id)
            logger.info(f"Robot with ID {self.robot_id} removed from simulation.")

# ----------------------------
# Multi-Agent Environment
# ----------------------------
class MultiAgentUR5eEnv(MultiAgentEnv):
    """
    Multi-agent Gymnasium environment for the UR5e robotic arm.
    Each joint is controlled by a separate agent.
    """
    def __init__(self, render_mode=None, max_step=600):  # Reduced max_step from 12000 to 2000
        super().__init__()
        self.robot = CustomUR5eRobot(urdf_path="ur5e.urdf")  # Ensure ur5e.urdf is in the correct path
        self.n_agents = self.robot.n_joints
        self.agents = [f"agent_{i}" for i in range(self.n_agents)]
        # Action space: single float per agent representing target velocity
        self.action_space = spaces.Box(low=-0.1, high=0.1, shape=(1,), dtype=np.float32)
        # Observation space: (target_position - current_position)
        self.observation_space = spaces.Box(low=np.array([-2*np.pi, -6]),
                                            high=np.array([2*np.pi, 6]),
                                            dtype=np.float32)
        self.render_mode = render_mode
        self.max_steps = max_step
        self.current_step = 0
        self.target_positions = self.robot.invKine([-0.5, -0.5  , 0.0 ])  # Target joint positions
        self.target_velocity = np.zeros(6)


        self.dones = {agent: False for agent in self.agents}

        # Initialize PyBullet physics client
        if self.render_mode == "human":
            self.physics_client = p.connect(p.GUI)
            logger.info("PyBullet GUI mode enabled.")
        else:
            self.physics_client = p.connect(p.DIRECT)
            logger.info("PyBullet DIRECT mode enabled.")

        # Additional setup
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.setGravity(0, 0, 0)  # Set gravity
        self.robot.load_robot()
        self.robot.reset_joints(initial_positions=self.robot.invKine([-0.5, -0.5  , 0.0 ]),
                                initial_velocities=np.zeros(self.n_agents))
    def get_target_position(self, step):
        """
        Returns the desired end-effector position at the given step.
        For example, a circular trajectory in the XY-plane.
        """
        # Parameters for the circular trajectory
        radius = 0.1  # Adjust as needed
        angular_speed = 0.001  # Adjust as needed

        # Calculate the angle for the current step
        angle = angular_speed * step
        xyz_init = [-0.5, -0.5  , 0.0 ]
        xyz = [-0.5, -0.5  , 0.5 ]
        # Desired X, Y, Z position
        dt1 = 15 

        # x_desired = xyz_init[0] # + radius * np.cos(angle)
        # y_desired = xyz_init[1] # + radius * np.sin(angle)
        # z_desired = xyz_init[2] + radius * np.cos(angle)


        if step >= dt1 * 5 and step <= dt1 * 25 :
            x_desired = xyz[0] # + radius * np.cos(angle)
            y_desired = xyz[1] # + radius * np.sin(angle)
            z_desired = xyz[2] # Keep Z constant, or define a function
        else:
            x_desired = xyz_init[0] # + radius * np.cos(angle)
            y_desired = xyz_init[1] # + radius * np.sin(angle)
            z_desired = xyz_init[2]  # Keep Z constant, or define a function 
            
        return np.array([x_desired, y_desired, z_desired])
    
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
        """
        Resets the environment to an initial state and returns an initial observation.
        """
        if seed is not None:
            np.random.seed(seed)  # Example: Seed NumPy's RNG
            # PyBullet does not support direct seeding, but you can manage randomness as needed

        self.current_step = 0
        self.dones = {agent: False for agent in self.agents}

        # Reset the simulation
        p.resetSimulation()
        p.setGravity(0, 0, 0)
        p.setTimeStep(1./15.)

        # Load the robot model into the simulation
        self.robot.load_robot()

        # Initialize joint positions and velocities
        initial_positions = self.robot.invKine([-0.5, -0.5  , 0.0 ]) 
        initial_velocities = np.zeros(self.n_agents)
        self.robot.reset_joints(initial_positions, initial_velocities)

        observations, _ = self._get_obs()
        infos = {agent: {} for agent in self.agents}
        return observations, infos


    def step(self, action_dict):
        """
        Executes a step in the environment given the actions of all agents.
        """
        self.current_step += 1

        # Collect actions
        delta_deg = np.zeros(self.n_agents)
        for i, agent in enumerate(self.agents):
            delta_deg[i] = action_dict[agent][0]

        # Store actions for reward computation
        self.last_actions = action_dict

        # Apply actions to the robot using PD control
        self.target_end_effector_pos = self.get_target_position(self.current_step)
        self.target_positions = self.robot.invKine(self.target_end_effector_pos)
        self.robot.apply_action(delta_deg, self.target_positions)
        # Advance the simulation by one time step
        p.stepSimulation()

        # Get observations (positions and velocities)
        joint_positions, joint_velocities = self.robot.get_observation()

        # Compute X, Y, Z positions using Forward Kinematics
        x, y, z = self.robot.FK(joint_positions)

        # Compute individual end-effector errors
        x_error = x - self.target_end_effector_pos[0]
        y_error = y - self.target_end_effector_pos[1]
        z_error = z - self.target_end_effector_pos[2]

        # Compute total end-effector error
        end_effector_error = np.linalg.norm([x_error, y_error, z_error])
        self.end_effector_error = end_effector_error

        # Check if target is reached
        target_threshold = 0.001  # Adjust threshold as needed
        if not hasattr(self, 'target_reached'):
            self.target_reached = False

        if not self.target_reached and end_effector_error <= target_threshold:
            self.target_reached = True
            # Store the time step when the target was first reached
            self.target_reach_step = self.current_step

        # Clamp joint positions to be within observation space bounds
        joint_positions_clamped = np.clip(joint_positions, -np.pi, np.pi)
        joint_velocities_clamped = np.clip(joint_velocities, -2, 2)

        # Compute position errors for observations
        position_errors = self.target_positions - joint_positions_clamped
        # velocity_errors = self.target_velocity - joint_velocities_clamped

        observations = {}
        for i, agent in enumerate(self.agents):
            observations[agent] = np.array([joint_positions_clamped[i], joint_velocities_clamped[i]], dtype=np.float32)

        # Calculate rewards
        rewards = self._compute_rewards(observations)

        # Update infos to include target positions and errors
        infos = {
            agent: {
                "joint_position": joint_positions_clamped[i],
                "joint_velocity": joint_velocities_clamped[i],
                "reward": rewards[agent],
                "end_effector_x": x,
                "end_effector_y": y,
                "end_effector_z": z,
                "target_end_effector_x": self.target_end_effector_pos[0],
                "target_end_effector_y": self.target_end_effector_pos[1],
                "target_end_effector_z": self.target_end_effector_pos[2],
                "x_error": x_error,
                "y_error": y_error,
                "z_error": z_error,
                "end_effector_error": end_effector_error,
            } for i, agent in enumerate(self.agents)
        }

        # Check termination conditions
        done = self.current_step >= self.max_steps

        # Build the "terminations" and "truncations" flag dicts for all agents
        terminations = {agent: False for agent in self.agents}
        truncations = {agent: done for agent in self.agents}
        terminations["__all__"] = False
        truncations["__all__"] = done

        return observations, rewards, terminations, truncations, infos


    def _get_obs(self):
        """
        Retrieves the current joint positions and calculates the difference with target positions.
        """
        joint_positions, joint_velocities = self.robot.get_observation()
        position_errors = self.target_positions - joint_positions
        velocity_errors = self.target_velocity - joint_velocities
        observations = {}
        for i, agent in enumerate(self.agents):
            observations[agent] = np.array([joint_positions[i], joint_velocities[i]], dtype=np.float32)
        return observations, {}
    

    def _compute_rewards(self, observations):
        rewards = {}
        for i, agent in enumerate(self.agents):
            # Penalize the end-effector error at each time step
            joint_position_error = abs(self.target_positions[i] - observations[agent][0])
            joint_position_error = -joint_position_error  # The closer to the target, the less negative
            dt1 = 15
            joint_velocity_error = -abs(0 - observations[agent][1])
            # Include a time penalty that increases over time
            if self.current_step <= dt1 * 5:
                time = 1
            elif self.current_step >= dt1 * 5 and self.current_step <= dt1 * 10:
                time = self.current_step - (dt1 * 5)
                time = (dt1 * 20)/(time + 10)
            elif self.current_step > dt1 * 25 and self.current_step <= dt1 * 30:
                time = self.current_step - (dt1 * 25)
                time = (dt1 * 15)/(time + 10)
            else:
                time = 1

            time_penalty = time  # Adjust the coefficient as needed
            reward = joint_position_error# * time_penalty

            # Assign the computed reward to the agent
            rewards[agent] = reward

        return rewards



    def render(self, mode='human'):
        """
        Renders the environment. Rendering is handled by PyBullet.
        """
        pass

    def close(self):
        """
        Closes the environment and disconnects PyBullet.
        """
        p.disconnect()
        self.robot.disconnect()
        logger.info("PyBullet simulation disconnected.")



class JointPositionLogger(DefaultCallbacks):
    def __init__(self):
        super().__init__()
        self.log_dir = "joint_logs_PPO"  # Default log directory
        self.episode_counter = 1  # Initialize episode counter
        self.total_rewards_per_episode = {}  # Dict to store rewards per agent

    def on_algorithm_init(self, *, algorithm, **kwargs):
        self.log_dir = algorithm.config.get("callbacks_config", {}).get("log_dir", "joint_logs_PPO")
        os.makedirs(self.log_dir, exist_ok=True)
        logger.info(f"JointPositionLogger initialized with log directory: {self.log_dir}")

    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs):
        sequential_episode_id = self.episode_counter
        self.episode_counter += 1
        episode.user_data["sequential_episode_id"] = sequential_episode_id

        file_path = os.path.join(self.log_dir, f"episode_{sequential_episode_id}.csv")
        os.makedirs(self.log_dir, exist_ok=True)  # Ensure the log directory exists
        csv_file = open(file_path, mode='w', newline='')
        writer = csv.DictWriter(csv_file, fieldnames=[
            "step", "agent_id", "joint_position", "joint_velocity",
            "end_effector_x", "end_effector_y", "end_effector_z",
            "target_end_effector_x", "target_end_effector_y", "target_end_effector_z",
            "x_error", "y_error", "z_error",
            "end_effector_error",
            "reward"
        ])
        writer.writeheader()
        episode.user_data["csv_file"] = csv_file
        episode.user_data["csv_writer"] = writer
        episode.user_data["episode_rewards"] = {}

    def on_episode_step(self, *, worker, base_env, policies, episode, **kwargs):
        writer = episode.user_data.get("csv_writer")
        if writer is None:
            return  # Skip logging if writer is not initialized

        for agent_id in episode.get_agents():
            try:
                info = episode.last_info_for(agent_id) or {}
            except AttributeError:
                info = episode._last_infos.get(agent_id, {})
            if info is None:
                continue

            joint_pos = info.get("joint_position", 0.0)
            joint_vel = info.get("joint_velocity", 0.0)
            end_effector_x = info.get("end_effector_x", 0.0)
            end_effector_y = info.get("end_effector_y", 0.0)
            end_effector_z = info.get("end_effector_z", 0.0)
            target_end_effector_x = info.get("target_end_effector_x", 0.0)
            target_end_effector_y = info.get("target_end_effector_y", 0.0)
            target_end_effector_z = info.get("target_end_effector_z", 0.0)
            x_error = info.get("x_error", 0.0)
            y_error = info.get("y_error", 0.0)
            z_error = info.get("z_error", 0.0)
            end_effector_error = info.get("end_effector_error", 0.0)
            per_step_reward = info.get("reward", 0.0)  # Get reward from info

            # Accumulate rewards safely
            episode.user_data.setdefault("episode_rewards", {}).setdefault(agent_id, 0.0)
            episode.user_data["episode_rewards"][agent_id] += per_step_reward

            # Log data
            log_entry = {
                "step": episode.length,
                "agent_id": agent_id,
                "joint_position": joint_pos,
                "joint_velocity": joint_vel,
                "end_effector_x": end_effector_x,
                "end_effector_y": end_effector_y,
                "end_effector_z": end_effector_z,
                "target_end_effector_x": target_end_effector_x,
                "target_end_effector_y": target_end_effector_y,
                "target_end_effector_z": target_end_effector_z,
                "x_error": x_error,
                "y_error": y_error,
                "z_error": z_error,
                "end_effector_error": end_effector_error,
                "reward": per_step_reward
            }
            writer.writerow(log_entry)

    def on_episode_end(self, *, worker, base_env, policies, episode, **kwargs):
        csv_file = episode.user_data.get("csv_file")
        sequential_episode_id = episode.user_data.get("sequential_episode_id")
        if csv_file:
            csv_file.close()
            logger.info(
                f"Ended Episode {sequential_episode_id} | Log saved to "
                f"{os.path.join(self.log_dir, f'episode_{sequential_episode_id}.csv')}"
            )

            # Generate per-episode plots
            self.generate_joint_position_plot(sequential_episode_id)
            self.generate_joint_velocity_plot(sequential_episode_id)
            self.generate_end_effector_plot(sequential_episode_id)
            self.generate_reward_plot(sequential_episode_id)
            self.generate_xyz_error_plot(sequential_episode_id)

        # Calculate total reward per agent
        episode_rewards = episode.user_data.get("episode_rewards", {})
        for agent_id, total_reward in episode_rewards.items():
            self.total_rewards_per_episode.setdefault(agent_id, []).append(total_reward)

        # Generate the plot of total rewards per episode per agent
        self.generate_total_reward_plot()

    def generate_total_reward_plot(self):
        plt.figure()
        # Check if there are any rewards recorded
        if not self.total_rewards_per_episode:
            logger.warning("No rewards to plot.")
            return

        num_episodes = len(next(iter(self.total_rewards_per_episode.values())))
        episodes = list(range(1, num_episodes + 1))

        for agent_id, rewards in self.total_rewards_per_episode.items():
            plt.plot(episodes, rewards, label=agent_id)

        plt.xlabel('Episode')
        plt.ylabel('Total Reward')
        plt.title('Total Reward per Episode per Agent')
        plt.legend()
        plt.grid(True)
        plot_filename = 'total_rewards_per_episode.png'
        plot_path = os.path.join(self.log_dir, plot_filename)
        plt.savefig(plot_path)
        plt.close()
        logger.info(f"Saved total rewards plot to {plot_path}")


    def generate_xyz_error_plot(self, episode_id: int):
        """
        Generates and saves plots for x_error, y_error, and z_error over time with x-axis in seconds.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        steps = []
        x_errors = []
        y_errors = []
        z_errors = []

        # Read data from CSV
        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    if row["agent_id"] != "agent_0":
                        continue  # Avoid duplicating entries for each agent

                    step = int(row["step"])
                    x_error = float(row["x_error"])
                    y_error = float(row["y_error"])
                    z_error = float(row["z_error"])

                    steps.append(step)
                    x_errors.append(x_error)
                    y_errors.append(y_error)
                    z_errors.append(z_error)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Convert steps to seconds
        seconds = steps_to_seconds(steps)

        # Plotting
        plt.figure(figsize=(12, 8))
        plt.plot(seconds, x_errors, label="X Error")
        plt.plot(seconds, y_errors, label="Y Error")
        plt.plot(seconds, z_errors, label="Z Error")

        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Error (meters)")
        plt.title(f"End-Effector Errors for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_xyz_errors.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved end-effector error plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save end-effector error plot for Episode {episode_id}: {e}")


    def generate_joint_position_plot(self, episode_id: int):
        """
        Generates and saves a joint position plot for the specified episode with x-axis in seconds.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        agent_data = {}
        steps = {}

        # Read data from CSV
        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    step = int(row["step"])
                    agent_id = row["agent_id"]
                    joint_pos = float(row["joint_position"])

                    if agent_id not in agent_data:
                        agent_data[agent_id] = []
                        steps[agent_id] = []

                    steps[agent_id].append(step)
                    agent_data[agent_id].append(joint_pos)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Plotting
        plt.figure(figsize=(12, 8))
        for agent_id in agent_data:
            agent_steps = steps[agent_id]
            agent_seconds = steps_to_seconds(agent_steps)
            agent_positions = agent_data[agent_id]
            plt.plot(agent_seconds, agent_positions, label=agent_id)

        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Joint Position (radians)")
        plt.title(f"Joint Positions for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_joint_positions.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved joint positions plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save joint positions plot for Episode {episode_id}: {e}")


    def generate_joint_velocity_plot(self, episode_id: int):
        """
        Generates and saves a joint velocity plot for the specified episode with x-axis in seconds.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        agent_data = {}
        steps = {}

        # Read data from CSV
        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    step = int(row["step"])
                    agent_id = row["agent_id"]
                    joint_vel = float(row["joint_velocity"])

                    if agent_id not in agent_data:
                        agent_data[agent_id] = []
                        steps[agent_id] = []

                    steps[agent_id].append(step)
                    agent_data[agent_id].append(joint_vel)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Plotting
        plt.figure(figsize=(12, 8))
        for agent_id in agent_data:
            agent_steps = steps[agent_id]
            agent_seconds = steps_to_seconds(agent_steps)
            agent_velocities = agent_data[agent_id]
            plt.plot(agent_seconds, agent_velocities, label=agent_id)

        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Joint Velocity (radians per second)")
        plt.title(f"Joint Velocities for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_joint_velocities.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved joint velocities plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save joint velocities plot for Episode {episode_id}: {e}")

    def generate_end_effector_plot(self, episode_id: int):
        """
        Generates and saves an end-effector position plot (X, Y, Z) for the specified episode with x-axis in seconds.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        # Initialize data structures
        steps = []
        x_positions = []
        y_positions = []
        z_positions = []
        target_x_positions = []
        target_y_positions = []
        target_z_positions = []

        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    step = int(row["step"])
                    x = float(row["end_effector_x"])
                    y = float(row["end_effector_y"])
                    z = float(row["end_effector_z"])
                    target_x = float(row["target_end_effector_x"])
                    target_y = float(row["target_end_effector_y"])
                    target_z = float(row["target_end_effector_z"])

                    steps.append(step)
                    x_positions.append(x)
                    y_positions.append(y)
                    z_positions.append(z)
                    target_x_positions.append(target_x)
                    target_y_positions.append(target_y)
                    target_z_positions.append(target_z)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Convert steps to seconds
        seconds = steps_to_seconds(steps)

        # Plotting
        plt.figure(figsize=(12, 8))
        plt.plot(seconds, x_positions, label="Actual X Position")
        plt.plot(seconds, target_x_positions, label="Target X Position", linestyle='--')

        plt.plot(seconds, y_positions, label="Actual Y Position")
        plt.plot(seconds, target_y_positions, label="Target Y Position", linestyle='--')

        plt.plot(seconds, z_positions, label="Actual Z Position")
        plt.plot(seconds, target_z_positions, label="Target Z Position", linestyle='--')

        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Position (meters)")
        plt.title(f"End-Effector Positions for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_end_effector_positions.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved end-effector positions plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save end-effector plot for Episode {episode_id}: {e}")



    def generate_reward_plot(self, episode_id: int):
        """
        Generates and saves a reward plot for the specified episode with x-axis in seconds.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        agent_rewards = {}
        steps = []

        # Read data from CSV
        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    step = int(row["step"])
                    agent_id = row["agent_id"]
                    reward = float(row["reward"])

                    if agent_id not in agent_rewards:
                        agent_rewards[agent_id] = []

                    agent_rewards[agent_id].append(reward)
                    if step not in steps:
                        steps.append(step)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Convert steps to seconds
        seconds = steps_to_seconds(steps)

        # Plotting
        plt.figure(figsize=(12, 8))
        for agent_id, rewards in agent_rewards.items():
            plt.plot(seconds, rewards, label=agent_id)

        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Per-Step Reward")
        plt.title(f"Per-Step Rewards for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_rewards.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved reward plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save reward plot for Episode {episode_id}: {e}")


    def generate_error_plot(self, episode_id: int):
        """
        Generates and saves an end-effector error plot for the specified episode.
        """
        file_path = os.path.join(self.log_dir, f"episode_{episode_id}.csv")
        if not os.path.exists(file_path):
            logger.error(f"CSV file for Episode {episode_id} does not exist.")
            return

        steps = []
        errors = []

        # Read data from CSV
        try:
            with open(file_path, mode='r') as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    step = int(row["step"])
                    error = float(row["end_effector_error"])

                    steps.append(step)
                    errors.append(error)
        except Exception as e:
            logger.exception(f"Failed to read CSV for Episode {episode_id}: {e}")
            return

        # Plotting
        plt.figure(figsize=(12, 8))
        plt.plot(steps, errors, label="End-Effector Error")

        plt.xlabel("Step")
        plt.ylabel("Error (meters)")
        plt.title(f"End-Effector Error for Episode {episode_id}")
        plt.legend()
        plt.grid(True)

        # Save the plot
        plot_filename = f"episode_{episode_id}_end_effector_error.png"
        plot_path = os.path.join(self.log_dir, plot_filename)
        try:
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Saved end-effector error plot for Episode {episode_id} to {plot_path}")
        except Exception as e:
            logger.exception(f"Failed to save end-effector error plot for Episode {episode_id}: {e}")


# ----------------------------
# Plotting Function for Rewards
# ----------------------------
def plot_rewards(log_dir: str, output_dir: str = "reward_plots"):
    """
    Generates and saves reward plots for each episode with x-axis in seconds.
    
    :param log_dir: Directory where CSV logs are stored.
    :param output_dir: Directory to save the generated reward plots.
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Saving reward plots to directory: {output_dir}")
    
    # Gather all CSV files corresponding to episodes
    csv_files = [f for f in os.listdir(log_dir) if f.endswith('.csv')]
    
    if not csv_files:
        logger.warning(f"No CSV files found in {log_dir} to plot rewards.")
        return
    
    for csv_file in csv_files:
        episode_id = csv_file.replace("episode_", "").replace(".csv", "")
        file_path = os.path.join(log_dir, csv_file)
        
        # Initialize data structures
        agent_rewards = {}
        steps = []
        
        # Read data from CSV
        with open(file_path, mode='r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                step = int(row["step"])
                agent_id = row["agent_id"]
                reward = float(row["reward"])
                
                if agent_id not in agent_rewards:
                    agent_rewards[agent_id] = []
                
                agent_rewards[agent_id].append(reward)
                if step not in steps:
                    steps.append(step)
        
        # Convert steps to seconds
        seconds = steps_to_seconds(steps)
        
        # Plotting
        plt.figure(figsize=(12, 8))
        for agent_id, rewards in agent_rewards.items():
            plt.plot(seconds, rewards, label=agent_id)
        
        plt.xlabel("Seconds")  # Updated label
        plt.ylabel("Per-Step Reward")
        plt.title(f"Per-Step Rewards for Episode {episode_id}")
        plt.legend()
        plt.grid(True)
        
        # Save the plot
        plot_filename = f"episode_{episode_id}_rewards.png"
        plot_path = os.path.join(output_dir, plot_filename)
        plt.savefig(plot_path)
        plt.close()
        logger.info(f"Saved reward plot for Episode {episode_id} to {plot_path}")



# ----------------------------
# Plotting Average Rewards Across Episodes
# ----------------------------
def plot_average_rewards(log_dir: str, output_dir: str = "average_reward_plots"):
    """
    Generates and saves average reward plots across all episodes with x-axis in seconds.
    
    :param log_dir: Directory where CSV logs are stored.
    :param output_dir: Directory to save the generated plots.
    """
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Saving average reward plots to directory: {output_dir}")
    
    csv_files = [f for f in os.listdir(log_dir) if f.endswith('.csv')]
    
    if not csv_files:
        logger.warning(f"No CSV files found in {log_dir} to plot average rewards.")
        return
    
    # Initialize data structures
    agent_rewards_over_episodes = {f"agent_{i}": [] for i in range(6)}
    max_steps = 0
    
    for csv_file in csv_files:
        file_path = os.path.join(log_dir, csv_file)
        with open(file_path, mode='r') as csvfile:
            reader = csv.DictReader(csvfile)
            episode_rewards = {agent_id: [] for agent_id in agent_rewards_over_episodes.keys()}
            for row in reader:
                agent_id = row["agent_id"]
                reward = float(row["reward"])
                episode_rewards[agent_id].append(reward)
        
        for agent_id, rewards in episode_rewards.items():
            agent_rewards_over_episodes[agent_id].append(rewards)
            if len(rewards) > max_steps:
                max_steps = len(rewards)
    
    # Calculate average rewards per step for each agent
    average_rewards = {}
    for agent_id, rewards_list in agent_rewards_over_episodes.items():
        # Pad shorter episodes with np.nan for averaging
        padded_rewards = np.full((len(rewards_list), max_steps), np.nan)
        for i, rewards in enumerate(rewards_list):
            padded_rewards[i, :len(rewards)] = rewards
        average_rewards[agent_id] = np.nanmean(padded_rewards, axis=0)
    
    # Convert steps to seconds
    seconds = steps_to_seconds(range(1, max_steps + 1))
    
    # Plotting
    plt.figure(figsize=(12, 8))
    for agent_id, avg_rewards in average_rewards.items():
        plt.plot(seconds, avg_rewards, label=agent_id)
    
    plt.xlabel("Seconds")  # Updated label
    plt.ylabel("Average Per-Step Reward")
    plt.title("Average Per-Step Rewards Across Episodes")
    plt.legend()
    plt.grid(True)
    
    # Save the plot
    plot_filename = "average_rewards_across_episodes.png"
    plot_path = os.path.join(output_dir, plot_filename)
    plt.savefig(plot_path)
    plt.close()
    logger.info(f"Saved average reward plot to {plot_path}")

# ----------------------------
# Policy Mapping Function
# ----------------------------
def policy_mapping_fn(agent_id, *args, **kwargs):
    """
    Maps each agent to its corresponding policy based on agent ID.
    """
    return agent_id



def steps_to_seconds(steps, steps_per_second=15):
    """
    Converts step counts to seconds.
    
    :param steps: List or array of step numbers.
    :param steps_per_second: Number of steps that correspond to one second.
    :return: List or array of time in seconds.
    """
    return np.array(steps) / steps_per_second


# ----------------------------
# Main Training Function
# ----------------------------
def main():
    """
    Main function to initialize Ray, configure the PPO trainer, and execute the training loop.
    """
    # Initialize Ray if not already initialized
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True)
        logger.info("Ray initialized.")

    config = {
            "env": MultiAgentUR5eEnv,
            "env_config": {},
            "multiagent": {
                "policies": {
                    f"agent_{i}": (
                        None,
                        spaces.Box(low=np.array([-2*np.pi, -6]), high=np.array([2*np.pi, 6]), dtype=np.float32),  # Observation space
                        spaces.Box(low=-0.1, high=0.1, shape=(1,), dtype=np.float32),  # Action space (delta positions)
                        {}
                    ) for i in range(6)
                },
                "policy_mapping_fn": policy_mapping_fn,
                "replay_mode": "independent",
            },
            "framework": "torch",  # or "tf"
            "num_workers": 1,
            "callbacks": JointPositionLogger,
            "callbacks_config": {
                "log_dir": "joint_logs_PPO",
            },
            # SAC-specific configurations
            "lr": 4e-5,
            "train_batch_size": 1000,
            "gamma": 0.9,
            "tau": 0.005,
            "target_entropy": "auto",
            "n_step": 1,
            "no_done_at_end": True,
            "prioritized_replay": False,
            "optimization": {
                "actor_learning_rate": 4e-5,
                "critic_learning_rate": 4e-5,
                "entropy_learning_rate": 4e-5,
            },
            "model": {
                "fcnet_hiddens": [128, 128],
                "fcnet_activation": "tanh",
            },
        }


    # Initialize trainer
    trainer = PPO(config=config)
    logger.info("PPO trainer initialized.")

    # Training loop parameters
    total_timesteps = 6_000_000  # Reduced total timesteps
    checkpoint_interval = 20_000  # Save checkpoint every 20,000 timesteps

    # Training loop
    logger.info(f"Starting training for {total_timesteps} timesteps...")
    timesteps = 0
    try:
        while timesteps < total_timesteps:
            result = trainer.train()
            timesteps = result["timesteps_total"]
            logger.info(f"Completed {timesteps} timesteps")

            # Save checSACkpoint at intervals
            if timesteps % checkpoint_interval == 0:
                checkpoint_path = trainer.save()
                logger.info(f"Checkpoint saved at {checkpoint_path}")
                
                # Optionally, plot average rewards across episodes
                plot_average_rewards(
                    log_dir="joint_logs_PPO", 
                    output_dir="average_reward_plots"
                )
    except KeyboardInterrupt:
        logger.info("Training interrupted by user.")
    except Exception as e:
        logger.error(f"An error occurred during training: {e}")
        raise
    finally:
        logger.info("Training completed or interrupted. Cleaning up.")
        trainer.cleanup()
        ray.shutdown()
        logger.info("Ray shutdown.")

        # Optionally, generate a final average reward plot
        plot_average_rewards(
            log_dir="joint_logs_PPO", 
            output_dir="average_reward_plots"
        )
# Execute the Training
if __name__ == "__main__":
    main()
