In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from pprint import pprint

import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.factory import make_dataset


from hydra import compose, initialize
from omegaconf import OmegaConf

# context initialization
with initialize(version_base=None, config_path="../configs", job_name="test_app"):
    cfg = compose(config_name="default")
    print(OmegaConf.to_yaml(cfg))

In [None]:
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
import shutil
imi = 0
use_images = True
filter_zeros = True
include_failures = False
notes = 'image' if use_images else 'vstate'
notes += '_zeros' if filter_zeros else ''
notes += '_failures' if include_failures else ''

repo_id = f"j/{imi}"
root = Path(f'~/workspace/lerobot/local/ros_{imi}_{notes}').expanduser()
dst = Path(f'~/.cache/huggingface/hub/datasets--ros_{imi}_{notes}').expanduser()

if dst.exists(): print(f"Removing {dst}"); shutil.rmtree(dst)
if root.exists(): shutil.rmtree(root); print(f"Removing {root}")
    

In [None]:
import rosbag
import rospy
from kortex_driver.msg import BaseCyclic_Feedback, TwistCommand
from sensor_msgs.msg import Image as RosImage, Joy, JointState
from std_msgs.msg import Float32, Int8
import std_msgs.msg
from sensor_msgs.msg import Image, Joy
import numpy as np
from collections import defaultdict
import cv2
from cv_bridge import CvBridge
import matplotlib.pyplot as plt

bridge = CvBridge()

def state_from_basefeedback(msg):
    gripper_pos = msg.interconnect.oneof_tool_feedback.gripper_feedback[0].motor[0].position
    tool_pose = msg.base.tool_pose_x, msg.base.tool_pose_y, msg.base.tool_pose_z, msg.base.tool_pose_theta_x, msg.base.tool_pose_theta_y, msg.base.tool_pose_theta_z
    # return [*tool_pose, gripper_pos]
    return [gripper_pos]

def state_from_jointstate(msg):
    return msg.position[:6] # only first 6 joints, gripper from basefeedback

def action_from_joy(msg):
    return -msg.buttons[4] if msg.buttons[4] else msg.buttons[5]


def action_from_outvel(msg):
    return [msg.twist.linear_x, msg.twist.linear_y, msg.twist.linear_z, msg.twist.angular_x, msg.twist.angular_y, msg.twist.angular_z]

crop_dim = 700
crop_left_offset = 200
def top_image_to_cv2(cvimg):
    cvimg = cvimg[:crop_dim, crop_left_offset:crop_left_offset+crop_dim]
    return cvimg

def top_image_msg_to_cv2(msg):
    cvimg = bridge.imgmsg_to_cv2(msg, "bgr8")
    return cvimg

topic_to_fn = {
    '/my_gen3_lite/base_feedback': state_from_basefeedback,
    '/my_gen3_lite/joint_states': state_from_jointstate,
    '/joy': action_from_joy,
    '/my_gen3_lite/in/cartesian_velocity': action_from_outvel,
    '/camera_obs__dev_video4_96x96': top_image_msg_to_cv2,
}

action_dims = 7
state_dims = 7

path = Path('~/user_315').expanduser() # sample to pull out state sizes
bag = rosbag.Bag(path / 'trial_data.bag')
msg_topics = set(); t0 = None; hz = 10; frame = defaultdict(list)
video_frames = []

n_frames = 0
for topic, msg, t in bag.read_messages():
    if not t0: t0 = t.to_sec()
    if t.to_sec() - t0 > 1/hz:
        # create a frame
        

        # Reset t0
        t0 = t.to_sec()
        frame = defaultdict(list)

    if topic not in msg_topics:
        # print(topic, type(msg)) #, msg)
        msg_topics.add(topic)

    if topic in topic_to_fn:
        frame[topic].append(topic_to_fn[topic](msg))

    # if topic == '/my_gen3_lite/in/cartesian_velocity':
    #     print(topic, msg.twist)

    if topic == '/camera_obs__dev_video4_96x96':
        
        video_frames.append(msg)
        cv2_img = top_image_msg_to_cv2(msg)
        if n_frames % 200 == 0:
            plt.imshow(cv2_img)
            plt.show()
    n_frames += 1

print(f'bag runtime {t.to_sec() - t0:1.2f} seconds')

video = cv2.VideoWriter(str(path / 'video.avi'), cv2.VideoWriter_fourcc(*'XVID'), 10, (96, 96))
for i, frame in enumerate(video_frames):
    cv2_img = bridge.imgmsg_to_cv2(frame, "bgr8")
    video.write(cv2_img)
    # if i > 100: break
video.release()

In [None]:
state_ndims = 7
action_ndims = 7

features = {
    "observation.state": {
        "dtype": "float32",
        "shape": (state_ndims,),
        "names": [f's{i}' for i in range(state_ndims)],
    },
    "action": {
        "dtype": "float32",
        "shape": (action_ndims,),
        "names": [f'a{i}' for i in range(action_ndims)],
    },
    "next.reward": {
        "dtype": "float32",
        "shape": (1,),
        "names": None,
    },
    "next.success": {
        "dtype": "bool",
        "shape": (1,),
        "names": None,
    },
}

if use_images:
    features["observation.image.top"] = {
        "dtype": "image",
        "shape": (3, 96, 96),
        "names": [
            "channel",
            "height",
            "width",
        ],
        'fps': 10
    }
else:
    features["observation.environment_state"] = {
        "dtype": "float32",
        "shape": (3,),
        "names": [f'env_s{i}' for i in range(2)],
    }

for k, v in features.items():
    print(k)


In [None]:
from collections import deque
import time
class VideoLoader:
    def __init__(self, video_dirs, threshold_ns=2e6, cache_size=100):
        """
        Initialize the VideoLoader with paths to videos and their corresponding timestamps.
        """
        self.dirs = [str(entry).split('/')[-1] for entry in video_dirs]
        self.video_paths = [Path(dirname).absolute() / 'output.mp4' for dirname in video_dirs]
        timestamp_paths = [Path(dirname).absolute() / 'video_frame_timestamps.txt' for dirname in video_dirs]
        self.timestamp_lists = []
       
        for timestamp_fn in timestamp_paths:
            timestamps = []
            with open(str(timestamp_fn), 'r') as fp:
                timestamps = fp.readlines()
            timestamps = [rospy.Time.from_seconds(int(t) / 1e9) for t in timestamps]
            self.timestamp_lists.append(timestamps)

        self.frames = [None for _ in self.video_paths]  # To store frames of each video
        self.captures = []
        self.frame_caches = [deque(maxlen=cache_size) for _ in self.video_paths]
        # self.frame_timestamps = []  # To store timestamp lists of each video
        self.frame_idx = [0 for _ in self.video_paths] # NOTE: this class dumps its frames, it doesn't search them, when it runs into problems it yells and skips frames. Its like a bag.
        # self._load_videos()
        self._open_videos()



        self.threshold_ns = rospy.Time(secs=0, nsecs=threshold_ns)

    def _open_videos(self):
        """
        Open video files for streaming.
        """
        self.captures = [cv2.VideoCapture(str(path)) for path in self.video_paths]
        for i, cap in enumerate(self.captures):
            if not cap.isOpened():
                raise ValueError(f"Cannot open video file: {self.video_paths[i]}")
            else:
                # load the first frame
                ret, frame = cap.read()
                if not ret:
                    raise ValueError(f"Cannot read from capture {self.video_paths[i]}")
                self.frames[i] = frame

    # NOTE: Better to drop frames than spend too much time getting them (frames will be dropped in the real world)
    def get_frame_if_available(self, target_timestamp):
        """
        Retrieve a frame if the next frame is close enough to the passed target_timestamp. NOTE: if you didn't have a fast signal in the rosbag you would miss 
        
        :param target_timestamp: Timestamp for which to retrieve the frame.
        :return: Frames at the given target_timestamp or None if no close match.
        """
        t0 = time.perf_counter()
        ret_frames = [None for _ in self.frames]
        read = 0
        percent_complete = {}
        for cam_idx, capture in enumerate(self.captures):
            ts = self.timestamp_lists[cam_idx][self.frame_idx[cam_idx]] # the time of the current frame
            dt = ts.to_nsec() - target_timestamp.to_nsec() # the difference between the current frame and the target timestamp
            
            if dt < 0:
                while dt < -self.threshold_ns.to_nsec(): # if the current frame is too far behind the target timestamp
                    ret, frame = capture.read(); read += 1
                    self.frames[cam_idx] = frame
                    self.frame_idx[cam_idx] += 1
                    ts = self.timestamp_lists[cam_idx][self.frame_idx[cam_idx]]
                    dt = ts.to_nsec() - target_timestamp.to_nsec() # the difference between the current frame and the target timestamp
                ret_frames[cam_idx] = self.frames[cam_idx]; self.frames[cam_idx] = None
            else:
                if dt > self.threshold_ns.to_nsec(): # if the current frame is too far ahead of the target timestamp
                    ret_frames[cam_idx] = None
                else:
                    ret_frames[cam_idx] = self.frames[cam_idx]; self.frames[cam_idx] = None

        
            idx = self.frame_idx[cam_idx] 
            if idx > len(self.timestamp_lists[cam_idx]):
                rospy.loginfo("Out of video frames for camera {self.video_paths[cam_idx]}")
                continue

            remaining_frame_count = len(self.timestamp_lists[cam_idx]) - idx
            percent_complete[cam_idx] = 1 - (remaining_frame_count / len(self.timestamp_lists[cam_idx]))

        endT = time.perf_counter()
        # if (endT - t0) > 0.01:
            # rospy.logwarn(f"Video read took longer than 0.01 seconds: {(endT - t0)=}")

        return ret_frames, ts, percent_complete

In [None]:
trials_position_0_successful = [232, 235, 242, 245, 248, 251, 254, 257, 261, 265, 269, 273, 276, 279, 283, 293, 297, 301, 304, 308, 315, 316, 319, 320, 325, 328, 331, 335]
trials_position_0_failed = [238, 249, 260, 268, 282, 288, 307, 312, 324, 331, 334]

trials_position_1_successful = [233, 236, 239, 243, 246, 252, 255, 258, 262, 266, 274, 277, 280, 284, 287, 294, 298, 302, 305, 309, 317, 321, 326, 329, 322]
trials_position_1_failed = [270, 313]

trials_position_2_successful = [327, 330, 333, 300, 303, 306, 311, 318, 278, 281, 286, 290, 295, 299, 256, 259, 263, 267, 275, 234, 237, 244, 247, 250]
trials_position_2_failed = [240, 253, 285, 289, 296, 310, 314, 322]

trials = trials_position_0_successful + trials_position_1_successful + trials_position_2_successful
if include_failures:
    trials += trials_position_0_failed + trials_position_1_failed + trials_position_2_failed

in0 = lambda x: x in trials_position_0_failed or x in trials_position_0_successful
in1 = lambda x: x in trials_position_1_failed or x in trials_position_1_successful
in2 = lambda x: x in trials_position_2_failed or x in trials_position_2_successful


In [None]:
if dst.exists(): shutil.rmtree(dst)
if root.exists(): shutil.rmtree(root)

# metadata = LeRobotDatasetMetadata(repo_id, root, local_files_only=True)
dataset = LeRobotDataset.create(
    repo_id,
    fps=hz, # from pusht.yaml
    root=root,
    use_videos=use_images,
    features=features
)
msg_topics = set(); must_have_keys = [entry for entry in list(topic_to_fn.keys()) if entry not in ['/my_gen3_lite/in/cartesian_velocity', '/joy']]
zeros_filtered = 0
for uid in trials: # just the position 2 successes
    path = Path(f'~/user_{uid}').expanduser() # sample to pull out state sizes

    video_dirs = [entry for entry in path.iterdir() if "cam_dev_video" in str(entry)]
    video_loader = VideoLoader(video_dirs)

    bag = rosbag.Bag(path / 'trial_data.bag')

    t0 = None; hz = 10; frame = defaultdict(list); bag_start = None; total_frames = 0

    all_actions = []
    all_states = []
    all_video_frames = []

    for topic, msg, t in bag.read_messages(): # NOTE: we're dropping the last frame
        if not t0: t0 = t.to_sec()
        if not bag_start: bag_start = t.to_sec()
        # print(f'{t.to_sec() - bag_start:1.2f}')
        # if topic == '/my_gen3_lite/in/cartesian_velocity':
        #     print('\t', msg.twist)

        dt = t.to_sec() - t0
        if dt >= 1/hz:
            # print(f'{dt=:1.2f}')
            # create a frame

            # Action and state are the mean of the frames
            # only make a frame if we have data of all topics
            if all(len(frame[k]) > 0 for k in must_have_keys):
                LR_frame = {}
                LR_frame['next.success'] = 0
                LR_frame['next.reward'] = 0

                joint_state_mean = np.mean(frame['/my_gen3_lite/joint_states'], axis=0)
                gripper_pos = np.mean(frame['/my_gen3_lite/base_feedback'], axis=0)
                LR_frame['observation.state'] = np.concatenate([joint_state_mean, gripper_pos])

                if '/my_gen3_lite/in/cartesian_velocity' in frame:
                    joint_action = np.mean(frame['/my_gen3_lite/in/cartesian_velocity'], axis=0)
                else:
                    joint_action = np.zeros((6,))

                if '/joy' in frame:
                    gripper_action = min(1, max(-1, np.sum(frame['/joy'], axis=0))) # sum the actions since they're -1, 1
                else:
                    gripper_action = 0

                if abs(gripper_action) > 0 or any([abs(entry) for entry in frame['/joy']]):
                    print(f'gripper action {gripper_action}', frame['/joy'])

                # make sure gripper action is concatenateable 
                gripper_action = np.array([gripper_action]) 
                action = np.concatenate([joint_action, gripper_action], axis=0)
                
                # print(action)
                if True and np.all(action == 0):
                    zeros_filtered += 1
                else:
                    LR_frame['action'] = action

                    if use_images:
                        # LR_frame['observation.image.top'] = frame['/camera_obs__dev_video4_96x96'][0]
                        camera_frames, ts, percent_complete = video_loader.get_frame_if_available(t - rospy.Time.from_sec(0.1))
                        if camera_frames[1] is None: continue
                        top_img = top_image_to_cv2(camera_frames[1])
                        top_img = cv2.resize(top_img, (96, 96), interpolation=cv2.INTER_AREA)
                        
                        mini_frame = frame['/camera_obs__dev_video4_96x96'][0]

                        # concatenate the 2 imgs
                        toshow = np.concatenate([mini_frame, top_img], axis=1)

                        # plt.imshow(toshow)
                        # plt.show()
                        # break
                        LR_frame['observation.image.top'] = top_img
                    else:
                        # placeholder state for cup position
                        env_state = np.zeros((3,))
                        if in0(uid): env_state[0] = 1
                        elif in1(uid): env_state[1] = 1
                        elif in2(uid): env_state[2] = 1
                        else: raise ValueError(f'uid {uid} not in any position')
                        LR_frame['observation.environment_state'] = env_state
                    
                    dataset.add_frame(LR_frame)
                    total_frames += 1

                    all_actions.append(LR_frame['action'])
                    all_states.append(LR_frame['observation.state'])

                    video_frame = frame['/camera_obs__dev_video4_96x96'][0]
                    # put action on the video
                    cv2_img = video_frame #bridge.imgmsg_to_cv2(video_frame, "bgr8")

                    # make the frame big
                    cv2_img = cv2.resize(cv2_img, (256, 256), interpolation=cv2.INTER_AREA)
                    cv2.putText(cv2_img, ', '.join([f'{a:1.2f}' for a in action]), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0), 1)
                    all_video_frames.append(cv2_img)


            # Reset t0
            t0 = t.to_sec()
            frame = defaultdict(list)

        if topic not in msg_topics:
            print(topic, type(msg)) #, msg)
            msg_topics.add(topic)

        if topic in topic_to_fn:
            val = topic_to_fn[topic](msg)
            # if np.isnan(val).any():
            #     print(f'nan in {topic} {val}')
            frame[topic].append(val)

    print(f'bag runtime {t.to_sec() - t0:1.2f} seconds')

    video = cv2.VideoWriter(str(path / 'video.avi'), cv2.VideoWriter_fourcc(*'XVID'), 10, (256, 256))
    for i, frame in enumerate(all_video_frames):
        video.write(frame)
        # if i > 100: break
    video.release()

    dataset.save_episode("Pick up a cup.", encode_videos=False)
    print(f'added {total_frames} frames from {uid}')
    # break
print(f'filtered {zeros_filtered} zeros')
dataset.consolidate()
dataset.meta.stats['action']

In [None]:
# plot each action and state by dimension
import matplotlib.pyplot as plt
fig, axs = plt.subplots(all_actions[0].shape[0], 1, figsize=(10, 10))
for i in range(all_actions[0].shape[0]):
    # count the nonzero actions
    print(np.sum([a[i] != 0 for a in all_actions]))
    axs[i].scatter([_ for _ in range(len(all_actions))], [a[i] for a in all_actions])
    axs[i].set_title(f'action {i}')

fig, axs = plt.subplots(all_states[0].shape[0], 1, figsize=(10, 10))
for i in range(all_states[0].shape[0]):
    axs[i].plot([a[i] for a in all_states])
    axs[i].set_title(f'state {i}')


