In [None]:
import gymnasium as gym
from gym import spaces
import numpy as np
import pybullet as p
import pybullet_data
import os


class UR3SortingEnv(gym.Env):
    def __init__(self):
        super().__init__()
        if p.isConnected():
            p.disconnect()
        # Action space: Mỗi khớp có thể điều khiển với một giá trị (-1, 1)
        self.action_space = spaces.Box(low=-1, high=1, shape=(6,), dtype=np.float32)
        
        # Observation space: Trạng thái robot + thông tin vật (vị trí, màu sắc)
        self.observation_space = spaces.Dict({
            "robot_state": spaces.Box(low=-1, high=1, shape=(6,), dtype=np.float32),  # 6 khớp
            "object_info": spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)    # x, y, z, màu sắc
        })
        
        # PyBullet settings
        self.client = p.connect(p.GUI)
        if self.client < 0:
            raise RuntimeError("Failed to connect to PyBullet physics server.")
        
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.setGravity(0, 0, -9.81)
        pb_data_path = pybullet_data.getDataPath()
        project_path = os.getcwd()
        # Load environment
        self.robot_uid = p.loadURDF(os.path.join(project_path, "urdf/ur5_rg2.urdf"), useFixedBase=True)
        self.table_uid = p.loadURDF(os.path.join(pb_data_path, "table/table.urdf"), basePosition=[0.5, 0, -0.65])
        self.tray_uid = p.loadURDF(os.path.join(pb_data_path, "tray/traybox.urdf"), basePosition=[0.65, 0, 0])
        self.objects = self._load_objects()
        
        
        self.joint_index_last             = 13
        self.joint_index_endeffector_base = 7
        p.resetDebugVisualizerCamera(cameraDistance=1.5, cameraYaw=0, cameraPitch=-40, cameraTargetPosition=[0.55, -0.35, 0.2])

        self.rest_pose:np.ndarray = np.array((
            0.0,    # Base  (Fixed)
            0.0,    # Joint 1
            -2.094, # Joint 2
            1.57,   # Joint 3
            -1.047, # Joint 4
            -1.57,  # Joint 5
            0,      # Joint 6
            0.0,    # EE Base (Fixed)
            0.785,  # EE Finger
        ))
        self.reset()

    def _load_objects(self):
        """Load objects with different colors."""
        objects = []
        positions = []
        for i, color in enumerate([[1, 0, 0], [0, 1, 0], [0, 0, 1]]):  # Red, Green, Blue
            while True:
                # Tạo ngẫu nhiên vị trí đối tượng
                pos = [np.random.uniform(-0.6, 0.6), 0, np.random.uniform(-0.2, 0.2)]
                # Kiểm tra xem vị trí này có chồng chéo với các vị trí khác hay không
                if all(np.linalg.norm(np.array(pos) - np.array(existing_pos)) > 0.5 for existing_pos in positions):
                    positions.append(pos)
                    break
            
            obj_uid = p.loadURDF("random_urdfs/000/000.urdf", basePosition=pos)
            p.changeVisualShape(obj_uid, -1, rgbaColor=color + [1])  # Set object color
            objects.append({"uid": obj_uid, "color": color})
        return objects
    def reset(self):
        """Reset environment."""
        # Reset robot to initial position
        self.current_pose:np.ndarray = np.copy(self.rest_pose)
        for i in range(6):  # 6 khớp điều khiển
            p.resetJointState(self.robot_uid, i, self.current_pose[i])
            
        self._finger_control(self.current_pose[self.joint_index_endeffector_base+1])
        # Reset objects
        
        for obj in self.objects:
            pos = [0.6, np.random.uniform(-0.2, 0.2), 0.1]
            p.resetBasePositionAndOrientation(obj["uid"], pos, [0, 0, 0, 1])
        

        return self._get_observation()

    def _finger_control(self, target):
        '''
        Control the finger joints to target position.
        This is to imitate the mimic joint in ROS.
        '''
        # Just a hardcoded control...
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+1, p.POSITION_CONTROL, 
                                targetPosition = target)
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+4, p.POSITION_CONTROL, 
                                targetPosition = target)
        # Get the current joint pos and vel
        finger_left = p.getJointState(self.robot_uid, self.joint_index_endeffector_base+1)
        finger_right = p.getJointState(self.robot_uid, self.joint_index_endeffector_base+4)
        # Propagate it to the other joints.
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+2, p.POSITION_CONTROL, 
                                targetPosition = finger_left[0], 
                                targetVelocity = finger_left[1],
                                positionGain=1.2)
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+3, p.POSITION_CONTROL, 
                                targetPosition = finger_left[0], 
                                targetVelocity = finger_left[1],
                                positionGain=1.2)
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+5, p.POSITION_CONTROL, 
                                targetPosition = finger_right[0], 
                                targetVelocity = finger_right[1],
                                positionGain=1.2)
        p.setJointMotorControl2(self.robot_uid, self.joint_index_endeffector_base+6, p.POSITION_CONTROL, 
                                targetPosition = finger_right[0], 
                                targetVelocity = finger_right[1],
                                positionGain=1.2)
        
    def _get_observation(self):
        """Get current observation."""
        # Robot joint positions
        joint_states = [p.getJointState(self.robot_uid, i)[0] for i in range(6)]
        robot_state = np.array(joint_states) / np.pi  # Normalize to [-1, 1]
        
        # Object position and color
        obj = self.objects[0]  # For simplicity, focus on the first object
        pos, _ = p.getBasePositionAndOrientation(obj["uid"])
        color = obj["color"]
        object_info = np.array(list(pos) + [color[0]])  # Include only red channel as example
        
        return {"robot_state": robot_state, "object_info": object_info}

    def step(self, action):
        """Execute one step."""
        # Scale action to joint range
        joint_positions = np.clip(action, -1, 1) * np.pi  # Scale to [-π, π]
        
        # Apply action
        for i in range(6):
            p.setJointMotorControl2(self.robot_uid, i, p.POSITION_CONTROL, targetPosition=joint_positions[i])
        p.stepSimulation()
        
        # Get observation
        obs = self._get_observation()
        
        # Compute reward
        reward = self._compute_reward(obs)
        
        # Check termination
        done = self._check_termination(obs)
        
        return obs, reward, done, {}

    def _compute_reward(self, obs):
        """Compute reward based on task progress."""
        # Example: Reward for proximity to the object and correct classification
        obj_pos = obs["object_info"][:3]
        dist_to_obj = np.linalg.norm(obj_pos - np.array([0.6, 0, 0.1]))  # Desired position
        is_correct_color = obs["object_info"][3] > 0.8  # Red color
        reward = -dist_to_obj
        if is_correct_color:
            reward += 1
        return reward

    def _check_termination(self, obs):
        """Check if episode should terminate."""
        obj_pos = obs["object_info"][:3]
        dist_to_obj = np.linalg.norm(obj_pos - np.array([0.6, 0, 0.1]))
        return dist_to_obj < 0.05  # Close enough to classify

    def render(self, mode="human"):
        pass  # GUI rendering is handled by PyBullet
    
    def close(self):
        p.disconnect(self.client)

In [None]:
from stable_baselines3 import PPO

env = UR3SortingEnv()
print("Environment initialized successfully!")
# model = PPO("MultiInputPolicy", env, verbose=1)
# model.learn(total_timesteps=10000)
# model.save("ur3_sorting_model")


# SQL 


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import count
from collections import deque
import random
from tensorboardX import SummaryWriter
from torch.distributions import Categorical
import gym
import numpy as np


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Memory(object):
    def __init__(self, memory_size: int) -> None:
        self.memory_size = memory_size
        self.buffer = deque(maxlen=self.memory_size)

    def add(self, experience) -> None:
        self.buffer.append(experience)

    def size(self):
        return len(self.buffer)

    def sample(self, batch_size: int, continuous: bool = True):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        if continuous:
            rand = random.randint(0, len(self.buffer) - batch_size)
            return [self.buffer[i] for i in range(rand, rand + batch_size)]
        else:
            indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
            return [self.buffer[i] for i in indexes]

    def clear(self):
        self.buffer.clear()


class SoftQNetwork(nn.Module):
    def __init__(self):
        super(SoftQNetwork, self).__init__()
        self.alpha = 4
        self.fc1 = nn.Linear(4, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 256)
        self.fc3 = nn.Linear(256, 2)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def getV(self, q_value):
        v = self.alpha * torch.log(torch.sum(torch.exp(q_value/self.alpha), dim=1, keepdim=True))
        return v
        
    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        # print('state : ', state)
        with torch.no_grad():
            q = self.forward(state)
            v = self.getV(q).squeeze()
            # print('q & v', q, v)
            dist = torch.exp((q-v)/self.alpha)
            # print(dist)
            dist = dist / torch.sum(dist)
            # print(dist)
            c = Categorical(dist)
            a = c.sample()
        return a.item()


if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    onlineQNetwork = SoftQNetwork().to(device)
    targetQNetwork = SoftQNetwork().to(device)
    targetQNetwork.load_state_dict(onlineQNetwork.state_dict())

    optimizer = torch.optim.Adam(onlineQNetwork.parameters(), lr=1e-4)

    GAMMA = 0.99
    REPLAY_MEMORY = 50000
    BATCH = 16
    UPDATE_STEPS = 4

    memory_replay = Memory(REPLAY_MEMORY)
    writer = SummaryWriter('logs/sql')

    learn_steps = 0
    begin_learn = False
    episode_reward = 0

    for epoch in count():
        state, _ = env.reset()
        if state is None or len(state) != 4:
            print("State is invalid:", state)
        episode_reward = 0
        for time_steps in range(200):
            action = onlineQNetwork.choose_action(state)
            # print(action)
            next_state, reward, done, _, _ = env.step(action)
            episode_reward += reward
            if next_state is None or len(next_state) != 4:
                print("Next state is invalid:", next_state)
                break
            memory_replay.add((state, next_state, action, reward, done))
            state = next_state
            # try:
            #     action = onlineQNetwork.choose_action(state)
            #     next_state, reward, done, _, _ = env.step(action)
            #     episode_reward += reward
            #     if next_state is None or len(next_state) != 4:
            #         print("Next state is invalid:", next_state)
            #         break
            #     memory_replay.add((state, next_state, action, reward, done))
            #     state = next_state
            # except Exception as e:
            #     # print(f"Error occurred at state: {state}, action: {action}")
            #     break

            if memory_replay.size() > 128:
                if begin_learn is False:
                    print('learn begin!')
                    begin_learn = True
                learn_steps += 1
                if learn_steps % UPDATE_STEPS == 0:
                    targetQNetwork.load_state_dict(onlineQNetwork.state_dict())
                batch = memory_replay.sample(BATCH, False)
                batch_state, batch_next_state, batch_action, batch_reward, batch_done = zip(*batch)

                batch_state = torch.FloatTensor(batch_state).to(device)
                batch_next_state = torch.FloatTensor(batch_next_state).to(device)
                batch_action = torch.FloatTensor(batch_action).unsqueeze(1).to(device)
                batch_reward = torch.FloatTensor(batch_reward).unsqueeze(1).to(device)
                batch_done = torch.FloatTensor(batch_done).unsqueeze(1).to(device)

                with torch.no_grad():
                    next_q = targetQNetwork(batch_next_state)
                    next_v = targetQNetwork.getV(next_q)
                    y = batch_reward + (1 - batch_done) * GAMMA * next_v

                loss = F.mse_loss(onlineQNetwork(batch_state).gather(1, batch_action.long()), y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                writer.add_scalar('loss', loss.item(), global_step=learn_steps)
            
            if done:
                break
            
            state = next_state
        writer.add_scalar('episode reward', episode_reward, global_step=epoch)
        if epoch % 10 == 0:
            torch.save(onlineQNetwork.state_dict(), 'sql-policy.para')
            print('Ep {}\tMoving average score: {:.2f}\t'.format(epoch, episode_reward))






