In [1]:
import os
import time
import math
import numpy as np
import gymnasium as gym
from gymnasium import spaces

import pybullet as p
import pybullet_data

import os
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import multiprocessing as mp

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor

pybullet build time: Jan 29 2025 23:16:28


In [2]:
class PandaPickCubeEnv(gym.Env):
    """
    Franka Panda end-effector control with IK.
    Task: pick up a cuboid (random size) from a table and move it to a target on the table.

    Notes:
      - CPU-based physics (PyBullet) -> scale with SubprocVecEnv for throughput.
      - Observation includes cube size to allow generalization across sizes.
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}

    def __init__(
        self,
        render_mode=None,
        episode_len=200,
        sim_steps_per_action=8,
        cube_size_range=(0.02, 0.06),  # half-extent range in meters
        target_xy_range=0.15,
        seed=None,
    ):
        super().__init__()
        self.render_mode = render_mode
        self.episode_len = episode_len
        self.sim_steps_per_action = sim_steps_per_action
        self.cube_size_range = cube_size_range
        self.target_xy_range = target_xy_range

        self._rng = np.random.default_rng(seed)

        # Action: delta EE xyz + gripper command
        # dx,dy,dz in [-1,1] scaled internally, grip in [-1,1] (close/open)
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)

        # Observation: ee_pos(3), ee_vel(3), cube_pos(3), cube_vel(3), target_pos(3), cube_size(3), gripper_width(1)
        obs_dim = 3 + 3 + 3 + 3 + 3 + 3 + 1
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)

        self._client = None
        self._step_count = 0

        # IDs
        self.panda = None
        self.cube_id = None
        self.table_id = None

        # Panda joints
        self.arm_joint_indices = [0, 1, 2, 3, 4, 5, 6]
        self.finger_joint_indices = [9, 10]  # depends on URDF; works for standard panda in pybullet_data
        self.ee_link_index = 11  # panda_hand
        self.max_gripper_opening = 0.08  # approximate

        # Goal
        self.target_pos = None

        # Internal control scaling
        self.ee_delta_scale = np.array([0.02, 0.02, 0.02], dtype=np.float32)  # meters per action
        self.ee_min = np.array([0.25, -0.3, 0.05], dtype=np.float32)
        self.ee_max = np.array([0.75, 0.3, 0.55], dtype=np.float32)

    def _connect(self):
        if self._client is not None:
            return
        if self.render_mode == "human":
            self._client = p.connect(p.GUI)
        else:
            self._client = p.connect(p.DIRECT)

        p.resetSimulation(physicsClientId=self._client)
        p.setAdditionalSearchPath(pybullet_data.getDataPath(), physicsClientId=self._client)
        p.setGravity(0, 0, -9.81, physicsClientId=self._client)
        p.setTimeStep(1.0 / 240.0, physicsClientId=self._client)

    def _load_scene(self):
        p.resetSimulation(physicsClientId=self._client)
        p.setAdditionalSearchPath(pybullet_data.getDataPath(), physicsClientId=self._client)
        p.setGravity(0, 0, -9.81, physicsClientId=self._client)
        p.setTimeStep(1.0 / 240.0, physicsClientId=self._client)

        p.loadURDF("plane.urdf", physicsClientId=self._client)

        # Table
        table_urdf = os.path.join(pybullet_data.getDataPath(), "table/table.urdf")
        self.table_id = p.loadURDF(table_urdf, basePosition=[0.5, 0.0, -0.65], useFixedBase=True, physicsClientId=self._client)

        # Panda
        self.panda = p.loadURDF(
            "franka_panda/panda.urdf",
            basePosition=[0.0, 0.0, 0.0],
            useFixedBase=True,
            physicsClientId=self._client,
        )

        # Set joint damping
        for j in range(p.getNumJoints(self.panda, physicsClientId=self._client)):
            p.changeDynamics(self.panda, j, linearDamping=0.04, angularDamping=0.04, physicsClientId=self._client)

        # Default pose
        home = [0.0, -0.6, 0.0, -2.2, 0.0, 1.6, 0.8]
        for idx, q in zip(self.arm_joint_indices, home):
            p.resetJointState(self.panda, idx, q, physicsClientId=self._client)

        # Open gripper
        self._set_gripper(opening=self.max_gripper_opening)

        # Spawn cube with randomized half-extents (variable cuboid size)
        hx = float(self._rng.uniform(*self.cube_size_range))
        hy = float(self._rng.uniform(*self.cube_size_range))
        hz = float(self._rng.uniform(*self.cube_size_range))
        self.cube_half_extents = np.array([hx, hy, hz], dtype=np.float32)

        cube_col = p.createCollisionShape(p.GEOM_BOX, halfExtents=[hx, hy, hz], physicsClientId=self._client)
        cube_vis = p.createVisualShape(p.GEOM_BOX, halfExtents=[hx, hy, hz], rgbaColor=[0.8, 0.2, 0.2, 1.0], physicsClientId=self._client)

        cube_x = float(self._rng.uniform(0.45, 0.65))
        cube_y = float(self._rng.uniform(-0.15, 0.15))
        cube_z = 0.02  # slightly above table surface
        self.cube_id = p.createMultiBody(
            baseMass=0.2,
            baseCollisionShapeIndex=cube_col,
            baseVisualShapeIndex=cube_vis,
            basePosition=[cube_x, cube_y, cube_z],
            baseOrientation=p.getQuaternionFromEuler([0, 0, float(self._rng.uniform(-math.pi, math.pi))]),
            physicsClientId=self._client,
        )
        p.changeDynamics(self.cube_id, -1, lateralFriction=1.0, spinningFriction=0.01, rollingFriction=0.01, physicsClientId=self._client)

        # Sample target on table (xy random, z fixed at table height)
        tx = float(self._rng.uniform(0.45, 0.65))
        ty = float(self._rng.uniform(-self.target_xy_range, self.target_xy_range))
        tz = 0.02
        self.target_pos = np.array([tx, ty, tz], dtype=np.float32)

    def _set_gripper(self, opening: float):
        opening = float(np.clip(opening, 0.0, self.max_gripper_opening))
        # Each finger moves half the opening
        finger_pos = opening / 2.0
        for j in self.finger_joint_indices:
            p.setJointMotorControl2(
                self.panda, j,
                p.POSITION_CONTROL,
                targetPosition=finger_pos,
                force=80,
                physicsClientId=self._client
            )

    def _get_gripper_opening(self):
        s9 = p.getJointState(self.panda, self.finger_joint_indices[0], physicsClientId=self._client)[0]
        s10 = p.getJointState(self.panda, self.finger_joint_indices[1], physicsClientId=self._client)[0]
        return float(s9 + s10)

    def _get_ee_state(self):
        link = p.getLinkState(self.panda, self.ee_link_index, computeLinkVelocity=1, physicsClientId=self._client)
        pos = np.array(link[4], dtype=np.float32)
        vel = np.array(link[6], dtype=np.float32)
        return pos, vel

    def _ik_to(self, ee_pos, ee_orn=None):
        if ee_orn is None:
            # fixed orientation: gripper pointing down
            ee_orn = p.getQuaternionFromEuler([math.pi, 0, 0])
        joint_poses = p.calculateInverseKinematics(
            self.panda,
            self.ee_link_index,
            ee_pos.tolist(),
            ee_orn,
            maxNumIterations=50,
            residualThreshold=1e-4,
            physicsClientId=self._client,
        )
        return joint_poses

    def _apply_action(self, action):
        action = np.asarray(action, dtype=np.float32)
        dxyz = action[:3] * self.ee_delta_scale
        grip_cmd = float(action[3])

        ee_pos, _ = self._get_ee_state()
        target_ee = np.clip(ee_pos + dxyz, self.ee_min, self.ee_max)

        joint_poses = self._ik_to(target_ee)

        for i, j in enumerate(self.arm_joint_indices):
            p.setJointMotorControl2(
                self.panda, j,
                p.POSITION_CONTROL,
                targetPosition=float(joint_poses[j]),
                force=200,
                physicsClientId=self._client
            )

        # Gripper: grip_cmd < 0 close, > 0 open
        current_open = self._get_gripper_opening()
        delta_open = 0.01 * grip_cmd
        self._set_gripper(current_open + delta_open)

    def _get_obs(self):
        ee_pos, ee_vel = self._get_ee_state()
        cube_pos, cube_orn = p.getBasePositionAndOrientation(self.cube_id, physicsClientId=self._client)
        cube_vel_lin, cube_vel_ang = p.getBaseVelocity(self.cube_id, physicsClientId=self._client)

        cube_pos = np.array(cube_pos, dtype=np.float32)
        cube_vel = np.array(cube_vel_lin, dtype=np.float32)

        target = self.target_pos.astype(np.float32)
        cube_size = (2.0 * self.cube_half_extents).astype(np.float32)  # full extents
        grip_open = np.array([self._get_gripper_opening()], dtype=np.float32)

        obs = np.concatenate([ee_pos, ee_vel, cube_pos, cube_vel, target, cube_size, grip_open], axis=0)
        return obs

    def _compute_reward_done_info(self):
        ee_pos, _ = self._get_ee_state()
        cube_pos, _ = p.getBasePositionAndOrientation(self.cube_id, physicsClientId=self._client)
        cube_pos = np.array(cube_pos, dtype=np.float32)

        # Distances
        d_ee_cube = float(np.linalg.norm(ee_pos - cube_pos))
        d_cube_goal = float(np.linalg.norm(cube_pos - self.target_pos))

        # "Lifted" heuristic
        lifted = cube_pos[2] > 0.08

        # Dense reward shaping
        reward = -1.0 * d_ee_cube - 2.0 * d_cube_goal
        if lifted:
            reward += 1.0

        success = (d_cube_goal < 0.05) and (cube_pos[2] < 0.06)
        if success:
            reward += 10.0

        terminated = bool(success)
        truncated = bool(self._step_count >= self.episode_len)

        info = {"success": success, "d_ee_cube": d_ee_cube, "d_cube_goal": d_cube_goal}
        return reward, terminated, truncated, info

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self._rng = np.random.default_rng(seed)
        self._connect()
        self._load_scene()
        self._step_count = 0

        obs = self._get_obs()
        info = {}
        return obs, info

    def step(self, action):
        self._apply_action(action)
        for _ in range(self.sim_steps_per_action):
            p.stepSimulation(physicsClientId=self._client)
            if self.render_mode == "human":
                time.sleep(1.0 / 240.0)

        self._step_count += 1
        obs = self._get_obs()
        reward, terminated, truncated, info = self._compute_reward_done_info()
        return obs, reward, terminated, truncated, info

    def render(self):
        if self.render_mode != "rgb_array":
            return None
        # Simple camera view
        view = p.computeViewMatrixFromYawPitchRoll(
            cameraTargetPosition=[0.55, 0.0, 0.1],
            distance=0.8,
            yaw=45,
            pitch=-35,
            roll=0,
            upAxisIndex=2,
            physicsClientId=self._client,
        )
        proj = p.computeProjectionMatrixFOV(fov=60, aspect=16/9, nearVal=0.01, farVal=3.0)
        w, h, rgba, _, _ = p.getCameraImage(960, 540, viewMatrix=view, projectionMatrix=proj, physicsClientId=self._client)
        img = np.array(rgba, dtype=np.uint8).reshape(h, w, 4)
        return img

    def close(self):
        if self._client is not None:
            p.disconnect(physicsClientId=self._client)
            self._client = None


In [3]:
def make_env(rank: int, seed: int = 0):
    def _init():
        env = PandaPickCubeEnv(render_mode=None, seed=seed + rank)
        return env
    return _init

In [5]:
mp.set_start_method("forkserver", force=True)

n_envs = min(32, mp.cpu_count())

env = SubprocVecEnv([make_env(i) for i in range(n_envs)])
env = VecMonitor(env)

model = PPO(
    "MlpPolicy",
    env,
    device="cpu",
    verbose=1,
    n_steps=1024,
    batch_size=4096,
    n_epochs=10,
    learning_rate=3e-4,
    gamma=0.99,
)

model.learn(total_timesteps=2_000_000)
model.save("ppo_panda_pick")
env.close()

pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 23:16:28
pybullet build time: Jan 29 2025 2

Using cpu device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 180      |
|    ep_rew_mean     | -142     |
| time/              |          |
|    fps             | 4344     |
|    iterations      | 1        |
|    time_elapsed    | 7        |
|    total_timesteps | 30720    |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 178          |
|    ep_rew_mean          | -128         |
| time/                   |              |
|    fps                  | 4218         |
|    iterations           | 2            |
|    time_elapsed         | 14           |
|    total_timesteps      | 61440        |
| train/                  |              |
|    approx_kl            | 0.0050018444 |
|    clip_fraction        | 0.0303       |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.67        |
|    explained_variance   | 0.000636     

In [7]:
# IMPORTANT: evaluation env must be single-process for video
env = PandaPickCubeEnv(render_mode="rgb_array")

env = RecordVideo(
    env,
    video_folder="./rl_videos",
    episode_trigger=lambda ep: ep == 0,
    name_prefix="panda-pick",
    disable_logger=False,
)

model = PPO.load("ppo_panda_pick", device="cpu")

obs, _ = env.reset()
done = False
while True:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        break

env.close()
print("Saved video(s) to ./rl_videos")

  logger.warn(


MoviePy - Building video /home/shreyak/rl_videos/panda-pick-episode-0.mp4.
MoviePy - Writing video /home/shreyak/rl_videos/panda-pick-episode-0.mp4



                                                                                                                                                                                      

MoviePy - Done !
MoviePy - video ready /home/shreyak/rl_videos/panda-pick-episode-0.mp4
Saved video(s) to ./rl_videos
