In [None]:
import gymnasium as gym
import numpy as np
import mediapy as media
import gym_lite6.env, gym_lite6.scripted_policy, gym_lite6.pickup_task
# Had to export this before starting jupyter server
%env MUJOCO_GL=egl 
import mujoco


In [None]:
from importlib import reload

reload(gym_lite6.env)
reload(gym_lite6.utils)
reload(gym_lite6.scripted_policy)
reload(gym_lite6.pickup_task)

task = gym_lite6.pickup_task.GraspTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')
# task = gym_lite6.pickup_task.GraspAndLiftTask('gripper_left_finger', 'gripper_right_finger', 'box', 'floor')

env = gym.make(
    "UfactoryCubePickup-v0",
    task=task,
    obs_type="pixels_state",
    action_type="qvel",
    max_episode_steps=100,
    visualization_width=320,
    visualization_height=240,
    render_fps=10
)

observation, info = env.reset()
media.show_image(env.unwrapped.render(camera="side_cam"))


In [None]:
# qpos0 = None
# box_pos0 = None
# box_quat0 = None

qpos0 = np.array([0, 0.541, 1.49 , 2.961, 0.596, 0.203])
box_pos0 = np.array([0.2, 0, 0.0])
box_quat0 = None

## Run a scripted rollout
Important hyperparams:
- kv in the actuator gains in the model XML - higher makes it reach setpoint with greater precision
- Kp, Kp_ang
- damping - higher gives smoother motions/less craziness around singularities

In [None]:
# Run a scripted rollout

policy = gym_lite6.scripted_policy.GraspPolicy(env, 'end_effector', 'box', 'gripper_left_finger', 'gripper_right_finger', max_vel=0.3)
# policy = gym_lite6.scripted_policy.GraspAndLiftPolicy(env, 'end_effector', 'box', 'gripper_left_finger', 'gripper_right_finger', max_vel=0.2)

# Reset the policy and environmens to prepare for rollout
policy.reset()
observation, info = env.reset(seed=69, qpos=qpos0, box_pos=box_pos0, box_quat=box_quat0)

action = {}
Kp = 0.4
Kp_ang = 0.4
# Ki = 0
# i_bounds = [-1,1]
# i_error = np.zeros(env.unwrapped.dof)


frames = []
# Render frame of the initial state
frames.append(env.render())

step = 0
done = False

ep_dict = {"action.qpos": [], "action.qvel": [], "action.gripper": [],"action.ee_vel": [], "action.ee_ang_vel": [],
           "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.pixels.side": [], "observation.pixels.gripper": [],
           "observation.ee_pose.pos": [], "observation.ee_pose.quat": [], "observation.ee_pose.vel": [], "observation.ee_pose.ang_vel": [],
           "reward": [], "timestamp": [], "frame_index": [],
           }

while not done:
  action = policy(env.unwrapped.model, env.unwrapped.data, observation, info)
  delta = action["pos"] - observation["ee_pose"]["pos"]

  # Quaternion error
  quat_err = np.empty(4)
  curr_quat = observation["ee_pose"]["quat"]
  curr_quat_conj = np.empty(4)
  ang_delta = np.empty(3)
  mujoco.mju_negQuat(curr_quat_conj, curr_quat)
  mujoco.mju_mulQuat(quat_err, action["quat"], curr_quat_conj)
  mujoco.mju_quat2Vel(ang_delta, quat_err, 1.0)
  
  # Convert to velocity (in world frame)
  vel = Kp * delta * env.metadata["render_fps"]
  ang_vel = Kp_ang * ang_delta * env.metadata["render_fps"]
  # Transform to end effector frame
  action["qvel"] = env.unwrapped.solve_ik_vel(vel, ang_vel, ref_frame='end_effector', local=False, damping=4e-4)

  # for i in range(3):
  observation, reward, terminated, truncated, info = env.step(action)
    # The rollout is considered done when the success state is reach (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
  done = terminated | truncated | done
    # if done:
    #   break

  ep_dict["action.qpos"].append(action["qpos"])
  ep_dict["action.qvel"].append(action["qvel"])
  ep_dict["action.ee_vel"].append(vel)
  ep_dict["action.ee_ang_vel"].append(ang_vel)
  ep_dict["action.gripper"].append(action["gripper"])
  ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
  ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
  ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
  ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
  ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
  ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
  ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
  ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
  ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
  ep_dict["reward"].append(reward)
  ep_dict["timestamp"].append(env.unwrapped.data.time)
  ep_dict["frame_index"].append(step)

  print(f"{step=} {reward=} {terminated=}")

  step += 1

if terminated:
  print("Success!")
else:
  print(f"Failure! Reached {policy.stage}")

media.show_video(ep_dict["observation.pixels.side"], fps=env.metadata["render_fps"])
media.show_video(ep_dict["observation.pixels.gripper"], fps=env.metadata["render_fps"])


In [None]:
print(ep_dict.keys())

gym_lite6.utils.plot_dict_of_arrays(ep_dict, "timestamp", keys=["action.ee_vel", "action.ee_ang_vel", "action.qvel", "observation.state.qpos", "observation.state.qvel", "observation.state.gripper", "reward"], sharey=False)


## Record to HF dataset

In [None]:
import lerobot.common.datasets.push_dataset_to_hub.utils
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import (
    hf_transform_to_torch,
)
import os
import datetime

def record_episodes_to_hf(env, policy, dataset_dir, num_ep=1, num_frames=300):
  features = {}
  num_readings = num_ep*num_frames
  features["observation.pixels.side"] = Image()
  features["observation.pixels.gripper"] = Image()
  features["observation.state.qpos"] = Sequence(length=6, feature=Value(dtype="float32", id=None))
  features["observation.state.qvel"] = Sequence(length=6, feature=Value(dtype="float32", id=None))
  features["observation.state.gripper"] = Value(dtype="int8", id=None)
  features["observation.ee_pose.pos"] = Sequence(length=3, feature=Value(dtype="float32", id=None))
  features["observation.ee_pose.quat"] = Sequence(length=4, feature=Value(dtype="float32", id=None))
  features["observation.ee_pose.vel"] = Sequence(length=3, feature=Value(dtype="float32", id=None))
  features["observation.ee_pose.ang_vel"] = Sequence(length=3, feature=Value(dtype="float32", id=None))
  features["reward"] = Value(dtype="int8", id=None)
  features["action.qpos"] = Sequence(length=6, feature=Value(dtype="float32", id=None))
  features["action.qvel"] = Sequence(length=6, feature=Value(dtype="float32", id=None))
  features["action.ee_vel"] = Sequence(length=3, feature=Value(dtype="float32", id=None))
  features["action.ee_ang_vel"] = Sequence(length=3, feature=Value(dtype="float32", id=None))
  features["action.gripper"] = Value(dtype="int8", id=None)
  features["episode_index"] = Value(dtype="int64", id=None) # Which episode
  features["frame_index"] = Value(dtype="int64", id=None) # Which frame within episode
  features["timestamp"] = Value(dtype="float32", id=None)
  features["index"] = Value(dtype="int64", id=None) # Which frame in the whole datasets

  if not os.path.isdir(dataset_dir):
    os.makedirs(dataset_dir, exist_ok=True)

  successful_trajectories = 0
  data_dict = {
      "action.qpos": [], "action.qvel": [], "action.gripper": [],"action.ee_vel": [], "action.ee_ang_vel": [],
      "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.pixels.side": [], "observation.pixels.gripper": [],
      "observation.ee_pose.pos": [], "observation.ee_pose.quat": [], "observation.ee_pose.vel": [], "observation.ee_pose.ang_vel": [],
      "reward": [], "timestamp": [], "frame_index": [],
      "episode_index": [], "index": []
      }
  while successful_trajectories < num_ep:
    ep_idx = successful_trajectories
    print(f"Episode {ep_idx}")
    observation, info = env.reset(qpos=None, box_pos=None, box_quat=None)
    policy.reset()

    ep_dict = {
      "action.qpos": [], "action.qvel": [], "action.gripper": [],"action.ee_vel": [], "action.ee_ang_vel": [],
      "observation.state.qpos": [], "observation.state.qvel": [], "observation.state.gripper": [], "observation.pixels.side": [], "observation.pixels.gripper": [],
      "observation.ee_pose.pos": [], "observation.ee_pose.quat": [], "observation.ee_pose.vel": [], "observation.ee_pose.ang_vel": [],
      "reward": [], "timestamp": [], "frame_index": [],
      }
    ep_dict["episode_index"] = [ep_idx] * num_frames

    for step in range(num_frames):
      action = policy(env.unwrapped.model, env.unwrapped.data, observation, info)
      delta = action["pos"] - observation["ee_pose"]["pos"]

      # Quaternion error
      quat_err = np.empty(4)
      curr_quat = observation["ee_pose"]["quat"]
      curr_quat_conj = np.empty(4)
      ang_delta = np.empty(3)
      mujoco.mju_negQuat(curr_quat_conj, curr_quat)
      mujoco.mju_mulQuat(quat_err, action["quat"], curr_quat_conj)
      mujoco.mju_quat2Vel(ang_delta, quat_err, 1.0)
      
      # Convert to velocity (in world frame)
      vel = Kp * delta * env.metadata["render_fps"]
      ang_vel = Kp_ang * ang_delta * env.metadata["render_fps"]
      # Transform to end effector frame
      action["qvel"] = env.unwrapped.solve_ik_vel(vel, ang_vel, ref_frame='end_effector', local=False)

      observation, reward, terminated, truncated, info = env.step(action)

      ep_dict["action.qpos"].append(action["qpos"])
      ep_dict["action.qvel"].append(action["qvel"])
      ep_dict["action.ee_vel"].append(vel)
      ep_dict["action.ee_ang_vel"].append(ang_vel)
      ep_dict["action.gripper"].append(action["gripper"])
      ep_dict["observation.state.qpos"].append(observation["state"]["qpos"])
      ep_dict["observation.state.qvel"].append(observation["state"]["qvel"])
      ep_dict["observation.state.gripper"].append(observation["state"]["gripper"])
      ep_dict["observation.pixels.side"].append(observation["pixels"]["side"])
      ep_dict["observation.pixels.gripper"].append(observation["pixels"]["gripper"])
      ep_dict["observation.ee_pose.pos"].append(observation["ee_pose"]["pos"])
      ep_dict["observation.ee_pose.quat"].append(observation["ee_pose"]["quat"])
      ep_dict["observation.ee_pose.vel"].append(observation["ee_pose"]["vel"])
      ep_dict["observation.ee_pose.ang_vel"].append(observation["ee_pose"]["ang_vel"])
      ep_dict["reward"].append(reward)
      ep_dict["timestamp"].append(env.unwrapped.data.time)
      ep_dict["frame_index"].append(step)
    
    if policy.done and terminated:
      
      # for key in ep_dict:
      #   # Setting the dtype/transforming to a tensor makes no difference as HF Datasets stores everything in Arrow format anyway
      #   dtype = getattr(torch, features[key].feature.dtype) if 'feature' in features[key].__dict__ and features[key].feature.dtype in torch.__dict__ else None
      #   ep_dict[key] = torch.tensor(ep_dict[key], dtype=dtype)
            
      media.show_video(ep_dict["observation.pixels.side"], fps=env.metadata["render_fps"])
      for key in ep_dict:
        data_dict[key].extend(ep_dict[key])

      base_idx = successful_trajectories*num_frames
      data_dict["index"].extend(range(base_idx, base_idx + num_frames))

      successful_trajectories += 1
    else:
      print("Failed, retrying", policy.done, terminated)

  print("Creating dataset")
  # data_dict = concatenate_episodes(ep_dicts)
  # print(data_dict, features)
  hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
  
  hf_dataset.save_to_disk(dataset_dir + f"/grasp_ee_vel_fixed_{num_ep}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.hf")

  return hf_dataset


In [None]:
policy = gym_lite6.scripted_policy.GraspPolicy(env, 'end_effector', 'box', 'gripper_left_finger', 'gripper_right_finger', max_vel=0.2)
# policy = gym_lite6.scripted_policy.GraspAndLiftPolicy(env, 'end_effector', 'box', 'gripper_left_finger', 'gripper_right_finger', max_vel=0.2)
# record_episodes(env, policy, "dataset/pickup_side_cam", n=50)
hf_dataset = record_episodes_to_hf(env, policy, "../datasets/grasp", num_ep=2, num_frames=100)
hf_dataset.set_transform(hf_transform_to_torch)
